主席树

Published on 2016-04-04

最近搞了搞主席树,发现比想象中简单,又发现网上的讲解几乎看不懂,于是自己写一篇简易的指南,较难的问题慢慢补吧。

介绍

主席树是什么玩意呢?它是“函数式版本的线段树”,说的准确一点,他是 nn 棵完整的权值线段树,但是这 nn 棵树之间共用一些节点,使得内存开销仅为 O(nlogn)O(n\log n),由于权值线段树之间可以加减,所以我们可以得到序列任意区间的一棵权值线段树。

求区间第 K 大

我们用求区间第 K 大来理解主席树。
我们先来试着用权值线段树求整个序列的第 kk 大。
如果线段树的节点 uu 表示区间 [l,r)[l, r) 的话,那么这个节点的值就是整个序列中有多少个数在区间 [l,r)[l, r),容易看出,线段树的大小与序列的最大值有关,所以我们要离散化,尽可能的缩小节点数量。
kk 大与求第 kk 小实际上是类似的,我们求第 kk 小:如果 kk 小于等于左儿子的值的话,那么我们知道有不小于 kk 个数在左儿子代表的区间,那么第 kk 小自然在左儿子;反之,第 kk 小是右儿子的第 kvalleftChildk - \text{val}_{\text{leftChild}} 小。
递归到的叶子节点假如代表区间 [l,l+1)[l, l + 1) 的话,那么 ll 便是第 kk 小。

现在考虑求一段序列 [l,r)[l, r) 的第 kk 大,显然这比整个序列的第 kk 大多了一个维度,可以采取树套树,让之前的权值线段树的节点为一棵普通区间线段树,但这与主席树是无关的。
如果我们能把序列 [l,r)[l, r) 建成一棵权值线段树,那么我们就可以像刚才一样处理了,但是这样无论时间空间都无法承受,注意到因为节点保存的是出现次数,权值线段树有可加减性!即把各个节点的值相加减,表示的就是两个区间合并/做差的结果。
所以我们可以采取前缀和的思想!建立 nn 棵树,第 ii 棵树 TiT_i 表示把区间 [0,i)[0, i) 建成的权值线段树,那么区间 [l,r)[l, r),就可以表示成 TrTlT_r - T_l,当然我们不用真建,直接同时查询,相减即可。
但这样时间空间还是不行,一共有 nn 棵树,最坏有 n2n^2 个节点,怎么办呢?
考虑从 TiT_i 得到 Ti+1T_{i + 1} 的过程,设序列第 i+1i + 1 个数为 vv,那么我们就要对 TiT_i 执行一次单点修改,将 TiT_i 的位置 vv 的值加上 1,那么实际上只有一条链需要被修改!所以,我们完全可以在 TiT_i 上爬,但是我们每次重复利用不修改的部分,这样每次最多影响 logn\log n 个节点,空间变为 O(nlogn)O(n\log n),可以承受了!

模版

BZOJ 3524
注意到使用 null 指针的话,就不需要一开始建所有节点了,而且 null 的左右儿子还是 null,对答案无影响,直接当作完整的树写就行。

//  Created by Sengxian on 4/4/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 3932 主席树
#include <algorithm>
#include <iostream>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <vector>
#define mid (((l) + (r)) / 2)
using namespace std;

inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}

typedef long long ll;
const int maxn = 100000 + 3;
struct SegNode *pit, *null;
struct SegNode {
    SegNode *ls, *rs;
    ll s, sum;
    inline void maintain() {
        s = ls->s + rs->s;
        sum = ls->sum + rs->sum;
    }
    void *operator new(size_t) {return pit++;}
    SegNode(): s(0), sum(0) {ls = null, rs = null;}
}pool[maxn * 18 * 2], *root[maxn];

int n, m, S[maxn], E[maxn], P[maxn]; //n 任务总数 m 时间范围

vector<int> ps;
void compress() {
    for (int i = 0; i < m; ++i)
        ps.push_back(P[i]);
    sort(ps.begin(), ps.end());
    ps.erase(unique(ps.begin(), ps.end()), ps.end());
    for (int i = 0; i < m; ++i)
        P[i] = lower_bound(ps.begin(), ps.end(), P[i]) - ps.begin();
}

