BZOJ 3160 万径人踪灭 FFT + 回文自动机 + 生成函数
FFT 字符串 生成函数   发布于 2019-09-22   491人围观  0条评论
FFT 字符串 生成函数   发表于 2019-09-22   491人围观  0条评论

题目大意

 给定一个字符集仅含a和b的字符串s,求字符串中不连续的回文子序列个数mod 1000000007的值。

题解

如果是连续的回文子序列,也就是回文串,那就很好求了,直接运用Manacher算法就可以了。

但是这题是让算所有的不连续的回文子序列……

然而这个是没法直接算出来的……

所以我们反过来想,我们可以把所有回文子序列的数量求出来,然后减去回文子串的数量,就是不连续的回文子序列的数量。

上面说了,如果是回文子串的数量,可以直接用Manacher算法算出来。

所以现在我们考虑的就是全体回文子序列的数量该怎么算。

对一个回文子序列,一定有一个中心,假设两个对应字符位置分别为x,y,那么,中心的位置一定是(x+y)/2。

我们可以将那个/2拿掉,这样算出来的新的中心位置也是唯一确定的,也就是和本质的回文中心位置一一对应。

对于一个位置,假设有x对(两个字符在同一位置也算一对)对应字符确定的中心位置在此位置,那么一定有以此位置为回文中心的回文子序列个数为2^x - 1。

如果我们直接暴力求出,那么复杂度是O(n2)的,n是100000的数据是吃不消的。

注意到字符集只含两个字符,再观察对应字符确定的新的中心位置的计算公式(pos = x + y),我们可以对于字符a和字符b分别构造生成函数:

假设我们当前对字符c进行计算,则有生成函数f(x) = Σ((s[i] == c) ? 1 : 0) * xi

可以看出,对字符c,有多少对对应字符的中心位置确定在所有的位置的数量与f(x)对f(x)的卷积在对应位置的系数值是有关系的(其实是一个近似于/2的关系,不过涉及到要讨论位置的奇偶,这里不再赘述)。

卷积可以所以相当于我们对字符a和b分别构造生成函数,分别进行FFT,然后对于每一个位置,将两个FFT算出的结果相加,然后+1,再除以2,即为当前位置为中心位置的对应字符对数(这里+1再/2(整除)是一个巧妙地避开对两个对应字符在相同位置(也就是说这一对字符完全就是一个字符)的讨论的做法,请自行理解)。之后按照上面的公式处理以下即可。

这样我们就求出了所有回文子序列的个数。

至于回文子串的个数,可以直接用Manacher算法来解。

但是……

博主不会Manacher……而且也懒得学……

可是我会PAM(回文自动机)啊!回文自动机可比Manacher好用多了(个人观点不代表官方意见)!

回文自动机里面的cnt[i]表示当前在回文自动机节点i,以当前节点代表的字符Ci结尾的回文串的数量。

于是就可以直接搞了,直接套一个PAM板子就可以了,一点都不麻烦。

至于复杂度?回文自动机的复杂度是O(nlog(字符集大小)),这里字符集大小为2,所以log那一项就是1,所以跟O(n)没有区别……

代码

