题目链接:https://loj.ac/problem/2537
考虑到题目是一颗二叉树,有一个简单的O(N^2)的DP做法。
f(i,j)表示i号点取值为数集中第j大数的概率,每一次我们从子树合并上来。
j被取到有两种情况:
- j更大并且当前选择的是最大值
- j更小并且当前选择的是最小值
分情况讨论,发现要转移需要一个选到比j小的值的概率和(比j大的可以直接用1减比j小的)。
那么我们维护两个前缀和(对两边的元素分别计算),即可完成一次合并。
如何优化这个算法呢?
假设两边子树可以产生的j值集合分别为A,B。
考虑A值域上这样连续的一段数:满足不存在B[i]使得A[l]<=B[i]<=A[r]。
我们发现实际上这样的数的答案是一样的。
合并?求和?区间乘?想到魔改线段树合并。
我们考虑使用线段树合并来同时寻找/更改区间。
每一个叶子节点建立权值线段树,初始为自己的值。
我们合并线段树的同时,维护前面A集合中数的概率和以及B集合中数的概率和。
在合并的过程中,一定会出现合并到头(两棵树有至少一棵不存在节点)的情况。
此时,如果A的线段树存在节点,说明这一段值域连续的A集合的答案应该一样。
从B来考虑同样如此,我们做一个区间乘操作即可。
因为从n条链一路合并上去,所以复杂度为O(n log n)。
#include<cstdio> #include<algorithm> #include<iostream> #include<cstring> #define LL long long using namespace std; const int mod = 998244353; const int maxn = 3e5 + 5; const int N = 1.05e9; LL fpow(LL a, int b) { LL ans = 1; while(b) { if(b & 1) ans = ans * a % mod; a = a * a % mod, b >>= 1; } return ans; } const LL inv1e4 = fpow(10000, mod - 2); int tmp, ans; void mul(int &x, const int &a) { x = 1ll * x * a % mod; } void Add(int &x, const int &a) { x += a; if(x >= mod) x -= mod; } namespace SGT { struct node { int ch[2], sum, x; } t[maxn * 128]; const int L = 0, R = 1; int cnt; int Nnode() { return t[++cnt].x = 1, cnt; } void push_up(int rt) { t[rt].sum = (t[t[rt].ch[L]].sum + t[t[rt].ch[R]].sum) % mod; } void push_down(int rt) { if(t[rt].x != 1) { mul(t[t[rt].ch[L]].x, t[rt].x); mul(t[t[rt].ch[R]].x, t[rt].x); mul(t[t[rt].ch[L]].sum, t[rt].x); mul(t[t[rt].ch[R]].sum, t[rt].x); t[rt].x = 1; } } void insert(int p, int tl, int tr, int dt, int &rt) { if(!rt) rt = Nnode(); if(tl == tr) return t[rt].sum = dt, void(); int mid = tl + tr >> 1; push_down(rt); if(p <= mid) insert(p, tl, mid, dt, t[rt].ch[L]); else insert(p, mid + 1, tr, dt, t[rt].ch[R]); push_up(rt); } int tmpa, tmpb; void merge(int &a, int b, int p) { if(a + b == 0) return ; if(!a) { t[++cnt] = t[b]; a = cnt, Add(tmpb, t[b].sum); tmp = (1ll * tmpa * p % mod + (1ll * (1 - tmpa) * (1 - p) % mod + mod) % mod) % mod; mul(t[a].x, tmp); mul(t[a].sum, tmp); return void(); } if(!b) { t[++cnt] = t[a], a = cnt; Add(tmpa, t[a].sum); tmp = (1ll * tmpb * p % mod + (1ll * (1 - tmpb) * (1 - p) % mod + mod) % mod) % mod; mul(t[a].x, tmp); mul(t[a].sum, tmp); return void(); } push_down(a), push_down(b); t[++cnt] = t[a], a = cnt; merge(t[a].ch[L], t[b].ch[L], p); merge(t[a].ch[R], t[b].ch[R], p); push_up(a); } void work(int &a, int b, int p) { tmpa = tmpb = 0; merge(a, b, p); } int query(int p, int tl, int tr, int rt) { if(tl == tr) return t[rt].sum; int mid = tl + tr >> 1; push_down(rt); if(p <= mid) return query(p, tl, mid, t[rt].ch[L]); else return query(p, mid + 1, tr, t[rt].ch[R]); } } int vs[maxn], vis[maxn], vcnt; int rt[maxn], fa[maxn], p[maxn], nleaf[maxn], n; int main() { //freopen("4.in", "r", stdin); scanf("%d", &n); for(register int i = 1; i <= n; ++i) scanf("%d", &fa[i]), nleaf[fa[i]] = 1; for(register int i = 1; i <= n; ++i) { scanf("%d", &p[i]); if(!nleaf[i]) { SGT::insert(p[i], 1, N, 1, rt[i]); vs[++vcnt] = p[i]; } else p[i] = p[i] * inv1e4 % mod; } sort(vs + 1, vs + vcnt + 1); for(register int i = n; i > 1; --i) { if(!vis[fa[i]]) rt[fa[i]] = rt[i], vis[fa[i]] = 1; else SGT::work(rt[fa[i]], rt[i], p[fa[i]]); //if(i == 4) cerr << SGT::query(2, 1, N, rt[2]) << endl; } //cerr << "#" << fpow(5, mod - 2) << endl; for(register int i = 1; i <= vcnt; ++i) { tmp = SGT::query(vs[i], 1, N, rt[1]); //cerr << vs[i] << " " << tmp << endl; mul(tmp, tmp), mul(tmp, i), mul(tmp, vs[i]); Add(ans, tmp); } printf("%d", ans); return 0; }
没有帐号? 立即注册