void init() {
    pit = pool;
    null = new SegNode();
    null->s = 0, null->sum = 0;
    null->ls = null, null->rs = null;
}

struct state {
    int t, p, op;
    state(int t, int p, int op): t(t), p(p), op(op) {}
    bool operator < (const state &s) const {
        return t < s.t;
    }
};

SegNode *modify(const SegNode *o, int l, int r, int v, int op) {
    if (l >= r) return null;
    SegNode *ne = new SegNode();
    *ne = *o;
    if (r - l == 1) {
        ne->s += op;
        ne->sum += ps[v] * op;
        return ne;
    }
    if (v < mid) ne->ls = modify(ne->ls, l, mid, v, op);
    else ne->rs = modify(ne->rs, mid, r, v, op);
    ne->maintain();
    return ne;
}

ll query(const SegNode *o, int l, int r, int k) {
    if (l >= r || k == 0) return 0;
    if (o->s <= k) return o->sum;
    if (r - l == 1) return (ll)ps[l] * k;
    if (k <= o->ls->s) return query(o->ls, l, mid, k);
    else return o->ls->sum + query(o->rs, mid, r, k - o->ls->s);
}

int main() {
    init();
    m = ReadInt(), n = ReadInt();
    for (int i = 0; i < m; ++i)
        S[i] = ReadInt() - 1, E[i] = ReadInt(), P[i] = ReadInt(); //[S, E)
    compress(); //离散
    vector<state> events;
    for (int i = 0; i < m; ++i)
        events.push_back(state(S[i], P[i], 1)), events.push_back(state(E[i], P[i], -1)); //时间结束的时候,需要减!
    sort(events.begin(), events.end());
    for (int i = 0, j = 0; i < n; ++i) {
        SegNode *now = null;
        if (i) now = root[i - 1];
        while (j < events.size() && events[j].t == i) {
            now = modify(now, 0, ps.size(), events[j].p, events[j].op);
            ++j;
        }
        root[i] = now;
    }
    ll pre = 1, k;
    while (n--) {
        int x = ReadInt() - 1, a = ReadInt(), b = ReadInt(), c = ReadInt();
        k = 1 + (a * pre + b) % c;
        printf("%lld\n", pre = query(root[x], 0, ps.size(), k));
    }
    return 0;
}

带修改求区间第 K 大

我们刚才用主席树在 O(logn)O(\log n) 的时间内可以求出第 kk 大,但是只是无修改版本,如何动态求区间第 kk 大呢?如果不要求强制在线,可以使用整体二分,单次复杂度 O(log2n)O(\log^2 n),还是很优秀的:
--> 整体二分解决 BZOJ 1901
如果要求强制在线,怎么破?主席树依然能派上用场。
我们首先对原序列建立一棵主席树。注意到主席树维护的是前缀序列,所以这棵主席树无论如何是动不了了,既然动不了,我们就维护增量,由于只有权值线段树之间能够相互加减,所以我们维护树状数组套动态开点的权值线段树,这样我们就可以通过加法,得到任意一段区间的增量,再加上原序列,可得到修改过的一段序列的对应的权值线段树。
具体怎么维护增量呢?我们先看简单的情形,即实现对一个序列,实现区间增加,单点查值。
我们还是使用树状数组,如果将区间 [l,r)[l, r) 加上一个数,那么我们就将数状数组 ll 位置 + v,rr 位置 -v,那么树状数组的 sum(i)sum(i) 就表示位置 ii 的增量,加上原序列 ii 的值就是修改成的值。
现在树状数组的区间表示权值线段树,若原来树状数组节点代表 [l,r)[l, r) 的和,那么这个节点表示的权值线段树的就是序列 [l,r)[l, r) 的增量,即如果有一个在 [l,r)[l, r) 内增加 vv 这个数,那么线段树位置 vv +1,如果删除 vv 这个数,那么线段树位置 vv -1。
设原主席树为 TT,树状数组的前缀和为 sumsum,那么区间 [l,r)[l, r) 对应的权值线段树就是:

