BZOJ 1878 - [SDOI2009]HH的项链

Published on 2016-08-05

题目地址

描述

n(n50000)n(n\le 50000) 个贝壳排成一排,编号 。每种贝壳的种类为 ai(0ai1000000)a_i(0\le a_i\le 1000000)。现在给出 m(m200000)m(m\le 200000) 个询问,每次询问在 [l,r][l, r] 中,包含了多少种不同的贝壳?

分析

这个题目棘手的地方在于,求的是不同贝壳的种类,也就是说,相同种类的贝壳只能算一次。如果用区间和的话,就会导致重复计算。

这是一道经典题,我就直接说思路了。
对于贝壳 ii,记录下一个种类是 aia_i 的位置 nexti\mathrm{next}_i,如果不存在,nexti=0\mathrm{next}_i = 0
用树状数组维护长度为 nn 的序列,初始值为 00,先将每种贝壳第一次出现的位置上的值 +1。
将询问离线,按照左端点从小到大排序,一开始令指针 j=0j = 0,遍历询问 [li,ri][l_i, r_i],对每个询问如下操作:

  1. 对于 k[j,li)k \in [j, l_i),若 nextk=0\mathrm{next}_k = 0,则不管,否则在树状数组中 nextk\mathrm{next}_k 位置 +1。
  2. j=lij = l_i
  3. 记树状数组的前缀和为 sum(i)\mathrm{sum}(i),则该询问的答案为 sum(ri)sum(li1)\mathrm{sum}(r_i) - \mathrm{sum}(l_i - 1)

正确性证明:由于是按照左端点递增处理询问,可以保证,处理询问 [li,ri][l_i, r_i] 时,所有贝壳种类如果在 [li,n][l_i, n] 中出现,那么其第一次出现的位置必定值为 11,而且该种类其后出现的位置值必定全为 00。这样就保证了在 [li,ri][l_i, r_i] 内的所有贝壳种类,仅在第一次出现的位置值为 11,因此可以使用区间求和来查询。
复杂度: O(nlogn)O(n\log n)

此外,这个题目还可以用莫队算法在 O(nn)O(n\sqrt n) 的时间内解决,这里不做过多的解释了。

代码

//  Created by Sengxian on 8/4/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 1878 离线 + 树状数组
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}

const int maxn = 50000 + 3, maxm = 200000 + 3;
int n, a[maxn], vs[maxn], pos[maxn], ans[maxm], s[maxn];
int nxt[maxn];

struct query {
    int l, r, id;
    inline bool operator < (const query &ano) const {
        return l < ano.l;
    }
}qs[maxm];

#define lowbit(x) ((x) & -(x))
inline void add(int p, int v) {
    v = 1;
    for (int i = p + 1; i <= n; i += lowbit(i)) s[i] += v;
}

inline int sum(int x) {
    int ret = 0;
    for (int i = x + 1; i > 0; i -= lowbit(i)) ret += s[i];
    return ret;
}

int main() {
    n = ReadInt();
    for (int i = 0; i < n; ++i) vs[i] = a[i] = ReadInt();
    sort(vs, vs + n);
    int tmp = unique(vs, vs + n) - vs;
    for (int i = 0; i < n; ++i) a[i] = lower_bound(vs, vs + tmp, a[i]) - vs;

    memset(pos, -1, sizeof pos);
    for (int i = n - 1; ~i; --i) {
        if (~pos[a[i]]) nxt[i] = pos[a[i]];
        pos[a[i]] = i;
    }

    for (int i = 0; i < tmp; ++i) if (~pos[i]) add(pos[i], vs[i]);
    int m = ReadInt();
    for (int i = 0; i < m; ++i)
        qs[i].l = ReadInt() - 1, qs[i].r = ReadInt() - 1, qs[i].id = i;
    sort(qs, qs + m);
    int l = 0;
    for (int i = 0; i < m; ++i) {
        while (l < qs[i].l) {
            if (nxt[l]) add(nxt[l], vs[a[nxt[l]]]);
            l++;
        }
        ans[qs[i].id] = sum(qs[i].r) - sum(qs[i].l - 1);
    }
    for (int i = 0; i < m; ++i) printf("%d\n", ans[i]);
    return 0;
}