BZOJ4771:七彩树

【题目描述】

给定一棵n个点的有根树,编号依次为1到n,其中1号点是根节点。每个节点都被染上了某一种颜色,其中第i个节点的颜色为c[i]。如果c[i]=c[j],那么我们认为点i和点j拥有相同的颜色。定义depth[i]为i节点与根节点的距离,为了方便起见,你可以认为树上相邻的两个点之间的距离为1。站在这棵色彩斑斓的树前面,你将面临m个问题。

每个问题包含两个整数x和d,表示询问x子树里且depth不超过depth[x]+d的所有点中出现了多少种本质不同的颜色。请写一个程序,快速回答这些询问。

【输入】

第一行包含一个正整数T(1<=T<=500),表示测试数据的组数。

每组数据中,第一行包含两个正整数n(1<=n<=100000)和m(1<=m<=100000),表示节点数和询问数。

第二行包含n个正整数,其中第i个数为c[i](1<=c[i]<=n),分别表示每个节点的颜色。

第三行包含n-1个正整数,其中第i个数为f[i+1](1<=f[i]<i),表示节点i+1的父亲节点的编号。< p="">

接下来m行,每行两个整数x(1<=x<=n)和d(0<=d<n),依次表示每个询问。< p="">

输入数据经过了加密,对于每个询问,如果你读入了x和d,那么真实的x和d分别是x xor last和d xor last,

其中last表示这组数据中上一次询问的答案,如果这是当前数据的第一组询问,那么last=0。

输入数据保证n和m的总和不超过500000。

 

【输出】

对于每个询问输出一行一个整数,即答案。

【输入样例】

1
5 8
1 3 3 2 2
1 1 3 3
1 0
0 0
3 0
1 3
2 1
2 0
6 2
4 1

【输出样例】

1
2
3
1
1
2
1
1

【来源】


我们可以模仿HH的项链这道题。

考虑维护每一个点子树内每种颜色最近的位置在哪里。这样我们只要统计这些位置在深度范围内的有多少个就行了。

考虑使用线段树合并,对每个节点维护线段树,每一个下标i的值表示第i种颜色出现的最浅位置。这个可以直接线段树合并,当发现有一个叶子节点在两颗线段树中都有的时候,新的线段树直接取更浅的那个就行了。

但是这样是没办法统计答案的。

考虑维护答案线段树,与上面的线段树互反,下标i的值表示最浅处在i的颜色有多少个。我们同时维护这两颗线段树,先把答案线段树直接合并起来,对位相加。然后在原来线段树合并的同时更新答案线段树:当原来的线段树有位置重复时我们在答案线段树中把更深的那个点深度的位置处单点-1。

在线的话只要把线段树合并的过程可持久化就行了。

总复杂度为:O(n log n)

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
using namespace std;
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
inline char gc(){
    static char buf[20000000], *p1 = buf, *p2 = buf;
    return (p1 == p2) && (p2 = (p1 = buf) + fread(buf, 1, 20000000, stdin), p1 == p2) ? EOF : *p1++;
}
inline void read(int & x) {
    x = 0; static char c = gc();
    while(!isdigit(c)) c = gc();
    while(isdigit(c)) x = x * 10 + c - '0', c = gc();
}
void write(int x) {
    if(x > 9) write(x / 10);
    putchar((x % 10) + '0');
}
struct edge {
    int v, next;
}e[maxn];
int head[maxn], cnt, n, m, u, v, t, lastans;
int rt[maxn], crt[maxn], c[maxn], fa[maxn], depth[maxn];
void adde(const int &u, const int &v) {
    e[++cnt] = (edge) {v, head[u]};
    head[u] = cnt;
}
 
