BZOJ 3196 - Tyvj 1730 二逼平衡树

Published on 2016-03-31

题目地址

描述

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询 kk 在区间内的排名。(若有相同的数,输出排名最小的)
  2. 查询区间内排名为 kk 的值。
  3. 修改某一位值上的数值。
  4. 查询 kk 在区间内的前驱(前驱定义为小于 xx,且最大的数)。
  5. 查询 kk 在区间内的后继(后继定义为大于 xx,且最小的数)。

样例输入

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

样例输出

2
4
3
4
9

分析

直接线段树套不旋转 Treap 了,tyvj 上最后两个点 TLE(据说只有带垃圾回收的权值线段树套区间线段树能过?)不管了。
我们用线段树套上 Treap,线段树每个节点 uu 对应的区间 [l,r)[l, r) 对应包含区间 [l,r)[l, r) 所有值的一棵以值为序的 Treap。接着考虑一下操作。

  1. kk 在区间 [a,b][a, b] 的排名:在线段树上分解为若干区间,然后在区间中查有多少个数大于 kk,累加和即可。复杂度 O(log2n)O(\log^2 n)
  2. 查第 kk 小:在线段树上分解为若干区间,由于区间被分解,无法按照平衡树的方式查询,那么二分答案,转化为判定性问题,若 tt 的排名是第一个大于 kk 的,那么 t1t - 1 便是答案。复杂度 O(log3n)O(\log^3 n)
  3. 修改某一位的值:在线段树中找到包含 pospos 的所有区间,在对应的 Treap 中改即可。复杂度 O(logn)O(\log n),常数是很大的。
  4. 查前驱:在线段树上分解为若干区间,在每个 Treap 里面找前驱,取最大的即可。复杂度 O(logn)O(\log n)
  5. 查后继:在线段树上分解为若干区间,在每个 Treap 里面找后继,取最小的即可。复杂度 O(logn)O(\log n)

说一下维护平衡树的几个技巧:首先是处理重复的值,对于所有值,我们都插入,找的时候用这两种方式处理:

//查询最后一个小于 x 的数在 o 中是第几大
int getKth(const Treap* o, int x) {
    if (o == null) return 0;
    return o->key >= x ? getKth(o->ls, x) : getKth(o->rs, x) + o->ls->s + 1;
}
//查询最后一个小于等于 x 的数在 o 中是第几大
int getKth1(const Treap* o, int x) {
    if (o == null) return 0;
    return o->key > x ? getKth1(o->ls, x) : getKth1(o->rs, x) + o->ls->s + 1;
}

然后就是由于 vv 在子区间的前驱后继可能没有,于是这样可以避免判断有没有:

int findKth(const Treap* o, int k) {
    if (k <= 0) return INT_MIN;
    if (k > o->s) return INT_MAX;
    int s = o->ls->s;
    if (k == s + 1) return o->key;
    else if (k <= s) return findKth(o->ls, k);
    return findKth(o->rs, k - s - 1);
}

kk 小于 0,只可能是前驱不合法,由于前驱是取 max,所以返回 ;后继就是反过来。

代码

树套树,未能 AC 55555,分块大法直接上了。

//  Created by Sengxian on 3/31/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 3196 分块
#include <algorithm>
#include <iostream>
#include <cctype>
#include <climits>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>
#define mid (((l) + (r)) / 2)
using namespace std;

inline int ReadInt() {
    static int n, ch;
    static bool flag;
    n = 0, ch = getchar(), flag = false;
    while (!isdigit(ch)) flag |= ch == '-', ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return flag ? -n : n;
}

const int maxn = 50000 + 3, SIZE = 883;
int block[maxn / SIZE + 1][SIZE], a[maxn];
int n, m, b = 0, j = 0;

void init() {
    n = ReadInt(), m = ReadInt();
    for (int i = 0; i < n; ++i) {
        block[b][j] = a[i] = ReadInt();
        if (++j == SIZE) b++, j = 0;
    }
    for (int i = 0; i < b; ++i) sort(block[i], block[i] + SIZE);
    if (j) sort(block[b], block[b] + j);
}

inline int getRank(int l, int r, int v) {
    int lb = l / SIZE, rb = r / SIZE, k = 0;
    if (lb == rb) {
        for (int i = l; i <= r; ++i) if (a[i] < v) k++;
    }else {
        for (int i = l; i < (lb + 1) * SIZE; ++i) if (a[i] < v) k++;
        for (int i = rb * SIZE; i <= r; ++i) if (a[i] < v) k++;
        for (int i = lb + 1; i < rb; ++i) k += lower_bound(block[i], block[i] + SIZE, v) - block[i];
    }
    return k + 1;
}