TrTl+sum(r)sum(l)T_r - T_l + sum(r) - sum(l)

然后就可以放心的跑下去找了。

模版

BZOJ 1901:只写了整体二分版本QAQ

树上路径第 K 大

现在第 K 大跑到树上来了嘿嘿嘿,怎么搞?
很好搞啊,首先我们对根节点建一棵主席树,然后DFS下去,每一个节点在他父亲的主席树上面加点(显然是新建一条链),最后就可以得到建出每个节点到根的对应的主席树。
树上的前缀和之类的,一般与 LCA 有关,那么路径 (u,v)(u, v) 对应的权值线段树就应该是:

Tu+TvTLCA(u,v)Tfa(LCA(u, v))T_u + T_v - T_{\text{LCA}(u, v)} - T_{fa(\text{LCA(u, v)})}

就可以干权值线段树可以干的任何事情了嘿嘿嘿。
如果有修改的话,还是建出主席树(听说这样快),然后与上题一样,用树状数组套动态开点的权值线段树维护每个点到根的链上的增量,这时候每个点的下标应该是它的 DFS 序下标!那么如果把 vv 的权值由 c1c_1 变为 c2c_2,那么就在 id[v]id[v] 处减 c1c_1c2c_2(加减显然对应权值线段树上的修改),在 id[v]+s[v]id[v] + s[v] 处加 c1c_1c2c_2s[v]s[v] 为以 vv 为根的子树大小,由于 vv 权值变了,那么受影响的就只有 vv 及其子树到根的增量,其他的都是没变的,DFS 序的好处就是子树在树状数组中是连续的一段,那么开头加结尾减就是完美!

模版

BZOJ 2588(SPOJ COT):

//  Created by Sengxian on 4/4/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 2588 树上路径第 k 大
#include <algorithm>
#include <iostream>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>
#define mid (((l) + (r)) / 2)
using namespace std;

inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}
typedef long long ll;

const int maxn = 100000 + 3;
struct SegNode *pit, *null;
struct SegNode {
    SegNode *ls, *rs;
    int s;
    inline void maintain() {
        s = ls->s + rs->s;
    }
    SegNode(): ls(null), rs(null), s(0) {}
}pool[maxn * 18], *root[maxn];

void init() {
    pit = pool;
    null = new SegNode();
    null->ls = null, null->rs = null;
}

SegNode* modify(const SegNode *o, int l, int r, int v) {
    if (l >= r) return null;
    SegNode *ne = pit++;
    *ne = *o;
    if (r - l == 1)
        ne->s++;
    else {
        if (v < mid) ne->ls = modify(ne->ls, l, mid, v);
        else ne->rs = modify(ne->rs, mid, r, v);
        ne->maintain();
    }
    return ne;
}

vector<int> G[maxn], Ws;
int n, m, w[maxn], ancestor[maxn][18], depth[maxn];

void compress() {
    for (int i = 0; i < n; ++i)
        Ws.push_back(w[i]);
    sort(Ws.begin(), Ws.end());
    Ws.erase(unique(Ws.begin(), Ws.end()), Ws.end());
    for (int i = 0; i < n; ++i)
        w[i] = lower_bound(Ws.begin(), Ws.end(), w[i]) - Ws.begin();
}

void process() {
    for (int w = 1; (1 << w) < n; ++w)
        for (int i = 0; i < n; ++i) if (depth[i] - (1 << w) >= 0)
            ancestor[i][w] = ancestor[ancestor[i][w - 1]][w - 1];
}

int LCA(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    int lim = log2(depth[a]);
    for (int i = lim; i >= 0; --i)
        if (depth[a] - (1 << i) >= depth[b])
            a = ancestor[a][i];
    if (a == b) return a;
    for (int i = lim; i >= 0; --i)
        if (depth[a] - (1 << i) >= 0 && ancestor[a][i] != ancestor[b][i]) {
            a = ancestor[a][i];
            b = ancestor[b][i];
        }
    return ancestor[a][0];
}