#include <bits/stdc++.h>
using namespace std;
const int maxn = 150005;
const double pi = acos(-1.0);
typedef long long ll;
const ll mod = 1e9+7;
ll ans = 0;
struct comp{
	double r,i;
	comp(double a = 0,double b = 0):r(a),i(b){}
	comp operator + (comp a){return comp(r+a.r,i+a.i);}
	comp operator - (comp a){return comp(r-a.r,i-a.i);}
	void operator += (comp a){r += a.r;i += a.i;}
	comp operator * (comp a){return comp(r*a.r-i*a.i,r*a.i+i*a.r);}
}a[maxn<<2],b[maxn<<2],w[maxn<<2];
struct PAM{
	int p,last,cur,len[maxn],nt[maxn][2],pos[maxn],fail[maxn],n,S[maxn];
	ll cnt[maxn];
	int newnode(int l){
		nt[p][1] = nt[p][0] = 0;
		len[p] = l;
		cnt[p] = 0;
		return p++;
	}
	inline void init(){
		p = n = last = 0;
		newnode(0);
		newnode(-1);
		S[0] = -1;fail[0] = 1;
		return;
	}
	int get_fail(int x){
		while(S[n-len[x]-1] != S[n])x = fail[x];
		return x;
	}
	inline void add(char c,int post){
		c -= 'a';
		S[++n] = c;
		int cur = get_fail(last);
		if(!nt[cur][c]){
			int now = newnode(len[cur]+2);
			if(len[now] > 1)pos[now] = post;
			fail[now] = nt[get_fail(fail[cur])][c];
			nt[cur][c] = now;
		}
		last = nt[cur][c];
		cnt[last]++;
		return;
	}
	void count(){
		for(int i = p-1;i >= 0;--i){
			cnt[fail[i]] += cnt[i];
			cnt[fail[i]] %= mod;
		}
		return;
	}
}run;
int len,n,lg;
char s[maxn<<2];
int rev[maxn<<2];
void init(int k,int m){
	for(len = 1;len < m + k;len <<= 1,lg++);
	for(int i = 0;i < len;++i)
		rev[i] = (rev[i>>1]>>1)|((i&1)<<(lg-1));
	for(int i = 2;i <= len;++i)
		w[i] = comp(cos(2*pi/i),sin(2*pi/i));
	return;
}
void dft(comp *a,int len,int f){
    for(int i = 0;i < len;++i)
        if(i < rev[i])
            swap(a[i],a[rev[i]]);
    for(int i = 2;i <= len;i <<= 1){
        int now = i >> 1;
        comp wn = w[i];wn.i *= f;
        for(int j = 0;j < len;j += i){
            comp wk = comp(1,0),x,y;
            for(int k = j;k < j + now;k++,wk = wk * wn){
                x = a[k];
                y = a[k + now] * wk;
                a[k] = x + y;a[k + now] = x - y;
            }
        }
    }
    if(f == -1)
        for(int i = 0;i < len;++i)
            a[i].r /= len;
    return;
}
int f[maxn<<2];
ll power[maxn<<2];
int main(){
	scanf("%s",s);
	n = strlen(s);
	ans = 0;
	run.init();
	for(int i = 0;i < n;++i)
		run.add(s[i],i);
	run.count();
	for(int i = 2;i < run.p;++i){
		ans -= run.cnt[i];
		if(ans < 0)ans += mod;
	}
	init(n,n);
	for(int i = 0;i < len;++i){
		if(s[i] == 'a' && i < n)
			a[i] = comp(1,0);
		else a[i] = comp(0,0);
	}
	dft(a,len,1);
	for(int i = 0;i < len;++i){
		a[i] = a[i] * a[i];
	}
	dft(a,len,-1);

	for(int i = 0;i < n;++i){
		if(s[i] == 'b' && i < n)
			b[i] = comp(1,0);
		else b[i] = comp(0,0);
	}
	dft(b,len,1);
	for(int i = 0;i < len;++i){
		b[i] = b[i] * b[i];
	}
	dft(b,len,-1);

	memset(f,0,sizeof(f));
	for(int i = 0;i < len;++i){
		f[i] = (int)(b[i].r + 0.5) + (int)(a[i].r + 0.5);
	}
	power[0] = 1;
	for(int i = 1;i < maxn*4;++i){
		power[i] = (power[i-1]*2) % mod;
	}
	for(int i = 0;i < len;++i){
		ans += power[f[i]+1>>1]-1;
		ans %= mod;
	}
	printf("%lld\n",(ans%mod + mod) % mod);
	return 0;
}

上一篇: luogu P3321 BZOJ 3992 [SDOI2015]序列统计 NTT + 生成函数 + 快速幂

下一篇: The Preliminary Contest for ICPC Asia Shanghai 2019 C.Triple FFT + 生成函数

立即登录,发表评论
没有帐号?立即注册
0 条评论