BZOJ 4771 - 七彩树

Published on 2017-03-11

题目地址

描述

给定一棵 n(n105)n(n\le {10}^5) 个点的有根树,编号依次为 11nn,其中 11 号点是根节点。每个节点都被染上了某一种颜色,其中第 ii 个节点的颜色为 cic_i。如果 ci=cjc_i = c_j,那么我们认为点 ii 和点 jj 拥有相同的颜色。定义 depth(i)\mathrm{depth}(i)ii 节点与根节点的距离,为了方便起见,你可以认为树上相邻的两个点之间的距离为 1。站在这棵色彩斑斓的树前面,你将面临 m(m105)m(m\le {10}^5) 个问题。

每个问题包含两个整数 xxdd,表示询问 xx 子树里,depth\mathrm{depth} 不超过 depth(x)+d\mathrm{depth}(x) + d 的所有点中出现了多少种本质不同的颜色。请写一个程序,快速回答这些询问。

分析

这个题的做法比较神(也有可能是一类套路),涉及到线段树合并。线段树合并是动态开点的线段树由 棵单点有值的线段树合并到一棵 nn 个点有值的线段树。

合并的方法是递归进行的,如果要合并的两棵线段树都为空,就返回空。如果有一棵为空,就返回另一棵,否则暴力递归合并左右儿子。可以用势能分析证明,无论怎么合并,合并的总复杂度都是 O(nlogn)O(n\log n) 的。

本题的做法是,对每一个节点 uu,用线段树维护一个序列 sssis_i 表示在 uu 子树中,深度为 ii 的点的颜色个数,但注意,深度 ii 的点中记录的颜色个数,并不包含 中已经计算过的颜色,也就是说一个颜色,只在最浅的位置加上贡献。假设询问是 (u,d)(u, d),答案就是

我们 DFS,自上往下对每个节点维护 ss。考虑当前节点为 uu,如何维护 sis_i,首先一开始 sdepth(u)=1s_{\mathrm{depth}(u)} = 1。接着递归子树,需要合并子树 vv 的答案,我们直接合并 uuvv 的线段树。但这样同一种颜色是会重复计算的,但有一点好处,一个颜色只会被重复计算两次,由于我们需要「一个颜色,只在最浅的位置加上贡献」,所以我们必须扣除同一颜色较深的点的贡献。

对每个节点 uu 再用线段树维护一个序列 did_i,表示 uu 的子树中,颜色为 ii 的节点的最浅深度,我们还是合并 uu 和子树 vv 的线段树,当合并到叶子节点的时候,如果两个节点都有值,那么我们就保留较小值,然后在 uu 维护的序列 ss 中扣除较深节点的贡献,这样就能够每个颜色只在最浅的位置上计算贡献了。

线段树合并的复杂度为 O(nlogn)O(n\log n),查询也是 O(logn)O(\log n) 单次,时间复杂度 O((n+m)logn)O((n + m)\log n),注意空间复杂度也是 O(nlogn)O(n\log n),且空间的常数较大。

总结:线段树合并可以较为暴力地统计每个子树内的一些信息,本题就是利用了线段树合并这一工具,来达到消除重复颜色的目的,而且这个算法是在线的。

代码

//  Created by Sengxian on 2017/3/11.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4771 线段树合并 + 可持久化线段树
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

#define len(x) ((int)x.size())
const int MAX_N = 100000 + 3;
vector<int> G[MAX_N];
int n, m, col[MAX_N];

struct Node *null1;
struct Node *null2;
struct Node {
    Node *lc, *rc;
    int val;

    Node(Node *null = NULL, int val = 0) : lc(null), rc(null), val(val) {}
} pool[10000000], *pit, *root1[MAX_N], *root2[MAX_N];

int dep[MAX_N];

#define mid (((l) + (r)) >> 1)

Node *modify(Node *o, int l, int r, int pos, int val) {
    Node *ne = pit++; *ne = *o;
    ne->val = val;
    if (r - l == 1) return ne;
    if (pos < mid) ne->lc = modify(ne->lc, l, mid, pos, val);
    else ne->rc = modify(ne->rc, mid, r, pos, val);
    return ne;
}

Node *modify1(Node *o, int l, int r, int pos, int val) {
    Node *ne = pit++; *ne = *o;
    ne->val += val;
    if (r - l == 1) return ne;
    if (pos < mid) ne->lc = modify1(ne->lc, l, mid, pos, val);
    else ne->rc = modify1(ne->rc, mid, r, pos, val);
    return ne;
}

Node *merge1(Node *t1, Node *t2, int l, int r) {
    if (t1 == null1 && t2 == null1) return null1;
    if (t2 == null1) return t1;
    if (t1 == null1) return t2;
    Node *ne = new (pit++) Node(null1, t1->val + t2->val);
    if (r - l == 1) return ne;
    ne->lc = merge1(t1->lc, t2->lc, l, mid);
    ne->rc = merge1(t1->rc, t2->rc, mid, r);
    return ne;
}

Node *merge2(Node *t1, Node *t2, int l, int r, Node *&t3) {
    if (t1 == null2 && t2 == null2) return null2;
    if (t2 == null2) return t1;
    if (t1 == null2) return t2;
    Node *ne = new (pit++) Node(null2, INT_MAX);
    if (r - l == 1) {
        ne->val = min(t1->val, t2->val);
        if (max(t1->val, t2->val) != INT_MAX) t3 = modify1(t3, 0, n, max(t1->val, t2->val), -1);
    }
    ne->lc = merge2(t1->lc, t2->lc, l, mid, t3);
    ne->rc = merge2(t1->rc, t2->rc, mid, r, t3);
    return ne;
}

int query(Node *o, int l, int r, int a, int b) {
    if (o->val == 0 || r <= a || l >= b) return 0;
    if (l >= a && r <= b) return o->val;
    return query(o->lc, l, mid, a, b) + query(o->rc, mid, r, a, b);
}

void dfs(int u) {
    root1[u] = modify(null1, 0, n, dep[u], 1);
    root2[u] = modify(null2, 0, n, col[u], dep[u]);

    for (int i = 0; i < len(G[u]); ++i) {
        int v = G[u][i];
        dep[v] = dep[u] + 1;
        dfs(v);
        root1[u] = merge1(root1[u], root1[v], 0, n);
        root2[u] = merge2(root2[u], root2[v], 0, n, root1[u]);
    }
}

void prepare() {
    pit = pool;
    null1 = pit++, null1->lc = null1->rc = null1, null1->val = 0;
    null2 = pit++, null2->lc = null2->rc = null2, null2->val = INT_MAX;
    dfs(0);
}

int main() {
    int caseNum = readInt();
    while (caseNum--) {
        n = readInt(), m = readInt();
        for (int i = 0; i < n; ++i) G[i].clear();
        for (int i = 0; i < n; ++i) col[i] = readInt() - 1;
        for (int i = 1; i < n; ++i) {
            int fa = readInt() - 1;
            G[fa].push_back(i);
        }

        prepare();

        int lastAns = 0;
        for (int i = 0; i < m; ++i) {
            int u = (readInt() ^ lastAns) - 1, d = readInt() ^ lastAns;;
            printf("%d\n", lastAns = query(root1[u], 0, n, dep[u], dep[u] + d + 1));
        }
    }
    return 0;
}