Rockdu's Blog
“Where there is will, there is a way”
亲 您的浏览器不支持html5的audio标签
Toggle navigation
Rockdu's Blog
主页
数据结构
字符串算法
图论
数论、数学
动态规划
基础算法
[其它内容]
计算几何
科研笔记
归档
标签
LOJ#6289. 花朵
? 解题记录 ?
? LOJ ?
? 树链剖分 ?
? FFT|NTT ?
? 动态规划 ?
2019-03-14 11:31:06
764
0
0
rockdu
? 解题记录 ?
? LOJ ?
? 树链剖分 ?
? FFT|NTT ?
? 动态规划 ?
题目链接:[传送门](https://loj.ac/problem/6289) 题解:不难发现,我们有一个优秀的$N^2$的$dp$做法,记录$f(i,j,0/1)$表示$i$子树内选$j$个,当前根选没选。 我们要知道分治$NTT$的复杂度:对于$k$个总度数$\sum deg=k$的多项式来说,计算它们乘积的复杂度可以做到:$O((\sum deg)log(\sum deg)logk)$。 因此,我们采用重链剖分,先转移所有轻链。 不难发现, $g(u,0)=\prod_{v\in son} (g(v,0)+g(v,1))$ $g(u,1)=w_u\prod_{v\in son} g(v,0)$ 因此可以对所有儿子分治$NTT$。 再考虑重链上的转移,我们记$f[0/1][0/1]$表示两边的选择情况,然后分治拼起来就可以了。 总复杂度因为树链剖分,所以是$O(nlog^3n)$ 代码能力低到窒息,写加调整整3个多小时。 ``` #include<cstdio> #include<cstring> #include<vector> #include<algorithm> using namespace std; const int maxn = 1e5 + 5; const int mod = 998244353; const int G = 3; /// namespace Polynomial /// struct PLG; NTT(PLG A); mul(PLG A, PLG B); /// namespace HLD /// dfs1(int u, int p), dfs2(int u, int top); /// namespace DP /// getL(int u) /// mergeH(vector<PLG > g[2]) /// f[1][1] = f[1][0] * f[0][1] + f[1][0] * f[1][1] + f[1][1] * f[0][1] /// f[1][0] = f[1][0] * f[0][0] + f[1][0] * f[1][0] + f[1][1] * f[0][0] /// f[0][1] = f[0][0] * f[0][1] + f[0][0] * f[1][1] + f[0][1] * f[0][1] /// f[0][0] = f[0][0] * f[0][0] + f[0][0] * f[1][0] + f[0][1] * f[0][0] /// int mul(const int &a, const int &b) { return 1ll * a * b % mod; } void Add(int &x, const int &a) { x += a; if(x >= mod) x -= mod; } void Dec(int &x, const int &a) { x -= a; if(x < 0) x += mod; } int tot, totg; namespace Poly { const int maxd = 1e6 + 5; struct PLG { vector<int > x; int deg() const {return x.size() - 1;} void ext(int n) {x.resize(n + 1);} void prt() { int len = deg(); for(register int i = 0; i <= len; ++i) printf("%d ", x[i]); putchar('\n'); } }; int RL[maxd], N, mxp2, w[maxd]; void init(int n) { for(N = 1, mxp2 = 0; N <= n; N <<= 1, ++mxp2) ; for(register int i = 0; i < N; ++i) RL[i] = (RL[i >> 1] >> 1) | ((i & 1) << mxp2 - 1); } int fpow(int a, int b) { int ans = 1; while(b) { if(b & 1) ans = 1ll * ans * a % mod; a = 1ll * a * a % mod, b >>= 1; } return ans; } void NTT(PLG &A, int type) { int tmp, x, y; A.ext(N); for(register int i = 0; i < N; ++i) if(i < RL[i]) swap(A.x[i], A.x[RL[i]]); for(register int i = 1; i < N; i <<= 1) { int Wn = fpow(G, (mod - 1) / (i << 1)); if(type == -1) Wn = fpow(Wn, mod - 2); w[0] = 1; for(register int j = 1; j <= i; ++j) w[j] = mul(w[j - 1], Wn); for(register int j = 0, p = i << 1; j < N; j += p) { for(register int k = 0; k < i; ++k) { x = A.x[j + k], y = mul(A.x[j + i + k], w[k]); A.x[j + k] = (x + y) % mod; A.x[j + i + k] = (x - y + mod) % mod; } } } } PLG operator *(PLG A, PLG B) { int len = A.deg() + B.deg(); //A.prt(), B.prt(); init(len); tot += len; NTT(A, 1), NTT(B, 1); for(register int i = 0; i < N; ++i) { A.x[i] = mul(A.x[i], B.x[i]); } NTT(A, -1); int invN = fpow(N, mod - 2); for(register int i = 0; i < N; ++i) { A.x[i] = mul(A.x[i], invN); } A.ext(len); return A; } PLG operator - (const PLG &A, const PLG &B) { PLG ans; int len = max(A.deg(), B.deg()); ans = A, ans.ext(len); for(register int i = B.deg(); i >= 0; --i) Dec(ans.x[i], B.x[i]); return ans; } PLG operator + (const PLG &A, const PLG &B) { PLG ans; int len = max(A.deg(), B.deg()); ans = A, ans.ext(len); for(register int i = B.deg(); i >= 0; --i) Add(ans.x[i], B.x[i]); return ans; } } struct edge { int v, next; } e[maxn << 1]; int head[maxn], cnt; void adde(const int &u, const int &v) { e[++cnt] = (edge) {v, head[u]}; head[u] = cnt; } namespace HLD { int top[maxn], son[maxn], siz[maxn], fa[maxn]; void dfs1(int u, int p) { fa[u] = p, siz[u] = 1; for(register int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(v == p) continue; dfs1(v, u); siz[u] += siz[v]; if(!son[u] || siz[v] > siz[son[u]]) son[u] = v; } } void dfs2(int u, int tp) { top[u] = tp; if(son[u]) dfs2(son[u], tp); for(register int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(v == fa[u] || v == son[u]) continue; dfs2(v, v); } } void work() { dfs1(1, 0), dfs2(1, 1); } } int n, w[maxn], m, u, v; Poly::PLG g[2][maxn]; namespace DP { void mergeH(int u); struct MES {Poly::PLG f[2][2];}; MES operator + (const MES &A, const MES &B) { MES ans; ans.f[1][1] = A.f[1][0] * (B.f[0][1] + B.f[1][1]) + A.f[1][1] * B.f[0][1]; ans.f[1][0] = A.f[1][0] * (B.f[0][0] + B.f[1][0]) + A.f[1][1] * B.f[0][0]; ans.f[0][1] = A.f[0][0] * (B.f[0][1] + B.f[1][1]) + A.f[0][1] * B.f[0][1]; ans.f[0][0] = A.f[0][0] * (B.f[0][0] + B.f[1][0]) + A.f[0][1] * B.f[0][0]; return ans; } MES solve(const vector<int > &chain, int l, int r) { if(l == r) { MES ret; ret.f[0][0] = g[0][chain[l]]; ret.f[0][1].x.push_back(0); ret.f[1][0].x.push_back(0); ret.f[1][1] = g[1][chain[l]]; return ret; } int mid = l + r >> 1; return solve(chain, l, mid) + solve(chain, mid + 1, r); } struct MESL {Poly::PLG f[2];}; MESL operator +(const MESL &A, const MESL &B) { MESL ans; ans.f[0] = A.f[0] * B.f[0]; ans.f[1] = A.f[1] * B.f[1]; return ans; } MESL solveL(const vector<int > &chain, int l, int r) { if(l > r) { MESL ret; ret.f[0].x.push_back(1); ret.f[1].x.push_back(1); return ret; } if(l == r) { MESL ret; ret.f[0] = g[0][chain[l]]; ret.f[1] = g[0][chain[l]] + g[1][chain[l]]; return ret; } int mid = l + r >> 1; return solveL(chain, l, mid) + solveL(chain, mid + 1, r); } void getL(int u) { using namespace HLD; //printf("%d\n", u); g[1][u].x.push_back(0); g[1][u].x.push_back(w[u]); g[0][u].x.push_back(1); vector<int > chain; for(register int i = head[u]; i; i = e[i].next) { int v = e[i].v; if(v == son[u] || v == fa[u]) continue; mergeH(v); chain.push_back(v); } MESL ans = solveL(chain, 0, chain.size() - 1); g[1][u] = g[1][u] * ans.f[0]; g[0][u] = g[0][u] * ans.f[1]; } void mergeH(int u) { using namespace HLD; vector<int > chain; int now = u; while(now) chain.push_back(now), now = son[now]; for(register int i = chain.size() - 1; i >= 0; --i) { getL(chain[i]); //printf("## %d\n", chain[i]); //g[0][chain[i]].prt(); //g[1][chain[i]].prt(); } MES ans = solve(chain, 0, chain.size() - 1); g[0][u] = ans.f[0][1] + ans.f[0][0]; g[1][u] = ans.f[1][1] + ans.f[1][0]; } } int main() { //freopen("3-1.in", "r", stdin); //test(); scanf("%d%d", &n, &m); for(register int i = 1; i <= n; ++i) scanf("%d", &w[i]); for(register int i = 1; i < n; ++i) { scanf("%d%d", &u, &v); adde(u, v), adde(v, u); } HLD::work(); DP::mergeH(1); //printf("%d\n", tot); if(m > g[0][1].deg() && m > g[1][1].deg()) printf("0"); else printf("%d", (g[0][1] + g[1][1]).x[m] % mod); return 0; } ```
上一篇:
LOJ#517. 「LibreOJ β Round #2」计算几何瞎暴力
下一篇:
TCO19 SRM 739 Div1 解题记录
0
赞
764 人读过
新浪微博
微信
腾讯微博
QQ空间
人人网
提交评论
立即登录
, 发表评论.
没有帐号?
立即注册
0
条评论
More...
文档导航
没有帐号? 立即注册