int query(const SegNode *a, const SegNode *b, const SegNode *c, const SegNode *d, int l, int r, int k) {
    if (r - l == 1) return Ws[l];
    int s = a->ls->s + b->ls->s - c->ls->s - d->ls->s;
    if (k <= s) return query(a->ls, b->ls, c->ls, d->ls, l, mid, k);
    else return query(a->rs, b->rs, c->rs, d->rs, mid, r, k - s);
}

void dfs(const SegNode *o, int u, int fa) {
    ancestor[u][0] = fa, depth[u] = fa == -1 ? 0 : depth[fa] + 1;
    root[u] = modify(o, 0, Ws.size(), w[u]);
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (v != fa) dfs(root[u], v, u);
    }
}

int main() {
    init();
    n = ReadInt(), m = ReadInt();
    for (int i = 0; i < n; ++i)
        w[i] = ReadInt();
    compress();
    for (int i = 0; i < n - 1; ++i) {
        int f = ReadInt() - 1, t = ReadInt() - 1;
        G[f].push_back(t);
        G[t].push_back(f);
    }
    dfs(null, 0, -1);
    process();
    int lastAns = 0;
    while (m--) {
        int u = (ReadInt() ^ lastAns) - 1, v = ReadInt() - 1, k = ReadInt();
        int lca = LCA(u, v);
        printf("%d", lastAns = query(root[u], root[v], root[lca], lca == 0 ? null : root[ancestor[lca][0]], 0, Ws.size(), k));
        if (m) putchar('\n');
    }
    return 0;
}

BZOJ 1146:

//  Created by Sengxian on 4/4/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 1146 主席树 带修改树上路径第 k 大
#include <algorithm>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>
#define mid (((l) + (r)) / 2)
#define lowbit(x) ((x) & -(x))
using namespace std;

inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}
typedef long long ll;

const int maxn = 80000 + 3, maxq = 80000 + 3;
struct SegNode *pit, *null;
struct SegNode {
    SegNode *ls, *rs;
    int s;
    inline void maintain() {
        s = ls->s + rs->s;
    }
    SegNode(): ls(null), rs(null), s(0) {}
}pool[maxn * 85], *root[maxn], *Fen[maxn], *add[100], *dec[100];

void init() {
    pit = pool;
    null = new SegNode();
    null->ls = null, null->rs = null;
}

vector<int> G[maxn], Ws;
int n, m, w[maxn], ancestor[maxn][18], depth[maxn], k[maxn], a[maxn], b[maxn], id[maxn], s[maxn], timestamp = 0;

void compress() {
    for (int i = 0; i < n; ++i)
        Ws.push_back(w[i]);
    for (int i = 0; i < m; ++i) {
        k[i] = ReadInt(), a[i] = ReadInt() - 1, b[i] = ReadInt() - (k[i] > 0);
        if (k[i] == 0) Ws.push_back(b[i]);
    }
    sort(Ws.begin(), Ws.end());
    Ws.erase(unique(Ws.begin(), Ws.end()), Ws.end());
    for (int i = 0; i < n; ++i)
        w[i] = lower_bound(Ws.begin(), Ws.end(), w[i]) - Ws.begin();
    for (int i = 0; i < m; ++i)
        if (k[i] == 0) b[i] = lower_bound(Ws.begin(), Ws.end(), b[i]) - Ws.begin();
}

void process() {
    for (int w = 1; (1 << w) < n; ++w)
        for (int i = 0; i < n; ++i) if (depth[i] - (1 << w) >= 0)
            ancestor[i][w] = ancestor[ancestor[i][w - 1]][w - 1];
}

int LCA(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    int lim = log2(depth[a]);
    for (int i = lim; i >= 0; --i)
        if (depth[a] - (1 << i) >= depth[b])
            a = ancestor[a][i];
    if (a == b) return a;
    for (int i = lim; i >= 0; --i)
        if (depth[a] - (1 << i) >= 0 && ancestor[a][i] != ancestor[b][i]) {
            a = ancestor[a][i];
            b = ancestor[b][i];
        }
    return ancestor[a][0];
}

