题目链接:https://loj.ac/problem/2537
考虑到题目是一颗二叉树,有一个简单的O(N^2)的DP做法。
f(i,j)表示i号点取值为数集中第j大数的概率,每一次我们从子树合并上来。
j被取到有两种情况:
- j更大并且当前选择的是最大值
- j更小并且当前选择的是最小值
分情况讨论,发现要转移需要一个选到比j小的值的概率和(比j大的可以直接用1减比j小的)。
那么我们维护两个前缀和(对两边的元素分别计算),即可完成一次合并。
如何优化这个算法呢?
假设两边子树可以产生的j值集合分别为A,B。
考虑A值域上这样连续的一段数:满足不存在B[i]使得A[l]<=B[i]<=A[r]。
我们发现实际上这样的数的答案是一样的。
合并?求和?区间乘?想到魔改线段树合并。
我们考虑使用线段树合并来同时寻找/更改区间。
每一个叶子节点建立权值线段树,初始为自己的值。
我们合并线段树的同时,维护前面A集合中数的概率和以及B集合中数的概率和。
在合并的过程中,一定会出现合并到头(两棵树有至少一棵不存在节点)的情况。
此时,如果A的线段树存在节点,说明这一段值域连续的A集合的答案应该一样。
从B来考虑同样如此,我们做一个区间乘操作即可。
因为从n条链一路合并上去,所以复杂度为O(n log n)。
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#define LL long long
using namespace std;
const int mod = 998244353;
const int maxn = 3e5 + 5;
const int N = 1.05e9;
LL fpow(LL a, int b) {
LL ans = 1;
while(b) {
if(b & 1) ans = ans * a % mod;
a = a * a % mod, b >>= 1;
}
return ans;
}
const LL inv1e4 = fpow(10000, mod - 2);
int tmp, ans;
void mul(int &x, const int &a) {
x = 1ll * x * a % mod;
}
void Add(int &x, const int &a) {
x += a; if(x >= mod) x -= mod;
}
namespace SGT {
struct node {
int ch[2], sum, x;
} t[maxn * 128];
const int L = 0, R = 1;
int cnt;
int Nnode() {
return t[++cnt].x = 1, cnt;
}
void push_up(int rt) {
t[rt].sum = (t[t[rt].ch[L]].sum + t[t[rt].ch[R]].sum) % mod;
}
void push_down(int rt) {
if(t[rt].x != 1) {
mul(t[t[rt].ch[L]].x, t[rt].x);
mul(t[t[rt].ch[R]].x, t[rt].x);
mul(t[t[rt].ch[L]].sum, t[rt].x);
mul(t[t[rt].ch[R]].sum, t[rt].x);
t[rt].x = 1;
}
}
void insert(int p, int tl, int tr, int dt, int &rt) {
if(!rt) rt = Nnode();
if(tl == tr)
return t[rt].sum = dt, void();
int mid = tl + tr >> 1;
push_down(rt);
if(p <= mid) insert(p, tl, mid, dt, t[rt].ch[L]);
else insert(p, mid + 1, tr, dt, t[rt].ch[R]);
push_up(rt);
}
int tmpa, tmpb;
void merge(int &a, int b, int p) {
if(a + b == 0) return ;
if(!a) {
t[++cnt] = t[b];
a = cnt, Add(tmpb, t[b].sum);
tmp = (1ll * tmpa * p % mod +
(1ll * (1 - tmpa) * (1 - p) % mod + mod) % mod) % mod;
mul(t[a].x, tmp);
mul(t[a].sum, tmp);
return void();
}
if(!b) {
t[++cnt] = t[a], a = cnt;
Add(tmpa, t[a].sum);
tmp = (1ll * tmpb * p % mod +
(1ll * (1 - tmpb) * (1 - p) % mod + mod) % mod) % mod;
mul(t[a].x, tmp);
mul(t[a].sum, tmp);
return void();
}
push_down(a), push_down(b);
t[++cnt] = t[a], a = cnt;
merge(t[a].ch[L], t[b].ch[L], p);
merge(t[a].ch[R], t[b].ch[R], p);
push_up(a);
}
void work(int &a, int b, int p) {
tmpa = tmpb = 0;
merge(a, b, p);
}
int query(int p, int tl, int tr, int rt) {
if(tl == tr)
return t[rt].sum;
int mid = tl + tr >> 1;
push_down(rt);
if(p <= mid) return query(p, tl, mid, t[rt].ch[L]);
else return query(p, mid + 1, tr, t[rt].ch[R]);
}
}
int vs[maxn], vis[maxn], vcnt;
int rt[maxn], fa[maxn], p[maxn], nleaf[maxn], n;
int main() {
//freopen("4.in", "r", stdin);
scanf("%d", &n);
for(register int i = 1; i <= n; ++i)
scanf("%d", &fa[i]), nleaf[fa[i]] = 1;
for(register int i = 1; i <= n; ++i) {
scanf("%d", &p[i]);
if(!nleaf[i]) {
SGT::insert(p[i], 1, N, 1, rt[i]);
vs[++vcnt] = p[i];
} else p[i] = p[i] * inv1e4 % mod;
}
sort(vs + 1, vs + vcnt + 1);
for(register int i = n; i > 1; --i) {
if(!vis[fa[i]]) rt[fa[i]] = rt[i], vis[fa[i]] = 1;
else SGT::work(rt[fa[i]], rt[i], p[fa[i]]);
//if(i == 4) cerr << SGT::query(2, 1, N, rt[2]) << endl;
}
//cerr << "#" << fpow(5, mod - 2) << endl;
for(register int i = 1; i <= vcnt; ++i) {
tmp = SGT::query(vs[i], 1, N, rt[1]);
//cerr << vs[i] << " " << tmp << endl;
mul(tmp, tmp), mul(tmp, i), mul(tmp, vs[i]);
Add(ans, tmp);
}
printf("%d", ans);
return 0;
}
rockdu
没有帐号? 立即注册