inline int Kth(int L, int R, int k) {
    int l = -1, r = 1e8 + 10;
    while (r - l > 1) {
        if (getRank(L, R, mid) > k) r = mid;
        else l = mid;
    }
    return r - 1;
}

inline void modify(int p, int v) {
    if (a[p] == v) return;
    int *B = block[p / SIZE], old = a[p], sz = (p / SIZE) == b ? j : SIZE;
    a[p] = v, p = lower_bound(B, B + sz, old) - B; //找 old 啊小伙子!!!
    B[p] = v;
    if (old < v) while (p + 1 < SIZE && B[p] > B[p + 1]) swap(B[p], B[p + 1]), p++;
    else while (p - 1 >= 0 && B[p] < B[p - 1]) swap(B[p], B[p - 1]), p--;
}

inline int pre(int l, int r, int v) {
    int lb = l / SIZE, rb = r / SIZE, val = INT_MIN, idx;
    if (lb == rb) {
        for (int i = l; i <= r; ++i) if (a[i] < v) val = max(val, a[i]);
    }else {
        for (int i = l; i < (lb + 1) * SIZE; ++i) if (a[i] < v) val = max(val, a[i]);
        for (int i = rb * SIZE; i <= r; ++i) if (a[i] < v) val = max(val, a[i]);
        for (int i = lb + 1; i < rb; ++i) {
            idx = lower_bound(block[i], block[i] + SIZE, v) - block[i];
            if (idx > 0) val = max(val, block[i][idx - 1]);    
        }
    }
    return val;
}

inline int post(int l, int r, int v) {
    int lb = l / SIZE, rb = r / SIZE, val = INT_MAX, idx;
    if (lb == rb) {
        for (int i = l; i <= r; ++i) if (a[i] > v) val = min(val, a[i]);
    }else {
        for (int i = l; i < (lb + 1) * SIZE; ++i) if (a[i] > v) val = min(val, a[i]);
        for (int i = rb * SIZE; i <= r; ++i) if (a[i] > v) val = min(val, a[i]);
        for (int i = lb + 1; i < rb; ++i) {
            idx = upper_bound(block[i], block[i] + SIZE, v) - block[i];
            if (idx < SIZE) val = min(val, block[i][idx]);
        }
    }
    return val;
}


int main() {
    init();
    int l, r;
    while (m--) {
        int opt = ReadInt();
        if (opt == 1) {
            l = ReadInt() - 1, r = ReadInt() - 1;
            printf("%d\n", getRank(l, r, ReadInt()));
        } else if (opt == 2) {
            l = ReadInt() - 1, r = ReadInt() - 1;
            printf("%d\n", Kth(l, r, ReadInt()));
        } else if (opt == 3) {
            l = ReadInt() - 1;
            modify(l, ReadInt());
        } else if (opt == 4) {
            l = ReadInt() - 1, r = ReadInt() - 1;
            printf("%d\n", pre(l, r, ReadInt()));
        } else if (opt == 5) {
            l = ReadInt() - 1, r = ReadInt() - 1;
            printf("%d\n", post(l, r, ReadInt()));
        }
    }
    return 0;
}

未能 AC 的树套树

//  Created by Sengxian on 3/30/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//    BZOJ 3196 线段树套 Treap
#include <algorithm>
#include <iostream>
#include <cctype>
#include <cassert>
#include <climits>
#include <cstring>
#include <cstdio>
#include <vector>
#include <ctime>
#define mid (((l) + (r)) / 2)
using namespace std;

inline int ReadInt() {
    int n = 0, ch = getchar(); bool flag = false;
    while (!isdigit(ch)) flag |= ch == '-', ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return flag ? -n : n;
}

const int maxn = 100000 + 10, INF = 0x3f3f3f3f;
struct Treap *null, *pit;
struct Treap {
    Treap *ls, *rs;
    int key, val, s;
    inline void maintain() {s = ls->s + rs->s + 1;}
    Treap(int key = 0, int val = rand()): ls(null), rs(null), key(key), val(val), s(1) {}
}pool[maxn * 20], *root[maxn], *stack[maxn];

Treap *newNode(int key = 0, int val = rand()) {
    pit = new Treap(key);
    pit->ls = null, pit->rs = null, pit->key = key, pit->val = val, pit->s = 1;
    return pit++;
}

typedef pair<Treap*, Treap*> Droot;
Treap* merge(Treap *a, Treap *b) {
    if (a == null) return b;
    if (b == null) return a;
    if (a->val < b->val) {
        a->rs = merge(a->rs, b);
        a->maintain();
        return a;
    }else {
        b->ls = merge(a, b->ls);
        b->maintain();
        return b;
    }
}

