# BZOJ 3160 万径人踪灭 FFT + 回文自动机 + 生成函数

## 题目大意

给定一个字符集仅含a和b的字符串s，求字符串中不连续的回文子序列个数mod 1000000007的值。

## 代码

```#include <bits/stdc++.h>
using namespace std;
const int maxn = 150005;
const double pi = acos(-1.0);
typedef long long ll;
const ll mod = 1e9+7;
ll ans = 0;
struct comp{
double r,i;
comp(double a = 0,double b = 0):r(a),i(b){}
comp operator + (comp a){return comp(r+a.r,i+a.i);}
comp operator - (comp a){return comp(r-a.r,i-a.i);}
void operator += (comp a){r += a.r;i += a.i;}
comp operator * (comp a){return comp(r*a.r-i*a.i,r*a.i+i*a.r);}
}a[maxn<<2],b[maxn<<2],w[maxn<<2];
struct PAM{
int p,last,cur,len[maxn],nt[maxn][2],pos[maxn],fail[maxn],n,S[maxn];
ll cnt[maxn];
int newnode(int l){
nt[p][1] = nt[p][0] = 0;
len[p] = l;
cnt[p] = 0;
return p++;
}
inline void init(){
p = n = last = 0;
newnode(0);
newnode(-1);
S[0] = -1;fail[0] = 1;
return;
}
int get_fail(int x){
while(S[n-len[x]-1] != S[n])x = fail[x];
return x;
}
c -= 'a';
S[++n] = c;
int cur = get_fail(last);
if(!nt[cur][c]){
int now = newnode(len[cur]+2);
if(len[now] > 1)pos[now] = post;
fail[now] = nt[get_fail(fail[cur])][c];
nt[cur][c] = now;
}
last = nt[cur][c];
cnt[last]++;
return;
}
void count(){
for(int i = p-1;i >= 0;--i){
cnt[fail[i]] += cnt[i];
cnt[fail[i]] %= mod;
}
return;
}
}run;
int len,n,lg;
char s[maxn<<2];
int rev[maxn<<2];
void init(int k,int m){
for(len = 1;len < m + k;len <<= 1,lg++);
for(int i = 0;i < len;++i)
rev[i] = (rev[i>>1]>>1)|((i&1)<<(lg-1));
for(int i = 2;i <= len;++i)
w[i] = comp(cos(2*pi/i),sin(2*pi/i));
return;
}
void dft(comp *a,int len,int f){
for(int i = 0;i < len;++i)
if(i < rev[i])
swap(a[i],a[rev[i]]);
for(int i = 2;i <= len;i <<= 1){
int now = i >> 1;
comp wn = w[i];wn.i *= f;
for(int j = 0;j < len;j += i){
comp wk = comp(1,0),x,y;
for(int k = j;k < j + now;k++,wk = wk * wn){
x = a[k];
y = a[k + now] * wk;
a[k] = x + y;a[k + now] = x - y;
}
}
}
if(f == -1)
for(int i = 0;i < len;++i)
a[i].r /= len;
return;
}
int f[maxn<<2];
ll power[maxn<<2];
int main(){
scanf("%s",s);
n = strlen(s);
ans = 0;
run.init();
for(int i = 0;i < n;++i)
run.count();
for(int i = 2;i < run.p;++i){
ans -= run.cnt[i];
if(ans < 0)ans += mod;
}
init(n,n);
for(int i = 0;i < len;++i){
if(s[i] == 'a' && i < n)
a[i] = comp(1,0);
else a[i] = comp(0,0);
}
dft(a,len,1);
for(int i = 0;i < len;++i){
a[i] = a[i] * a[i];
}
dft(a,len,-1);

for(int i = 0;i < n;++i){
if(s[i] == 'b' && i < n)
b[i] = comp(1,0);
else b[i] = comp(0,0);
}
dft(b,len,1);
for(int i = 0;i < len;++i){
b[i] = b[i] * b[i];
}
dft(b,len,-1);

memset(f,0,sizeof(f));
for(int i = 0;i < len;++i){
f[i] = (int)(b[i].r + 0.5) + (int)(a[i].r + 0.5);
}
power[0] = 1;
for(int i = 1;i < maxn*4;++i){
power[i] = (power[i-1]*2) % mod;
}
for(int i = 0;i < len;++i){
ans += power[f[i]+1>>1]-1;
ans %= mod;
}
printf("%lld\n",(ans%mod + mod) % mod);
return 0;
}```