namespace SGT2 {
    struct node {
        int ch[2], sum;
    }t[maxn * 64];
    int cnt;
    const int L = 0, R = 1;
    void init() {cnt = 0;}
    int Newnode() {
        t[++cnt] = (node){0, 0, 0};
        return cnt;
    }
    void push_up(int rt) {
        t[rt].sum = t[t[rt].ch[R]].sum + t[t[rt].ch[L]].sum;
    }
    void merge(int &x, int y) {
        if(!y) return ;
        if(!x) return x = y, void();
        t[++cnt] = t[x], x = cnt;
        t[x].sum += t[y].sum;
        merge(t[x].ch[L], t[y].ch[L]);
        merge(t[x].ch[R], t[y].ch[R]);
    }
    void add(int p, int tl, int tr, int & rt, int dt) {
        t[++cnt] = t[rt], rt = cnt;
        if(tl == tr) return t[rt].sum += dt, void();
        int mid = tl + tr >> 1;
        if(p <= mid) add(p, tl, mid, t[rt].ch[L], dt);
        else add(p, mid + 1, tr, t[rt].ch[R], dt);
        push_up(rt);
    }
    int query(int l, int r, int tl, int tr, int rt) {
        if(!rt) return 0;
        if(l == tl && r == tr) return t[rt].sum;
        int mid = tl + tr >> 1;
        if(r <= mid) return query(l, r, tl, mid, t[rt].ch[L]);
        else if(l > mid) return query(l, r, mid + 1, tr, t[rt].ch[R]);
        else return query(l, mid, tl, mid, t[rt].ch[L]) + 
                    query(mid + 1, r, mid + 1, tr, t[rt].ch[R]);
    }
}
 
namespace SGT {
    struct node {
        int ch[2], mn;
    }t[maxn * 64];
    int cnt;
    const int L = 0, R = 1;
    void init() {
        t[0] = (node) {0, 0, 0};
        cnt = 0;
    }
    int Newnode() {
        t[++cnt] = (node) {0, 0, 0};
        return cnt;
    }
    void merge(int & nrt, int & x, int y) {
        if(!y) return ;
        if(!x) return x = y, void();
        t[++cnt] = t[x], x = cnt;
        if(!t[x].ch[L] && !t[x].ch[R]) {
            if(t[x].mn < t[y].mn) SGT2::add(t[y].mn, 1, n, nrt, -1);
            else SGT2::add(t[x].mn, 1, n, nrt, -1);
            t[x].mn = min(t[x].mn, t[y].mn);
            return ;
        }
        merge(nrt, t[x].ch[L], t[y].ch[L]);
        merge(nrt, t[x].ch[R], t[y].ch[R]);
    }
    void ins(int p, int tl, int tr, int & rt, int dt) {
        if(!rt) rt = Newnode();
        if(tl == tr) return t[rt].mn = dt, void();
        int mid = tl + tr >> 1;
        if(p <= mid) ins(p, tl, mid, t[rt].ch[L], dt);
        else ins(p, mid + 1, tr, t[rt].ch[R], dt);
    }
}
 
void dfs(int u, int d) {
    crt[u] = rt[u] = 0, depth[u] = d;
    SGT::ins(c[u], 1, n, crt[u], d);
    SGT2::add(d, 1, n, rt[u], 1);
    for(register int i = head[u]; i; i = e[i].next) {
        int v = e[i].v;
        dfs(v, d + 1);
        SGT2::merge(rt[u], rt[v]);
        SGT::merge(rt[u], crt[u], crt[v]);
    }
}
 
int main() {
    read(t);
    while(t--) {
        SGT::init(), SGT2::init(), lastans = 0;
        read(n), read(m), cnt = 0;
        memset(head, 0, sizeof(int) * (n + 5));
        for(register int i = 1; i <= n; ++i) read(c[i]);
        for(register int i = 2; i <= n; ++i)
            read(fa[i]), adde(fa[i], i);
        dfs(1, 1);
        for(register int i = 1; i <= m; ++i) {
            read(u), read(v);
            u ^= lastans, v ^= lastans;
            lastans = SGT2::query(depth[u], depth[u] + v, 1, n, rt[u]);
            write(lastans), putchar('\n');
        }
    }
    return 0;
}



上一篇: 多项式基础

下一篇: Codeforces #364 Div1 E. Cool Slogans

553 人读过
立即登录, 发表评论.
没有帐号? 立即注册
1 条评论
文档导航