Droot split(Treap *o, int k) {
    Droot d(null, null);
    if (o == null) return d;
    int s = o->ls->s;
    if (k <= s) {
        d = split(o->ls, k);
        o->ls = d.second;
        o->maintain();
        d.second = o;
    }else {
        d = split(o->rs, k - s - 1);
        o->rs = d.first;
        o->maintain();
        d.first = o;
    }
    return d;
}
Treap* build(int n, int *a) {
    Treap *root = new Treap(-INF, -INF);
    stack[0] = root; int sz = 1;
    for (int i = 0; i < n; ++i) {
        Treap *now = newNode(a[i]); int p = sz - 1;
        while (stack[p]->val > now->val) stack[p--]->maintain();
        now->ls = stack[p]->rs, stack[p]->rs = now;
        sz = p + 1;
        stack[sz++] = now;
    }
    while (sz) stack[--sz]->maintain();
    return root->rs;
}
//查询最后一个小于 x 的数在 o 中是第几大
int getKth(const Treap* o, int x) {
    if (o == null) return 0;
    return o->key >= x ? getKth(o->ls, x) : getKth(o->rs, x) + o->ls->s + 1;
}
//查询最后一个小于等于 x 的数在 o 中是第几大
int getKth1(const Treap* o, int x) {
    if (o == null) return 0;
    return o->key > x ? getKth1(o->ls, x) : getKth1(o->rs, x) + o->ls->s + 1;
}
int findKth(const Treap* o, int k) {
    if (k <= 0) return INT_MIN;
    if (k > o->s) return INT_MAX;
    int s = o->ls->s;
    if (k == s + 1) return o->key;
    else if (k <= s) return findKth(o->ls, k);
    return findKth(o->rs, k - s - 1);
}
void modifyTreap(Treap* &o, int x, int v) {
    Droot l = split(o, getKth(o, x)), r = split(l.second, 1);
    o = merge(l.first, r.second);
    l = split(o, getKth(o, v));
    o = merge(merge(l.first, newNode(v)), l.second);
}
int n, m, arr[maxn], t[maxn];

void buildTree(int o, int l, int r) {
    for (int i = 0; i < r - l; ++i)
        t[i] = arr[l + i];
    sort(t, t + r - l);
    root[o] = build(r - l, t);
    if (r - l > 1) buildTree(o * 2 + 1, l, mid), buildTree(o * 2 + 2, mid, r);
}

int a, b, v;
int getRank(int o, int l, int r) {
    if (l >= b || r <= a) return 0;
    if (l >= a && r <= b) return getKth(root[o], v);
    return getRank(o * 2 + 1, l, mid) + getRank(o * 2 + 2, mid, r);
}

void modify(int o, int l, int r) {
    modifyTreap(root[o], arr[a], v);
    if (r - l > 1) {
        if (a < mid) modify(o * 2 + 1, l, mid);
        else modify(o * 2 + 2, mid, r);
    }
}

int getPre(int o, int l, int r) {
    if (l >= b || r <= a) return INT_MIN;
    if (l >= a && r <= b) return findKth(root[o], getKth(root[o], v));
    return max(getPre(o * 2 + 1, l, mid), getPre(o * 2 + 2, mid, r));
}

int getPost(int o, int l, int r) {
    if (l >= b || r <= a) return INT_MAX;
    if (l >= a && r <= b) return findKth(root[o], getKth1(root[o], v) + 1);
    return min(getPost(o * 2 + 1, l, mid), getPost(o * 2 + 2, mid, r));
}

int main() {
    //freopen("test.in", "r", stdin);
    null = new Treap(), null->s = 0;
    n = ReadInt(), m = ReadInt();
    for (int i = 0; i < n; ++i) arr[i] = ReadInt();
    buildTree(0, 0, n);
    while (m--) {
        int opt = ReadInt();
        if (opt == 1) {
            a = ReadInt() - 1, b = ReadInt(), v = ReadInt();
            printf("%d\n", getRank(0, 0, n) + 1);
        } else if (opt == 2) {
            a = ReadInt() - 1, b = ReadInt();
            int l = -1, r = 1e8 + 10;
            int k = ReadInt();
            while (r - l > 1) {
                v = mid;
                if (getRank(0, 0, n) + 1 > k) r = mid;
                else l = mid;
            }
            printf("%d\n", r - 1);
        } else if (opt == 3) {
            a = ReadInt() - 1, v = ReadInt();
            modify(0, 0, n);
            arr[a] = v;
        } else if (opt == 4) {
            a = ReadInt() - 1, b = ReadInt(), v = ReadInt();
            printf("%d\n", getPre(0, 0, n));
        } else if (opt == 5) {
            a = ReadInt() - 1, b = ReadInt(), v = ReadInt();
            printf("%d\n", getPost(0, 0, n));
        } else assert(false);
    }
    return 0;
}