SegNode* modify(const SegNode *o, int l, int r, int v, int op) {
    if (l >= r) return null;
    SegNode *ne = pit++;
    *ne = *o;
    if (r - l == 1)
        ne->s += op;
    else {
        if (v < mid) ne->ls = modify(ne->ls, l, mid, v, op);
        else ne->rs = modify(ne->rs, mid, r, v, op);
        ne->maintain();
    }
    return ne;
}

int query(SegNode* add[], int addc, SegNode* dec[], int decc, int l, int r, int k) {
    if (r - l == 1) return Ws[l];
    int s = 0;
    for (int i = 0; i < addc; ++i) s += add[i]->ls->s;
    for (int i = 0; i < decc; ++i) s -= dec[i]->ls->s;
    for (int i = 0; i < addc; ++i) 
        if (k <= s) add[i] = add[i]->ls;
        else add[i] = add[i]->rs;
    for (int i = 0; i < decc; ++i)
        if (k <= s) dec[i] = dec[i]->ls;
        else dec[i] = dec[i]->rs;
    if (k <= s) {
        return query(add, addc, dec, decc, l, mid, k);
    }else return query(add, addc, dec, decc, mid, r, k - s);
}

int dfs(const SegNode *o, int u, int fa) {
    id[u] = timestamp++, ancestor[u][0] = fa, depth[u] = fa == -1 ? 0 : depth[fa] + 1;
    root[u] = modify(o, 0, Ws.size(), w[u], 1);
    s[u] = 1;
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (v != fa) s[u] += dfs(root[u], v, u);
    }
    return s[u];
}

int main() {
    init();
    n = ReadInt(), m = ReadInt();
    for (int i = 0; i < n; ++i)
        w[i] = ReadInt();
    for (int i = 0; i < n - 1; ++i) {
        int f = ReadInt() - 1, t = ReadInt() - 1;
        G[f].push_back(t);
        G[t].push_back(f);
    }
    compress();
    dfs(null, 0, -1);
    process();
    for (int i = 1; i <= n; ++i) Fen[i] = null;
    for (int i = 0; i < m; ++i) {
        if (k[i] == 0) {
            for (int j = id[a[i]] + 1; j <= n; j += lowbit(j)) {
                Fen[j] = modify(Fen[j], 0, Ws.size(), w[a[i]], -1);
                Fen[j] = modify(Fen[j], 0, Ws.size(), b[i], 1);
            }
            for (int j = id[a[i]] + s[a[i]] + 1; j <= n; j += lowbit(j)) {
                Fen[j] = modify(Fen[j], 0, Ws.size(), w[a[i]], 1);
                Fen[j] = modify(Fen[j], 0, Ws.size(), b[i], -1);
            }
            w[a[i]] = b[i];
        }else {
            int lca = LCA(a[i], b[i]), length = depth[a[i]] + depth[b[i]] - 2 * depth[lca] + 1;
            k[i] = length - k[i] + 1;
            if (k[i] <= 0) puts("invalid request!");
            else {
                int addc = 0, decc = 0;
                add[addc++] = root[a[i]], add[addc++] = root[b[i]], dec[decc++] = root[lca];
                if (lca) dec[decc++] = root[ancestor[lca][0]];
                for (int j = id[a[i]] + 1; j > 0; j -= lowbit(j)) add[addc++] = Fen[j];
                for (int j = id[b[i]] + 1; j > 0; j -= lowbit(j)) add[addc++] = Fen[j];
                for (int j = id[lca] + 1; j > 0; j -= lowbit(j)) dec[decc++] = Fen[j];
                if (lca) for (int j = id[ancestor[lca][0]] + 1; j > 0; j -= lowbit(j)) dec[decc++] = Fen[j];
                printf("%d\n", query(add, addc, dec, decc, 0, Ws.size(), k[i]));
            }
        }
    }
    return 0;
}