Problem description.
You are given a tree. If we select 2 distinct nodes uniformly at random, what's the probability that the distance between these 2 nodes is a prime number?
Input
The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].
Output
Output a real number denote the probability we want.
You'll get accept if the difference between your answer and standard answer is no more than 10^-6.
Constraints
2 ≤ N ≤ 50,000
The input must be a tree.
Example
Input: 5 1 2 2 3 3 4 4 5 Output: 0.5
Explanation
We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:
1-3, 2-4, 3-5: 2
1-4, 2-5: 3
Note that 1 is not a prime number.
点分治,考虑每一层怎么统计答案。
因为子树中的路径拼起来是一个卷积形式,可以用FFT。
#include<cstdio> #include<iostream> #include<algorithm> #include<complex> #include<cstring> #include<cmath> #define LL long long using namespace std; const int maxn = 1e5 + 5; const double Pi = acos(-1); struct E{ double real, imag;E(){} E(double r, double i) {real = r, imag = i;} inline E operator + (const E &a) const { return E(real + a.real, imag + a.imag); } inline E operator - (const E &a) const { return E(real - a.real, imag - a.imag); } inline E operator * (const E &a) const { return E(a.real * real - a.imag * imag, a.real * imag + a.imag * real); } }; struct edge { int v, next; }e[maxn << 1]; int head[maxn], cnt, n, u, v; void adde(const int &u, const int &v) { e[++cnt] = (edge) {v, head[u]}; head[u] = cnt; } LL ans; int siz[maxn], h[maxn], vis[maxn], G; void FG(int u, int p, int tot, int & rt) { int mxs = 0; siz[u] = 1; for(register int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(vis[v] || v == p) continue; FG(v, u, tot, rt), siz[u] += siz[v]; mxs = max(mxs, siz[v]); } mxs = max(tot - siz[u], mxs); if(mxs <= tot / 2) rt = u; } int pri[maxn], ban[maxn]; void GPri(int n) { ban[1] = 1; for(register int i = 2; i <= n; ++i) { if(!ban[i]) pri[++pri[0]] = i; for(register int j = 1; j <= pri[0] && i * pri[j] <= n; ++j) { ban[i * pri[j]] = 1; if(i % pri[j] == 0) break; } } } namespace POLY { E tmp[maxn << 1], ta[maxn << 1]; int N, t, RL[maxn << 1]; LL as[maxn << 1]; void Gans() { for(register int i = 0; i <= N; ++i) { if(!ban[i]) ans += as[i]; as[i] = 0; } } void init(int l) { for(N = 1; N < l; N <<= 1); t = N >> 1; for(register int i = 0; i <= N; ++i) RL[i] = (RL[i >> 1] >> 1) | ((i & 1) * t); } void FFT(E * a, int type) { for(register int i = 0; i < N; ++i) if(i < RL[i]) swap(a[i], a[RL[i]]); for(register int i = 1; i < N; i <<= 1) { E Wn(cos(Pi / i), sin(type * Pi / i)); for(register int p = i << 1, j = 0; j < N; j += p) { E w(1, 0); for(register int k = 0; k < i; ++k, w = w * Wn) { E x = a[j + k], y = a[i + j + k] * w; a[j + k] = x + y, a[i + j + k] = x - y; } } } } void solve(E * a, E * b, int l) { init(l << 1); memset(ta, 0, sizeof(E) * (N + 3)); memcpy(ta, a, sizeof(E) * (N + 2)); FFT(ta, 1), FFT(b, 1); for(register int i = 0; i < N; ++i) tmp[i] = ta[i] * b[i]; FFT(tmp, -1); for(register int i = 0; i < N; ++i) tmp[i].real /= N, as[i] += (LL)(tmp[i].real + 0.5); memset(b, 0, sizeof(E) * (N + 3)); } } E a[maxn], b[maxn]; int add(E * b, int u, int p, int d, int dt, int s) { b[d].real += dt, POLY::as[d] += s; for(register int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(vis[v] || v == p) continue; add(b, v, u, d + 1, dt, s); } } void dfs(int u, int p) { h[u] = 1, siz[u] = 1; for(register int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(vis[v] || v == p) continue; dfs(v, u), siz[u] += siz[v]; h[u] = max(h[v] + 1, h[u]); } } bool cmp(const int &a, const int &b) { return h[a] < h[b]; } void solve(int u, int tot) { int l = 0, v, sz, hd = head[u]; vis[u] = 1, dfs(u, 0), l = h[u]; while(hd && vis[e[hd].v]) hd = e[hd].next; if(hd) add(a, e[hd].v, u, 1, 1, 1); for(register int i = e[hd].next; i; i = e[i].next) { v = e[i].v; if(vis[v]) continue; add(b, v, u, 1, 1, 1); POLY::solve(a, b, l); add(a, v, u, 1, 1, 0); } POLY::Gans(); memset(a, 0, sizeof(E) * (l + 3)); for(register int i = head[u]; i; i = e[i].next) { v = e[i].v, G = 0; if(vis[v]) continue; FG(v, u, siz[v], G); solve(G, siz[v]); } } int main() { //freopen("tst.in", "r", stdin); scanf("%d", &n), GPri(n); for(register int i = 1; i < n; ++i) { scanf("%d%d", &u, &v); adde(u, v), adde(v, u); } FG(1, 0, n, G); solve(G, n); printf("%.10lf", (double)ans / (1ll * n * (n - 1) / 2)); return 0; }
没有帐号? 立即注册