BZOJ 3110 - [Zjoi2013]K大数查询
Published on 2016-03-31描述
有 个位置, 个操作。操作有两种,每次操作如果是 1 a b c
的形式表示在第 个位置到第 个位置,每个位置加入一个数 。如果是 2 a b c
形式,表示询问从第 个位置到第 个位置,第 大的数是多少。
样例输入
2 5 1 1 2 1 1 1 2 2 2 1 1 2 2 1 1 1 2 1 2 3
样例输出
1 2 1
分析
一道树套树的题,外层是权值线段树,里层是普通区间线段树。
对于权值线段树的节点 表示权值区间 ,其对应的普通线段树的节点 表示序列 中一共有多少个在权值区间 的树。
这样不难得到我们的查询算法,要查 的第 大,如果权值线段树根的右儿子代表的线段树区间 的和为 ,如果 大于 ,说明第 大在右儿子代表的权值区间。否则在左儿子代表的权值区间上面。
修改也很好修改,只有一个区间加标记,如果要在 中加一个 ,那么应该在外层线段树中将所有包含权值 的节点对应的线段树的 区间全部 +1。
剩下唯一的问题就是空间,理论上需要 的空间,我们可以动态开点,未开的点给到 null
,如果查询的时候走到 null
,不需新建直接返回 0;如果修改的时候走到 null
,那就新建节点,每次操作第一层最多影响 个节点,第二层最对影响 个节点,所以总空间复杂度是 。
3.8 号新加入了一组嘿嘿嘿的数据,好多人挂了。注意到 ,那么最多可以加 个节点!爆了 int
,解决办法是换成 unsigned int
!
代码
// Created by Sengxian on 3/30/16. // Copyright (c) 2016年 Sengxian. All rights reserved. // BZOJ 3110 树套树 #include <algorithm> #include <iostream> #include <cctype> #include <cassert> #include <cstring> #include <cstdio> #include <vector> using namespace std; inline int ReadInt() { int ch = getchar(), n = 0; bool flag = false; while (!isdigit(ch)) flag |= ch == '-', ch = getchar(); while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar(); return flag ? -n : n; } typedef long long ll; int n, m, vn, a, b, v; #define mid ((l + r) / 2) struct SegNode *null; struct SegNode { SegNode *ls, *rs; ll sum, addv; SegNode(): ls(null), rs(null), sum(0), addv(0) {} inline void plus(int len, int x) {sum += len * x, addv += x;} inline void pushdown(int l, int r) {if (addv) ls->plus(mid - l, addv), rs->plus(r - mid, addv), addv = 0;} inline void pushup() {sum = ls->sum + rs->sum;} inline void maintain(int l, int r) { if (r - l == 1) sum = addv; else sum = ls->sum + rs->sum + addv * (r - l); } }; void modify1D(SegNode* &o, int l, int r) { if (l >= b || r <= a) return; if (o == null) o = new SegNode(); if (l >= a && r <= b) o->addv++; else modify1D(o->ls, l, mid), modify1D(o->rs, mid, r); o->maintain(l, r); } ll query1D(SegNode* &o, int l, int r, ll add = 0) { if (l >= b || r <= a) return 0; if (l >= a && r <= b) return o->sum + add * (r - l); return query1D(o->ls, l, mid, add + o->addv) + query1D(o->rs, mid, r, add + o->addv); } void build1D(SegNode* &o, int l, int r) { o = new SegNode(); if (r - l > 1) build1D(o->ls, l, mid), build1D(o->rs, mid, r); } struct SegNode2D *null2D; struct SegNode2D { SegNode2D *ls, *rs; SegNode *val; SegNode2D(): ls(null2D), rs(null2D), val(null) {} }*root; void modify2D(SegNode2D* &o, int l, int r) { if (o == null2D) o = new SegNode2D(); if (r - l > 1) { if (v < mid) modify2D(o->ls, l, mid); else modify2D(o->rs, mid, r); } modify1D(o->val, 0, n); } ll Query2D(SegNode2D* o, int l, int r, int k) { if (r - l == 1) return l; ll s = query1D(o->rs->val, 0, n); if (k <= s) return Query2D(o->rs, mid, r, k); else return Query2D(o->ls, l, mid, k - s); } void init_null() { null = new SegNode(), null->ls = null, null->rs = null, null->sum = null->addv = 0; null2D = new SegNode2D(), null2D->ls = null2D, null2D->rs = null2D, null2D->val = null; root = null2D; } const int maxm = 50000 + 3; vector<int> cs; int opers[maxm], as[maxm], bs[maxm], vs[maxm]; int compress() { for (int i = 0; i < m; ++i) { opers[i] = ReadInt(), as[i] = ReadInt(), bs[i] = ReadInt(), vs[i] = ReadInt(); if (opers[i] == 1) cs.push_back(vs[i]); } sort(cs.begin(), cs.end()); cs.erase(unique(cs.begin(), cs.end()), cs.end()); for (int i = 0; i < m; ++i) if (opers[i] == 1) vs[i] = lower_bound(cs.begin(), cs.end(), vs[i]) - cs.begin(); return cs.size(); } int main() { init_null(); n = ReadInt(), m = ReadInt(); vn = compress(); for (int i = 0; i < m; ++i) { int oper = opers[i]; a = as[i] - 1, b = bs[i], v = vs[i]; if (oper == 1) { modify2D(root, 0, vn); } else if (oper == 2) { printf("%d\n", cs[Query2D(root, 0, vn, v)]); } else assert(false); } return 0; }