主席树
Published on 2016-04-04最近搞了搞主席树,发现比想象中简单,又发现网上的讲解几乎看不懂,于是自己写一篇简易的指南,较难的问题慢慢补吧。
介绍
主席树是什么玩意呢?它是“函数式版本的线段树”,说的准确一点,他是 棵完整的权值线段树,但是这 棵树之间共用一些节点,使得内存开销仅为 ,由于权值线段树之间可以加减,所以我们可以得到序列任意区间的一棵权值线段树。
求区间第 K 大
我们用求区间第 K 大来理解主席树。
我们先来试着用权值线段树求整个序列的第 大。
如果线段树的节点 表示区间 的话,那么这个节点的值就是整个序列中有多少个数在区间 ,容易看出,线段树的大小与序列的最大值有关,所以我们要离散化,尽可能的缩小节点数量。
第 大与求第 小实际上是类似的,我们求第 小:如果 小于等于左儿子的值的话,那么我们知道有不小于 个数在左儿子代表的区间,那么第 小自然在左儿子;反之,第 小是右儿子的第 小。
递归到的叶子节点假如代表区间 的话,那么 便是第 小。
现在考虑求一段序列 的第 大,显然这比整个序列的第 大多了一个维度,可以采取树套树,让之前的权值线段树的节点为一棵普通区间线段树,但这与主席树是无关的。
如果我们能把序列 建成一棵权值线段树,那么我们就可以像刚才一样处理了,但是这样无论时间空间都无法承受,注意到因为节点保存的是出现次数,权值线段树有可加减性!即把各个节点的值相加减,表示的就是两个区间合并/做差的结果。
所以我们可以采取前缀和的思想!建立 棵树,第 棵树 表示把区间 建成的权值线段树,那么区间 ,就可以表示成 ,当然我们不用真建,直接同时查询,相减即可。
但这样时间空间还是不行,一共有 棵树,最坏有 个节点,怎么办呢?
考虑从 得到 的过程,设序列第 个数为 ,那么我们就要对 执行一次单点修改,将 的位置 的值加上 1,那么实际上只有一条链需要被修改!所以,我们完全可以在 上爬,但是我们每次重复利用不修改的部分,这样每次最多影响 个节点,空间变为 ,可以承受了!
模版
BZOJ 3524
注意到使用 null 指针的话,就不需要一开始建所有节点了,而且 null 的左右儿子还是 null,对答案无影响,直接当作完整的树写就行。
// Created by Sengxian on 4/4/16. // Copyright (c) 2016年 Sengxian. All rights reserved. // BZOJ 3932 主席树 #include <algorithm> #include <iostream> #include <cctype> #include <cstring> #include <cstdio> #include <vector> #define mid (((l) + (r)) / 2) using namespace std; inline int ReadInt() { static int n, ch; n = 0, ch = getchar(); while (!isdigit(ch)) ch = getchar(); while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar(); return n; } typedef long long ll; const int maxn = 100000 + 3; struct SegNode *pit, *null; struct SegNode { SegNode *ls, *rs; ll s, sum; inline void maintain() { s = ls->s + rs->s; sum = ls->sum + rs->sum; } void *operator new(size_t) {return pit++;} SegNode(): s(0), sum(0) {ls = null, rs = null;} }pool[maxn * 18 * 2], *root[maxn]; int n, m, S[maxn], E[maxn], P[maxn]; //n 任务总数 m 时间范围 vector<int> ps; void compress() { for (int i = 0; i < m; ++i) ps.push_back(P[i]); sort(ps.begin(), ps.end()); ps.erase(unique(ps.begin(), ps.end()), ps.end()); for (int i = 0; i < m; ++i) P[i] = lower_bound(ps.begin(), ps.end(), P[i]) - ps.begin(); } void init() { pit = pool; null = new SegNode(); null->s = 0, null->sum = 0; null->ls = null, null->rs = null; } struct state { int t, p, op; state(int t, int p, int op): t(t), p(p), op(op) {} bool operator < (const state &s) const { return t < s.t; } }; SegNode *modify(const SegNode *o, int l, int r, int v, int op) { if (l >= r) return null; SegNode *ne = new SegNode(); *ne = *o; if (r - l == 1) { ne->s += op; ne->sum += ps[v] * op; return ne; } if (v < mid) ne->ls = modify(ne->ls, l, mid, v, op); else ne->rs = modify(ne->rs, mid, r, v, op); ne->maintain(); return ne; } ll query(const SegNode *o, int l, int r, int k) { if (l >= r || k == 0) return 0; if (o->s <= k) return o->sum; if (r - l == 1) return (ll)ps[l] * k; if (k <= o->ls->s) return query(o->ls, l, mid, k); else return o->ls->sum + query(o->rs, mid, r, k - o->ls->s); } int main() { init(); m = ReadInt(), n = ReadInt(); for (int i = 0; i < m; ++i) S[i] = ReadInt() - 1, E[i] = ReadInt(), P[i] = ReadInt(); //[S, E) compress(); //离散 vector<state> events; for (int i = 0; i < m; ++i) events.push_back(state(S[i], P[i], 1)), events.push_back(state(E[i], P[i], -1)); //时间结束的时候,需要减! sort(events.begin(), events.end()); for (int i = 0, j = 0; i < n; ++i) { SegNode *now = null; if (i) now = root[i - 1]; while (j < events.size() && events[j].t == i) { now = modify(now, 0, ps.size(), events[j].p, events[j].op); ++j; } root[i] = now; } ll pre = 1, k; while (n--) { int x = ReadInt() - 1, a = ReadInt(), b = ReadInt(), c = ReadInt(); k = 1 + (a * pre + b) % c; printf("%lld\n", pre = query(root[x], 0, ps.size(), k)); } return 0; }
带修改求区间第 K 大
我们刚才用主席树在 的时间内可以求出第 大,但是只是无修改版本,如何动态求区间第 大呢?如果不要求强制在线,可以使用整体二分,单次复杂度 ,还是很优秀的:
--> 整体二分解决 BZOJ 1901
如果要求强制在线,怎么破?主席树依然能派上用场。
我们首先对原序列建立一棵主席树。注意到主席树维护的是前缀序列,所以这棵主席树无论如何是动不了了,既然动不了,我们就维护增量,由于只有权值线段树之间能够相互加减,所以我们维护树状数组套动态开点的权值线段树,这样我们就可以通过加法,得到任意一段区间的增量,再加上原序列,可得到修改过的一段序列的对应的权值线段树。
具体怎么维护增量呢?我们先看简单的情形,即实现对一个序列,实现区间增加,单点查值。
我们还是使用树状数组,如果将区间 加上一个数,那么我们就将数状数组 位置 + v, 位置 -v,那么树状数组的 就表示位置 的增量,加上原序列 的值就是修改成的值。
现在树状数组的区间表示权值线段树,若原来树状数组节点代表 的和,那么这个节点表示的权值线段树的就是序列 的增量,即如果有一个在 内增加 这个数,那么线段树位置 +1,如果删除 这个数,那么线段树位置 -1。
设原主席树为 ,树状数组的前缀和为 ,那么区间 对应的权值线段树就是:
然后就可以放心的跑下去找了。
模版
BZOJ 1901:只写了整体二分版本QAQ
树上路径第 K 大
现在第 K 大跑到树上来了嘿嘿嘿,怎么搞?
很好搞啊,首先我们对根节点建一棵主席树,然后DFS下去,每一个节点在他父亲的主席树上面加点(显然是新建一条链),最后就可以得到建出每个节点到根的对应的主席树。
树上的前缀和之类的,一般与 LCA 有关,那么路径 对应的权值线段树就应该是:
就可以干权值线段树可以干的任何事情了嘿嘿嘿。
如果有修改的话,还是建出主席树(听说这样快),然后与上题一样,用树状数组套动态开点的权值线段树维护每个点到根的链上的增量,这时候每个点的下标应该是它的 DFS 序下标!那么如果把 的权值由 变为 ,那么就在 处减 加 (加减显然对应权值线段树上的修改),在 处加 减 。 为以 为根的子树大小,由于 权值变了,那么受影响的就只有 及其子树到根的增量,其他的都是没变的,DFS 序的好处就是子树在树状数组中是连续的一段,那么开头加结尾减就是完美!
模版
BZOJ 2588(SPOJ COT):
// Created by Sengxian on 4/4/16. // Copyright (c) 2016年 Sengxian. All rights reserved. // BZOJ 2588 树上路径第 k 大 #include <algorithm> #include <iostream> #include <cctype> #include <cstring> #include <cstdio> #include <cmath> #include <vector> #include <queue> #define mid (((l) + (r)) / 2) using namespace std; inline int ReadInt() { static int n, ch; n = 0, ch = getchar(); while (!isdigit(ch)) ch = getchar(); while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar(); return n; } typedef long long ll; const int maxn = 100000 + 3; struct SegNode *pit, *null; struct SegNode { SegNode *ls, *rs; int s; inline void maintain() { s = ls->s + rs->s; } SegNode(): ls(null), rs(null), s(0) {} }pool[maxn * 18], *root[maxn]; void init() { pit = pool; null = new SegNode(); null->ls = null, null->rs = null; } SegNode* modify(const SegNode *o, int l, int r, int v) { if (l >= r) return null; SegNode *ne = pit++; *ne = *o; if (r - l == 1) ne->s++; else { if (v < mid) ne->ls = modify(ne->ls, l, mid, v); else ne->rs = modify(ne->rs, mid, r, v); ne->maintain(); } return ne; } vector<int> G[maxn], Ws; int n, m, w[maxn], ancestor[maxn][18], depth[maxn]; void compress() { for (int i = 0; i < n; ++i) Ws.push_back(w[i]); sort(Ws.begin(), Ws.end()); Ws.erase(unique(Ws.begin(), Ws.end()), Ws.end()); for (int i = 0; i < n; ++i) w[i] = lower_bound(Ws.begin(), Ws.end(), w[i]) - Ws.begin(); } void process() { for (int w = 1; (1 << w) < n; ++w) for (int i = 0; i < n; ++i) if (depth[i] - (1 << w) >= 0) ancestor[i][w] = ancestor[ancestor[i][w - 1]][w - 1]; } int LCA(int a, int b) { if (depth[a] < depth[b]) swap(a, b); int lim = log2(depth[a]); for (int i = lim; i >= 0; --i) if (depth[a] - (1 << i) >= depth[b]) a = ancestor[a][i]; if (a == b) return a; for (int i = lim; i >= 0; --i) if (depth[a] - (1 << i) >= 0 && ancestor[a][i] != ancestor[b][i]) { a = ancestor[a][i]; b = ancestor[b][i]; } return ancestor[a][0]; } int query(const SegNode *a, const SegNode *b, const SegNode *c, const SegNode *d, int l, int r, int k) { if (r - l == 1) return Ws[l]; int s = a->ls->s + b->ls->s - c->ls->s - d->ls->s; if (k <= s) return query(a->ls, b->ls, c->ls, d->ls, l, mid, k); else return query(a->rs, b->rs, c->rs, d->rs, mid, r, k - s); } void dfs(const SegNode *o, int u, int fa) { ancestor[u][0] = fa, depth[u] = fa == -1 ? 0 : depth[fa] + 1; root[u] = modify(o, 0, Ws.size(), w[u]); for (int i = 0; i < (int)G[u].size(); ++i) { int v = G[u][i]; if (v != fa) dfs(root[u], v, u); } } int main() { init(); n = ReadInt(), m = ReadInt(); for (int i = 0; i < n; ++i) w[i] = ReadInt(); compress(); for (int i = 0; i < n - 1; ++i) { int f = ReadInt() - 1, t = ReadInt() - 1; G[f].push_back(t); G[t].push_back(f); } dfs(null, 0, -1); process(); int lastAns = 0; while (m--) { int u = (ReadInt() ^ lastAns) - 1, v = ReadInt() - 1, k = ReadInt(); int lca = LCA(u, v); printf("%d", lastAns = query(root[u], root[v], root[lca], lca == 0 ? null : root[ancestor[lca][0]], 0, Ws.size(), k)); if (m) putchar('\n'); } return 0; }
BZOJ 1146:
// Created by Sengxian on 4/4/16. // Copyright (c) 2016年 Sengxian. All rights reserved. // BZOJ 1146 主席树 带修改树上路径第 k 大 #include <algorithm> #include <cctype> #include <cstring> #include <cstdio> #include <cmath> #include <vector> #include <queue> #define mid (((l) + (r)) / 2) #define lowbit(x) ((x) & -(x)) using namespace std; inline int ReadInt() { static int n, ch; n = 0, ch = getchar(); while (!isdigit(ch)) ch = getchar(); while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar(); return n; } typedef long long ll; const int maxn = 80000 + 3, maxq = 80000 + 3; struct SegNode *pit, *null; struct SegNode { SegNode *ls, *rs; int s; inline void maintain() { s = ls->s + rs->s; } SegNode(): ls(null), rs(null), s(0) {} }pool[maxn * 85], *root[maxn], *Fen[maxn], *add[100], *dec[100]; void init() { pit = pool; null = new SegNode(); null->ls = null, null->rs = null; } vector<int> G[maxn], Ws; int n, m, w[maxn], ancestor[maxn][18], depth[maxn], k[maxn], a[maxn], b[maxn], id[maxn], s[maxn], timestamp = 0; void compress() { for (int i = 0; i < n; ++i) Ws.push_back(w[i]); for (int i = 0; i < m; ++i) { k[i] = ReadInt(), a[i] = ReadInt() - 1, b[i] = ReadInt() - (k[i] > 0); if (k[i] == 0) Ws.push_back(b[i]); } sort(Ws.begin(), Ws.end()); Ws.erase(unique(Ws.begin(), Ws.end()), Ws.end()); for (int i = 0; i < n; ++i) w[i] = lower_bound(Ws.begin(), Ws.end(), w[i]) - Ws.begin(); for (int i = 0; i < m; ++i) if (k[i] == 0) b[i] = lower_bound(Ws.begin(), Ws.end(), b[i]) - Ws.begin(); } void process() { for (int w = 1; (1 << w) < n; ++w) for (int i = 0; i < n; ++i) if (depth[i] - (1 << w) >= 0) ancestor[i][w] = ancestor[ancestor[i][w - 1]][w - 1]; } int LCA(int a, int b) { if (depth[a] < depth[b]) swap(a, b); int lim = log2(depth[a]); for (int i = lim; i >= 0; --i) if (depth[a] - (1 << i) >= depth[b]) a = ancestor[a][i]; if (a == b) return a; for (int i = lim; i >= 0; --i) if (depth[a] - (1 << i) >= 0 && ancestor[a][i] != ancestor[b][i]) { a = ancestor[a][i]; b = ancestor[b][i]; } return ancestor[a][0]; } SegNode* modify(const SegNode *o, int l, int r, int v, int op) { if (l >= r) return null; SegNode *ne = pit++; *ne = *o; if (r - l == 1) ne->s += op; else { if (v < mid) ne->ls = modify(ne->ls, l, mid, v, op); else ne->rs = modify(ne->rs, mid, r, v, op); ne->maintain(); } return ne; } int query(SegNode* add[], int addc, SegNode* dec[], int decc, int l, int r, int k) { if (r - l == 1) return Ws[l]; int s = 0; for (int i = 0; i < addc; ++i) s += add[i]->ls->s; for (int i = 0; i < decc; ++i) s -= dec[i]->ls->s; for (int i = 0; i < addc; ++i) if (k <= s) add[i] = add[i]->ls; else add[i] = add[i]->rs; for (int i = 0; i < decc; ++i) if (k <= s) dec[i] = dec[i]->ls; else dec[i] = dec[i]->rs; if (k <= s) { return query(add, addc, dec, decc, l, mid, k); }else return query(add, addc, dec, decc, mid, r, k - s); } int dfs(const SegNode *o, int u, int fa) { id[u] = timestamp++, ancestor[u][0] = fa, depth[u] = fa == -1 ? 0 : depth[fa] + 1; root[u] = modify(o, 0, Ws.size(), w[u], 1); s[u] = 1; for (int i = 0; i < (int)G[u].size(); ++i) { int v = G[u][i]; if (v != fa) s[u] += dfs(root[u], v, u); } return s[u]; } int main() { init(); n = ReadInt(), m = ReadInt(); for (int i = 0; i < n; ++i) w[i] = ReadInt(); for (int i = 0; i < n - 1; ++i) { int f = ReadInt() - 1, t = ReadInt() - 1; G[f].push_back(t); G[t].push_back(f); } compress(); dfs(null, 0, -1); process(); for (int i = 1; i <= n; ++i) Fen[i] = null; for (int i = 0; i < m; ++i) { if (k[i] == 0) { for (int j = id[a[i]] + 1; j <= n; j += lowbit(j)) { Fen[j] = modify(Fen[j], 0, Ws.size(), w[a[i]], -1); Fen[j] = modify(Fen[j], 0, Ws.size(), b[i], 1); } for (int j = id[a[i]] + s[a[i]] + 1; j <= n; j += lowbit(j)) { Fen[j] = modify(Fen[j], 0, Ws.size(), w[a[i]], 1); Fen[j] = modify(Fen[j], 0, Ws.size(), b[i], -1); } w[a[i]] = b[i]; }else { int lca = LCA(a[i], b[i]), length = depth[a[i]] + depth[b[i]] - 2 * depth[lca] + 1; k[i] = length - k[i] + 1; if (k[i] <= 0) puts("invalid request!"); else { int addc = 0, decc = 0; add[addc++] = root[a[i]], add[addc++] = root[b[i]], dec[decc++] = root[lca]; if (lca) dec[decc++] = root[ancestor[lca][0]]; for (int j = id[a[i]] + 1; j > 0; j -= lowbit(j)) add[addc++] = Fen[j]; for (int j = id[b[i]] + 1; j > 0; j -= lowbit(j)) add[addc++] = Fen[j]; for (int j = id[lca] + 1; j > 0; j -= lowbit(j)) dec[decc++] = Fen[j]; if (lca) for (int j = id[ancestor[lca][0]] + 1; j > 0; j -= lowbit(j)) dec[decc++] = Fen[j]; printf("%d\n", query(add, addc, dec, decc, 0, Ws.size(), k[i])); } } } return 0; }