QTREE 6 - Query on a tree VI

Published on 2016-03-18

题目地址

分析

链分治,因为是求和,所以要去重,从什么地方爬上去的,就减什么地方的 maxLmaxL。要想清楚,合并的时候什么值该加,什么值不该加。
还有转换颜色的时候,被转换的节点的堆要重新更新,一定不要再次遍历儿子,菊花图会被卡成翔,应该记录两个堆,一个记录同色的,一个记录不同色的,然后 swap

代码

//  Created by Sengxian on 3/17/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  Spoj QTREE VI 链分治(巨坑)
#include <algorithm>
#include <iostream>
#include <cassert>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;

inline int ReadInt() {
    int n = 0, ch = getchar();
    while(!isdigit(ch)) ch = getchar();
    while(isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}

const int maxn = 100000 + 3;
vector<int> G[maxn];
bool color[maxn]; //0 - black 1 - white
int n, fa[maxn], id[maxn], idR[maxn], s[maxn], sz[maxn], bel[maxn], timestamp = 0;
int root[maxn];

int dfs1(int u, int f) {
    fa[u] = f, s[u] = 1;
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (v != f) s[u] += dfs1(v, u);
    }
    return s[u];
}

void dfs2(int u, int num) {
    idR[timestamp] = u, id[u] = timestamp++, bel[u] = num;
    sz[num]++;
    int Max = 0, idx = -1;
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (v != fa[u] && s[v] > Max) {Max = s[v], idx = v;}
    }
    if (Max == 0) return;
    dfs2(idx, num);
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (v != fa[u] && v != idx) dfs2(v, v);
    }
}

const int maxNode = (1 << 17) * 2 + 10;
struct Node {
    int cntL, cntR, cleft, cright;
}seg[maxNode];
int ls[maxNode], rs[maxNode], segCnt = 0;
int sum[maxn], nosum[maxn];

inline Node merge(const Node &lnode, const Node &rnode, int l, int r) {
    Node newNode; int mid = (l + r) / 2;
    newNode.cntL = lnode.cntL, newNode.cleft = lnode.cleft;
    if (lnode.cleft == mid - l && color[idR[l]] == color[idR[mid]]) newNode.cntL += rnode.cntL, newNode.cleft += rnode.cleft;  //注意,不要加 l, mid + 1 这一段,因为必定已经包含在子区间
    newNode.cntR = rnode.cntR, newNode.cright = rnode.cright;
    if (rnode.cleft == r - mid && color[idR[mid - 1]] == color[idR[mid]]) newNode.cntR += lnode.cntR, newNode.cright += lnode.cright;
    return newNode;
}

inline void maintain(int o, int u) {
    seg[o].cntL = seg[o].cntR = sum[u], seg[o].cleft = seg[o].cright = 1;
}

void buildTree(int o, int l, int r) {
    if (r - l == 1) {
        int u = idR[l];
        nosum[u] = 1, sum[u] = 1;
        for (int i = 0; i < (int)G[u].size(); ++i) {
            int v = G[u][i];
            if (v != fa[u] && bel[u] != bel[v]) {
                buildTree(root[v] = segCnt++, id[v], id[v] + sz[v]);
                if (color[v] == color[u]) sum[u] += seg[root[v]].cntL;
                else nosum[u] += seg[root[v]].cntL;
            }
        }
        maintain(o, u);
    }else {
        int mid = (l + r) / 2;
        buildTree(ls[o] = segCnt++, l, mid);
        buildTree(rs[o] = segCnt++, mid, r);
        seg[o] = merge(seg[ls[o]], seg[rs[o]], l, r);
    }
}

void process() {
    dfs1(0, -1);
    dfs2(0, 0);
    buildTree(root[0] = segCnt++, id[0], id[0] + sz[0]);
}

int query(int o, int l, int r, int tar, int from) {
    int ans = 0;
    if (r - l == 1) {
        ans = seg[o].cntL;
        if (from != -1) ans -= seg[root[from]].cntL; //减去重复的!
        int head = bel[tar];
        if (head != 0 && seg[root[head]].cleft >= l - id[head] + 1) {
            int fh = fa[head], fhh = bel[fh];
            if (color[fh] == color[tar]) ans += query(root[fhh], id[fhh], id[fhh] + sz[fhh], fh, head);
        }
    }else {
        int mid = (l + r) / 2, idx = id[tar];
        if (idx < mid) {
            if (seg[ls[o]].cright >= mid - idx && color[tar] == color[idR[mid]]) ans += seg[rs[o]].cntL; //注意,不要加 idx, mid + 1 这一段,因为后面会一定加上的!
            ans += query(ls[o], l, mid, tar, from);
        }else {
            if (seg[rs[o]].cleft >= idx - mid + 1 && color[tar] == color[idR[mid - 1]]) ans += seg[ls[o]].cntR;
            ans += query(rs[o], mid, r, tar, from);
        }
    }
    return ans;
}

vector<int> path;
void findPath(int u) {
    while (u != -1) {
        path.push_back(u);
        u = fa[bel[u]];
    }
}

int X;
void modify(int o, int l, int r, int i) {
    if (r - l == 1) {
        int u = idR[l];
        if (i + 1 != (int)path.size()) {
            int nexT = bel[path[i + 1]];
            if (color[u] == color[nexT] && nexT != X) sum[u] -= seg[root[nexT]].cntL; //判断是否加过这个值,对于刚刚修改的节点应该是不等于,否则是等于!
            else if (color[u] != color[nexT] && nexT == X) sum[u] -= seg[root[nexT]].cntL;
            else nosum[u] -= seg[root[nexT]].cntL;
            modify(root[nexT], id[nexT], id[nexT] + sz[nexT], i + 1);
            if (color[u] == color[nexT]) sum[u] += seg[root[nexT]].cntL;
            else nosum[u] += seg[root[nexT]].cntL;
        }else swap(sum[u], nosum[u]);
        maintain(o, u);
    }else {
        int mid = (l + r) / 2;
        if (id[path[i]] < mid) modify(ls[o], l, mid, i);
        else modify(rs[o], mid, r, i);
        seg[o] = merge(seg[ls[o]], seg[rs[o]], l, r);
    }
}

int main() {
    n = ReadInt();
    for (int i = 0; i < n - 1; ++i) {
        int f = ReadInt() - 1, t = ReadInt() - 1;
        G[f].push_back(t);
        G[t].push_back(f);
    }
    process();
    int m = ReadInt();
    while (m--) {
        int op = ReadInt(), u = ReadInt() - 1;
        assert(op == 0 || op == 1);
        if (op == 0) printf("%d\n", query(root[bel[u]], id[bel[u]], id[bel[u]] + sz[bel[u]], u, -1));
        else {
            color[u] = !color[u];
            path.clear();
            findPath(u);
            reverse(path.begin(), path.end());
            X = u; // 记录改的节点
            modify(root[0], id[0], id[0] + sz[0], 0);
        }
    }
    return 0;
}