The Preliminary Contest for ICPC Asia Shanghai 2019 C.Triple FFT + 生成函数
FFT 生成函数   发布于 2019-09-21   376人围观  0条评论
FFT 生成函数   发表于 2019-09-21   376人围观  0条评论

题目大意

给定一个正整数n,有三个数组A,B,C,每个数组有n个数,求问有多少个三元组(i,j,k)满足|Ai-Bj|≤Ck,|Ck-Bj|≤Ai,|Ai-Ck|≤Bj

题解

首先我们发现,其实三元组满足的条件其实非常像三角形的构成条件,所以我们可以叙述成

“最小值 + 中间值 ≤ 最大值”

满足这样的三元组个数。

官方题解是正着算的,然后还有什么去重什么麻烦的东西,看起来令人头大……

所以我们为何不反过来算呢?

首先所有的取法总数是n3,然后我们只需要计算以下不满足的三元组总数即可了。

但是怎么做?

大概的思路就是对于某一个数组中的所有数,我们对每一个进行枚举,把当前枚举的数作为三元组的最大值,看剩下两个数组中有多少个二元组满足两个数加起来小于当前这个数。

具体的思路就是这样,但是暴力算时间复杂度会崩,是O(n3)的。

我们可以把所有二元组的和都预处理出来,用桶计数,表示当前二元组的和为i的组数有多少个。然后我们对要枚举的数列先从小到大排序,再进行枚举,用前缀和的思想找上述的不满足条件的二元组的个数即可,这样的做法是O(n2)的。

但是还不够,因为n最大是1e5。

这时候我们就要应用生成函数的思想,对于每一个数y,把y放在x的指数上,y在这个数组中出现的次数作为这一项的系数。这样任意两项乘在一起就相当于一个加和。

我们要算所有二元组的加和,用上面的生成函数的思想,就可以把暴力转化成为一个普通卷积,而卷积可以用FFT来加速。于是就可以优化为O(nlogn)。

整体需要3次DFT和3次IDFT。

然后按照上述方法前缀和枚举即可。

but……因为FFT是用桶的思路,所以每一次FFT都是用最大值代入的时间复杂度中的n,而且FFT常数巨大,这就导致我们直接FFT会超时。

我们关注到题目中N > 1000的数据组最多只有20组,于是我们可以分数据范围,n<1000 的直接n2暴力,n > 1000的使用FFT。这样就不会超时了。

