BZOJ 3110 - [Zjoi2013]K大数查询

Published on 2016-03-31

题目地址

描述

N(N50000)N(N\le 50000) 个位置,M(M50000)M(M\le 50000) 个操作。操作有两种,每次操作如果是 1 a b c 的形式表示在第 aa 个位置到第 bb 个位置,每个位置加入一个数 cc。如果是 2 a b c 形式,表示询问从第 aa 个位置到第 bb 个位置,第 cc 大的数是多少。

样例输入

2 5
1 1 2 1
1 1 2 2
2 1 1 2
2 1 1 1
2 1 2 3

样例输出

1
2
1

分析

一道树套树的题,外层是权值线段树,里层是普通区间线段树。
对于权值线段树的节点 uu 表示权值区间 [l,r)[l, r),其对应的普通线段树的节点 vv 表示序列 [l1,r1)[l_1, r_1) 中一共有多少个在权值区间 [l,r)[l, r) 的树。
这样不难得到我们的查询算法,要查 [a,b][a, b] 的第 kk 大,如果权值线段树根的右儿子代表的线段树区间 [a,b][a, b] 的和为 ss,如果 ss 大于 kk,说明第 kk 大在右儿子代表的权值区间。否则在左儿子代表的权值区间上面。
修改也很好修改,只有一个区间加标记,如果要在 [a,b][a, b] 中加一个 cc,那么应该在外层线段树中将所有包含权值 cc 的节点对应的线段树的 [a,b][a, b] 区间全部 +1。
剩下唯一的问题就是空间,理论上需要 O(n2)O(n^2) 的空间,我们可以动态开点,未开的点给到 null,如果查询的时候走到 null,不需新建直接返回 0;如果修改的时候走到 null,那就新建节点,每次操作第一层最多影响 logn\log n 个节点,第二层最对影响 logn\log n 个节点,所以总空间复杂度是 O(mlog2n)O(m\log^2 n)

3.8 号新加入了一组嘿嘿嘿的数据,好多人挂了。注意到 n,m50000n, m \le 50000,那么最多可以加 25000000002500000000 个节点!爆了 int,解决办法是换成 unsigned int

代码

//  Created by Sengxian on 3/30/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 3110 树套树
#include <algorithm>
#include <iostream>
#include <cctype>
#include <cassert>
#include <cstring>
#include <cstdio>
#include <vector>
using namespace std;

inline int ReadInt() {
    int ch = getchar(), n = 0; bool flag = false;
    while (!isdigit(ch)) flag |= ch == '-', ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return flag ? -n : n;
}

typedef long long ll;

int n, m, vn, a, b, v;
#define mid ((l + r) / 2)
struct SegNode *null;
struct SegNode {
    SegNode *ls, *rs;
    ll sum, addv;
    SegNode(): ls(null), rs(null), sum(0), addv(0) {}
    inline void plus(int len, int x) {sum += len * x, addv += x;}
    inline void pushdown(int l, int r) {if (addv) ls->plus(mid - l, addv), rs->plus(r - mid, addv), addv = 0;}
    inline void pushup() {sum = ls->sum + rs->sum;}
    inline void maintain(int l, int r) {
        if (r - l == 1) sum = addv;
        else sum = ls->sum + rs->sum + addv * (r - l);
    }
};

void modify1D(SegNode* &o, int l, int r) {
    if (l >= b || r <= a) return;
    if (o == null) o = new SegNode();
    if (l >= a && r <= b) o->addv++;
    else modify1D(o->ls, l, mid), modify1D(o->rs, mid, r);
    o->maintain(l, r);
}

ll query1D(SegNode* &o, int l, int r, ll add = 0) {
    if (l >= b || r <= a) return 0;
    if (l >= a && r <= b) return o->sum + add * (r - l);
    return query1D(o->ls, l, mid, add + o->addv) + query1D(o->rs, mid, r, add + o->addv);
}

void build1D(SegNode* &o, int l, int r) {
    o = new SegNode();
    if (r - l > 1) build1D(o->ls, l, mid), build1D(o->rs, mid, r);
}

struct SegNode2D *null2D;
struct SegNode2D {
    SegNode2D *ls, *rs;
    SegNode *val;
    SegNode2D(): ls(null2D), rs(null2D), val(null) {}
}*root;

void modify2D(SegNode2D* &o, int l, int r) {
    if (o == null2D) o = new SegNode2D();
    if (r - l > 1) {
        if (v < mid) modify2D(o->ls, l, mid);
        else modify2D(o->rs, mid, r);
    }
    modify1D(o->val, 0, n);
}

ll Query2D(SegNode2D* o, int l, int r, int k) {
    if (r - l == 1) return l;
    ll s = query1D(o->rs->val, 0, n);
    if (k <= s) return Query2D(o->rs, mid, r, k);
    else return Query2D(o->ls, l, mid, k - s);
}

void init_null() {
    null = new SegNode(), null->ls = null, null->rs = null, null->sum = null->addv = 0;
    null2D = new SegNode2D(), null2D->ls = null2D, null2D->rs = null2D, null2D->val = null;
    root = null2D;
}

const int maxm = 50000 + 3;
vector<int> cs;
int opers[maxm], as[maxm], bs[maxm], vs[maxm];

int compress() {
    for (int i = 0; i < m; ++i) {
        opers[i] = ReadInt(), as[i] = ReadInt(), bs[i] = ReadInt(), vs[i] = ReadInt();
        if (opers[i] == 1) cs.push_back(vs[i]);
    }
    sort(cs.begin(), cs.end());
    cs.erase(unique(cs.begin(), cs.end()), cs.end());
    for (int i = 0; i < m; ++i)
        if (opers[i] == 1) vs[i] = lower_bound(cs.begin(), cs.end(), vs[i]) - cs.begin();
    return cs.size();
}

int main() {
    init_null();
    n = ReadInt(), m = ReadInt();
    vn = compress();
    for (int i = 0; i < m; ++i) {
        int oper = opers[i];
        a = as[i] - 1, b = bs[i], v = vs[i];
        if (oper == 1) {
            modify2D(root, 0, vn);
        } else if (oper == 2) {
            printf("%d\n", cs[Query2D(root, 0, vn, v)]);
        } else assert(false);
    }
    return 0;
}