Problem description.
You are given a tree. If we select 2 distinct nodes uniformly at random, what's the probability that the distance between these 2 nodes is a prime number?
Input
The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].
Output
Output a real number denote the probability we want.
You'll get accept if the difference between your answer and standard answer is no more than 10^-6.
Constraints
2 ≤ N ≤ 50,000
The input must be a tree.
Example
Input: 5 1 2 2 3 3 4 4 5 Output: 0.5
Explanation
We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:
1-3, 2-4, 3-5: 2
1-4, 2-5: 3
Note that 1 is not a prime number.
点分治,考虑每一层怎么统计答案。
因为子树中的路径拼起来是一个卷积形式,可以用FFT。
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<complex>
#include<cstring>
#include<cmath>
#define LL long long
using namespace std;
const int maxn = 1e5 + 5;
const double Pi = acos(-1);
struct E{
double real, imag;E(){}
E(double r, double i) {real = r, imag = i;}
inline E operator + (const E &a) const {
return E(real + a.real, imag + a.imag);
}
inline E operator - (const E &a) const {
return E(real - a.real, imag - a.imag);
}
inline E operator * (const E &a) const {
return E(a.real * real - a.imag * imag, a.real * imag + a.imag * real);
}
};
struct edge {
int v, next;
}e[maxn << 1];
int head[maxn], cnt, n, u, v;
void adde(const int &u, const int &v) {
e[++cnt] = (edge) {v, head[u]};
head[u] = cnt;
}
LL ans;
int siz[maxn], h[maxn], vis[maxn], G;
void FG(int u, int p, int tot, int & rt) {
int mxs = 0; siz[u] = 1;
for(register int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if(vis[v] || v == p) continue;
FG(v, u, tot, rt), siz[u] += siz[v];
mxs = max(mxs, siz[v]);
}
mxs = max(tot - siz[u], mxs);
if(mxs <= tot / 2) rt = u;
}
int pri[maxn], ban[maxn];
void GPri(int n) {
ban[1] = 1;
for(register int i = 2; i <= n; ++i) {
if(!ban[i]) pri[++pri[0]] = i;
for(register int j = 1; j <= pri[0] && i * pri[j] <= n; ++j) {
ban[i * pri[j]] = 1;
if(i % pri[j] == 0) break;
}
}
}
namespace POLY {
E tmp[maxn << 1], ta[maxn << 1];
int N, t, RL[maxn << 1];
LL as[maxn << 1];
void Gans() {
for(register int i = 0; i <= N; ++i) {
if(!ban[i]) ans += as[i];
as[i] = 0;
}
}
void init(int l) {
for(N = 1; N < l; N <<= 1);
t = N >> 1;
for(register int i = 0; i <= N; ++i)
RL[i] = (RL[i >> 1] >> 1) | ((i & 1) * t);
}
void FFT(E * a, int type) {
for(register int i = 0; i < N; ++i)
if(i < RL[i]) swap(a[i], a[RL[i]]);
for(register int i = 1; i < N; i <<= 1) {
E Wn(cos(Pi / i), sin(type * Pi / i));
for(register int p = i << 1, j = 0; j < N; j += p) {
E w(1, 0);
for(register int k = 0; k < i; ++k, w = w * Wn) {
E x = a[j + k], y = a[i + j + k] * w;
a[j + k] = x + y, a[i + j + k] = x - y;
}
}
}
}
void solve(E * a, E * b, int l) {
init(l << 1);
memset(ta, 0, sizeof(E) * (N + 3));
memcpy(ta, a, sizeof(E) * (N + 2));
FFT(ta, 1), FFT(b, 1);
for(register int i = 0; i < N; ++i)
tmp[i] = ta[i] * b[i];
FFT(tmp, -1);
for(register int i = 0; i < N; ++i)
tmp[i].real /= N, as[i] += (LL)(tmp[i].real + 0.5);
memset(b, 0, sizeof(E) * (N + 3));
}
}
E a[maxn], b[maxn];
int add(E * b, int u, int p, int d, int dt, int s) {
b[d].real += dt, POLY::as[d] += s;
for(register int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if(vis[v] || v == p) continue;
add(b, v, u, d + 1, dt, s);
}
}
void dfs(int u, int p) {
h[u] = 1, siz[u] = 1;
for(register int i = head[u]; i; i = e[i].next) {
int v = e[i].v;
if(vis[v] || v == p) continue;
dfs(v, u), siz[u] += siz[v];
h[u] = max(h[v] + 1, h[u]);
}
}
bool cmp(const int &a, const int &b) {
return h[a] < h[b];
}
void solve(int u, int tot) {
int l = 0, v, sz, hd = head[u];
vis[u] = 1, dfs(u, 0), l = h[u];
while(hd && vis[e[hd].v]) hd = e[hd].next;
if(hd) add(a, e[hd].v, u, 1, 1, 1);
for(register int i = e[hd].next; i; i = e[i].next) {
v = e[i].v;
if(vis[v]) continue;
add(b, v, u, 1, 1, 1);
POLY::solve(a, b, l);
add(a, v, u, 1, 1, 0);
}
POLY::Gans();
memset(a, 0, sizeof(E) * (l + 3));
for(register int i = head[u]; i; i = e[i].next) {
v = e[i].v, G = 0;
if(vis[v]) continue;
FG(v, u, siz[v], G);
solve(G, siz[v]);
}
}
int main() {
//freopen("tst.in", "r", stdin);
scanf("%d", &n), GPri(n);
for(register int i = 1; i < n; ++i) {
scanf("%d%d", &u, &v);
adde(u, v), adde(v, u);
}
FG(1, 0, n, G);
solve(G, n);
printf("%.10lf", (double)ans / (1ll * n * (n - 1) / 2));
return 0;
}
rockdu
没有帐号? 立即注册