BZOJ 4408 - [Fjoi 2016]神秘数

Published on 2017-01-11

题目地址

描述

一个可重复数字集合 SS 的神秘数定义为最小的不能被 SS 的子集的和表示的正整数。例如 S={1,1,1,4,13}S = \{1,1,1,4,13\}1177 均可以被表示,而 88 是最小的无法表示为集合 SS 的子集的和的数,故集合 SS 的神秘数为 88

现给定 n(n105)n(n \le {10}^5) 个正整数 m(m105)m(m\le {10}^5) 个询问,每次询问给定一个区间 [l,r](lr)[l,r](l\le r),求由 所构成的可重复数字集合的神秘数。

分析

遇到这种区间问题,先想想如何快速对一个序列解决,再将算法扩展到区间的询问。

考虑序列 的神秘数,不妨让它们从小到大排序,即 ,显然有 a1=1a_1 = 1,否则答案就是 11。我们考虑第二个数,第二个数只能在 {1,2}\{1, 2\} 中间选择,否则就凑不出来 22 了。按照这个思路,我们可以得到这样一个结论(设 sis_i 为前缀和)。

结论:假设前 ii 个数能够表示 [1,si][1, s_i] 的所有正整数,考虑加入的下一个数 ai+1a_{i + 1},如果 ai+1a_{i + 1} 大于 si+1s_i + 1,答案就是 si+1s_i + 1,否则前 ii 个数可以表示 [1,si+ai+1]=[1,si+1][1, s_i + a_{i + 1}] = [1, s_{i + 1}] 的所有正整数。

这个结论很好用归纳法证明,这里就不证明了。也就是说,如果整个序列都满足 si+1ai+1s_i + 1 \ge a_{i + 1},那么答案就是 sn+1s_n + 1,否则找到第一个不满足的位置 iisi+1s_i + 1 就是答案。

而第一个不满足的位置 ii,可以用如下的二分方法找到:

首先设定 i=1i = 1ii 的意义是对于 1j<i1 \le j < i 的所有 jj,都满足 sj+1aj+1s_j + 1 \ge a_{j + 1}。每次二分查找最大的 jj 满足 ajsi+1a_j\le s_i+1,如果 j=ij = i,那么 ii 就是第一个不满足的位置,否则令 i=ji = j,继续迭代。当 i=ni = n 时停止迭代。

这个算法的核心是对当前的 sis_i 找到最大的能“覆盖”的位置,从而不断的进行迭代。由于每迭代两次,当前 sis_i 至少翻倍(因为找到的 jj 满足 ajsi+1a_j \le s_i+1,下次迭代找到的数一定比 sis_i 大),所以最多只会迭代 O(logsn)O(\log s_n) 次,而二分的复杂度是 O(logn)O(\log n),这就得到时间复杂度上界的递归式:

T(n)=T(n2)+O(logn) T(n) = T(\frac n 2) + O(\log n)

不妨设 n=2kn = 2^k,递归式可以改写为:

T(2k)=T(2k1)+O(k) T(2^k) = T(2^{k - 1}) + O(k)

解出 T(n)=O(k2)T(n) = O(k^2),那么就有 T(n)=O(log2n)T(n) = O(\log^2 n)。这就意味着给定一个单调不降序列以及序列的前缀和,就能 O(log2n)O(\log^2 n) 求出答案。

接着扩展到区间询问,使用主席树能够求区间前 kk 小的数的前缀和,将迭代内「二分查找最大的 jj 满足 ajsi+1a_j\le s_i+1」改为主席树上的二分,则单次询问复杂度不变,仍为 O(log2n)O(\log^2 n)

总复杂度:O(nlogn+mlog2n)O(n\log n + m\log^2 n)

代码

//  Created by Sengxian on 2017/01/11.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4408 二分 + 主席树
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 100000 + 3;
int n, a[MAX_N];
vector<int> values;

namespace PresidentTree {
    struct Node {
        Node *ls, *rs;
        int s, cnt;

        inline void pushUp() {
            this->s = ls->s + rs->s;
            this->cnt = ls->cnt + rs->cnt;
        }
    } pool[MAX_N * 50], *pit, *null, *root[MAX_N];

    void init() {
        pit = pool;
        null = pit++;
        null->ls = null->rs = null;
        null->s = null->cnt = 0;
    }

    #define mid (((l) + (r)) >> 1)
    Node *modify(const Node *o, int l, int r, int pos) {
        Node *v = new (pit++) Node(*o);
        if (r - l == 1) {
            v->s += values[pos], v->cnt++;
        } else {
            if (pos < mid) v->ls = modify(v->ls, l, mid, pos);
            else v->rs = modify(v->rs, mid, r, pos);
            v->pushUp();
        }
        return v;
    }

    int sum(const Node *lt, const Node *rt, int l, int r, int k) {
        if (r - l == 1) return k * values[l];
        int cnt = rt->ls->cnt - lt->ls->cnt;
        if (cnt >= k) return sum(lt->ls, rt->ls, l, mid, k);
        else return rt->ls->s - lt->ls->s + sum(lt->rs, rt->rs, mid, r, k - cnt);
    }

    int query(const Node *lt, const Node *rt, int l, int r, int s) {
        if (r - l == 1) return rt->cnt - lt->cnt;
        if (s < values[mid]) return query(lt->ls, rt->ls, l, mid, s);
        else return query(lt->rs, rt->rs, mid, r, s) + rt->ls->cnt - lt->ls->cnt;
    }
    #undef mid
}

using namespace PresidentTree;

void compress() {
    for (int i = 0; i < n; ++i) values.push_back(a[i]);
    sort(values.begin(), values.end());
    values.erase(unique(values.begin(), values.end()), values.end());
    for (int i = 0; i < n; ++i) a[i] = lower_bound(values.begin(), values.end(), a[i]) - values.begin();
}

void prepare() {
    init();
    compress();
    root[0] = null;
    for (int i = 0; i < n; ++i) {
        root[i + 1] = modify(root[i], 0, values.size(), a[i]);
    }
}

int solve(int L, int R) {
    int len = R - L;
    if (sum(root[L], root[R], 0, values.size(), 1) != 1) return 1;

    int pos = 0;
    while (pos + 1 != len) {
        int l = pos, r = len, s = sum(root[L], root[R], 0, values.size(), pos + 1);
        int newPos = query(root[L], root[R], 0, values.size(), s + 1) - 1;
        if (newPos == pos) return s + 1;
        else pos = newPos;
    }

    return sum(root[L], root[R], 0, values.size(), pos + 1) + 1;
}

int main() {
#ifdef DEBUG
    freopen("test.in", "r", stdin);
#endif
    n = readInt();
    for (int i = 0; i < n; ++i) a[i] = readInt();

    prepare();

    int q = readInt();
    while (q--) {
        int l = readInt() - 1, r = readInt();
        printf("%d\n", solve(l, r));
    }
    return 0;
}