代码

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
typedef long long ll;
const double  pi = acos(-1.0);
struct comp{
	double r,i;
	comp(double a = 0,double b = 0):r(a),i(b){};
	comp operator + (comp a){return comp(r+a.r,i+a.i);}
	comp operator - (comp a){return comp(r-a.r,i-a.i);}
	comp operator * (comp a){return comp(r*a.r-i*a.i,r*a.i+i*a.r);}
}a[maxn<<2],b[maxn<<2],c[maxn<<2],w[maxn<<2],reta[maxn<<2],retb[maxn<<2],retc[maxn<<2];
int rev[maxn<<2];
int A[maxn],B[maxn],C[maxn],cnta[maxn],cntb[maxn],cntc[maxn];
int len,lg;
int t,cas,n;
ll ta,tb,tc;
ll ans;
void init(int n,int m){
	lg = 0;
	for(len = 1;len < n + m;len <<= 1,++lg);
	for(int i = 0;i < len;++i)
		rev[i] = (rev[i>>1]>>1) | ((i&1) << (lg-1));
	for(int i = 2;i <= len;++i)
		w[i] = comp(cos(2*pi/i),sin(2*pi/i));
	return;
}
void dft(comp *a,int len,int f){
	for(int i = 0;i < len;++i)
		if(i < rev[i])
			swap(a[i],a[rev[i]]);
	for(int l = 2;l <= len;l <<= 1){
		int now = l >> 1;
		comp wn = w[l];
		wn.i *= f;
		for(int j = 0;j < len;j += l){
			comp wk = comp(1,0),x,y;
			for(int k =j;k < j + now;k++,wk = wk * wn){
				x = a[k];
				y = wk * a[k + now];
				a[k] = x + y;
				a[k+now] = x - y;
			}
		}
	}
	if(f == -1)
		for(int i = 0;i < len;++i)
			a[i].r /= len;
	return;
}
inline int max(int a,int b){
	return (a > b) ? a : b;
}
int x;
void work1(int n,int cas){
	int mx = 0;
	for(int i = 1;i <= n;++i){
		cin >> x;A[x]++;
		a[x].r += 1.0;
		mx = max(mx,x);
	}
	for(int i = 1;i <= n;++i){
		cin >> x;B[x]++;
		b[x].r += 1.0;
		mx = max(mx,x);
	}
	for(int i = 1;i <= n;++i){
		cin >> x;C[x]++;
		c[x].r += 1.0;
		mx = max(mx,x);
	}
	mx += 1;
	init(mx,mx);
	dft(a,len,1);
	dft(b,len,1);
	dft(c,len,1);
	ans = (ll)n * (ll)n * (ll)n;
	for(int i = 0;i < len;++i){
		reta[i] = b[i] * c[i];
		retb[i] = a[i] * c[i];
		retc[i] = a[i] * b[i];
	}
	dft(reta,len,-1);dft(retb,len,-1);dft(retc,len,-1);
	ta = tb = tc = 0;
	for(int i = 0;i <= mx;++i){
		if(A[i] != 0)ans -= A[i] * ta;
		if(B[i] != 0)ans -= B[i] * tb;
		if(C[i] != 0)ans -= C[i] * tc;
		ta += (ll)(reta[i].r + 0.5);
		tb += (ll)(retb[i].r + 0.5);
		tc += (ll)(retc[i].r + 0.5);
	}
	cout << "Case #" << cas << ": " << ans << "\n";
	for(int i = 0;i < len;++i){
		a[i] = b[i] = c[i] = comp(0,0);
		reta[i] = retb[i] = retc[i] = comp(0,0);
	}
	for(int i = 0;i <= mx;++i){
		A[i] = B[i] = C[i] = 0;
	}
}
void work2(int n,int cas){
	int mx = 0;
	for(int i = 1;i <= n;++i)
		cin >> A[i],mx = max(mx,A[i]),cnta[A[i]]++;
	for(int i = 1;i <= n;++i)
		cin >> B[i],mx = max(mx,B[i]),cntb[B[i]]++;
	for(int i = 1;i <= n;++i)
		cin >> C[i],mx = max(mx,C[i]),cntc[C[i]]++;
	for(int i = 0;i < mx;++i){
		reta[i] = retb[i] = retc[i] = comp(0,0);
	}
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= n;++j){
			reta[B[i]+C[j]].r += 1;
			retb[A[i] + C[j]].r += 1;
			retc[A[i] + B[j]].r += 1;
		}
	ll ans = (ll)n * n * n;
	ta = tb = tc = 0;
	for(int i = 0;i <= mx;++i){
		if(cnta[i] != 0)ans -= cnta[i] * ta;
		if(cntb[i] != 0)ans -= cntb[i] * tb;
		if(cntc[i] != 0)ans -= cntc[i] * tc;
		ta += (ll)(reta[i].r + 0.5);
		tb += (ll)(retb[i].r + 0.5);
		tc += (ll)(retc[i].r + 0.5);
	}
	cout << "Case #" << cas << ": " << ans << "\n";
	for(int i = 0;i <= max(n,mx);++i){
		A[i] = B[i] = C[i] = 0;
		cnta[i] = cntb[i] = cntc[i] = 0;
		reta[i] = retb[i] = retc[i] = comp(0,0);
	}
	return;
}
int main(){
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	cin >> t;
	while(t--){
		cas++;
		cin >> n;
		if(n > 1000)work1(n,cas);
		else work2(n,cas);
	}
	return 0;
}

 

上一篇: BZOJ 3160 万径人踪灭 FFT + 回文自动机 + 生成函数

下一篇: HDU 6209 The Intersection 分数二分

立即登录,发表评论
没有帐号?立即注册
0 条评论