BZOJ 3238 - [Ahoi2013]差异

Published on 2016-07-04

题目地址

描述

给定一个长度为 n(n500000)n(n\le 500000) 的字符串,设 TiT_i 为从第 ii 个字符开始的后缀,请你求:

1i<jnlen(Ti)+len(Tj)2len(LCP(Ti,Tj))\sum_{1\le i<j\le n}\mathrm{len}(T_i) + \mathrm{len}(T_j) - 2 * \mathrm{len}(\mathrm{LCP}(T_i, T_j))

分析

对于前面两项,由于每个后缀在枚举的过程中出现 n1n - 1 次,那么可知和为 (n1)n(n+1)2(n - 1) * \frac {n(n + 1)} 2,问题转变为求:

1i<jnlen(LCP(Ti,Tj))\sum_{1\le i<j\le n} \mathrm{len}(\mathrm{LCP}(T_i, T_j))

我们考虑使用后缀数组求解,求出后缀数组对应的高度数组 height\mathrm{height},则两个后缀的 LCP(Ti,Tj)\mathrm{LCP}(T_i, T_j) 的长度应为 minheight[k],k[rank[i]+1,rank[j]]\min \mathrm{height}[k], k \in[\mathrm{rank}[i] + 1, \mathrm{rank}[j]],则上述所求式子转变为:

1i<jnminheight[k],k[i+1,j]\sum_{1\le i<j\le n}\min \mathrm{height}[k], k\in [i + 1, j]

这正好对应着 height\mathrm{height} 数组的所有区间,也就是问 height\mathrm{height} 数组的所有子串的最小值和。如果我们枚举起点,暴力扫描一边,复杂度应该是 O(n2)O(n^2) 的,还是无法承受,我们必须思考更优的做法。
最小值,还有区间,往往让人联想到单调性,我们如果枚举区间末尾 ii,那么可以发现,前面的 height[j]\mathrm{height}[j] 大于 height[i]\mathrm{height}[i] 的一律无效,只有小于 height[i]\mathrm{height}[i] 的才有效,我们可以构造一个单调递增的栈。

栈下标 0 1 2 3
stk[i] 0 2 4 10
height[stk[i]] 0 3 5 7

上表是一个以元素 10 结尾的单调栈,那么以元素 10 为区间末尾的答案为:

(stk[3]stk[2])×7+(stk[2]stk[1])×5+(stk[1]stk[0])×3(stk[3] - stk[2]) \times 7 + (stk[2] - stk[1]) \times 5 + (stk[1] - stk[0]) \times 3

发现可以利用前面的值,来 O(1)O(1) 计算。如果设 fif_i 为元素 ii 的答案的话,容易发现当 stk[i]stk[i] 是栈末尾时,有:

fstk[i]=fstk[i1]+(stk[i]stk[i1])×height(stk[i])f_{stk[i]} = f_{stk[i - 1]} + (stk[i] - stk[i - 1])\times \mathrm{height}(stk[i])

于是整个题就可以 O(n+nlogn)O(n + n\log n) 的求解了,追求理性愉悦的 DC3 预处理后缀数组的话,也可以是 O(n)O(n) 的。

这题非常类似于 HNOI 2016 Day2 T1,在那题中,是询问一个子区间的子串最小值和,必须快速回答询问,于是离线询问,用莫队将问题转变为 [i,j][i, j][i,j+1][i, j + 1] 的问题,仍然可以用单调栈来求解,看来,一个题目的 idea 可以被扩展,就要用相应的算法来解这个扩展,使之回到根本问题上来。

代码

//  Created by Sengxian on 7/4/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 3238 后缀数组
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}

const int maxn = 500000 + 3;

namespace SuffixArray {
    const int maxNode = maxn;
    int rank[maxNode], sa[maxNode], height[maxNode], n;
    char str[maxNode];
    void build_sa(const int m) {
        int rank1[maxNode], rank2[maxNode], tmpSA[maxNode], cnt[maxNode], i;
        n = strlen(str) + 1;
        str[n] = 0;
        memset(cnt, 0, sizeof cnt);
        for (i = 0; i < n; ++i) cnt[(int)str[i]]++;
        for (i = 1; i < m; ++i) cnt[i] += cnt[i - 1];
        for (i = 0; i < n; ++i) rank[i] = cnt[(int)str[i]] - 1;
        for (int l = 1; l < n; l <<= 1) {
            for (int i = 0; i < n; ++i)
                rank1[i] = rank[i], rank2[i] = i + l < n ? rank[i + l] : 0;
            memset(cnt, 0, sizeof cnt);
            for (i = 0; i < n; ++i) cnt[rank2[i]]++;
            for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1];
            for (i = n - 1; i >= 0; --i) tmpSA[--cnt[rank2[i]]] = i;
            memset(cnt, 0, sizeof cnt);
            for (i = 0; i < n; ++i) cnt[rank1[i]]++;
            for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1];
            for (i = n - 1; i >= 0; --i) sa[--cnt[rank1[tmpSA[i]]]] = tmpSA[i];
            bool unique = true;
            rank[sa[0]] = 0;
            for (i = 1; i < n; ++i) {
                rank[sa[i]] = rank[sa[i - 1]];
                if (rank1[sa[i]] == rank1[sa[i - 1]] && rank2[sa[i]] == rank2[sa[i - 1]]) unique = false;
                else rank[sa[i]]++;
            }
            if (unique) break;
        }
    }
    void getHeight() {
        int k = 0;
        for (int i = 0; i < n - 1; ++i) {
            if (k) --k;
            int j = sa[rank[i] - 1];
            while (str[i + k] == str[j + k]) k++;
            height[rank[i]] = k;
        }
    }
};

int n, *height, stk[maxn];
ll f[maxn];
char *str;

int main() {
    str = SuffixArray::str;
    height = SuffixArray::height + 1;
    scanf("%s", str);
    n = strlen(str);
    SuffixArray::build_sa('z' + 3);
    SuffixArray::getHeight();
    int sz = 0;
    ll ans = 0;
    stk[sz++] = 0;
    for (int i = 1; i < n; ++i) {
        while (sz && height[stk[sz - 1]] > height[i]) sz--;
        ans += f[i] = f[stk[sz - 1]] + (ll)(i - stk[sz - 1]) * height[i];
        stk[sz++] = i;
    }
    printf("%lld\n", (ll)(n - 1) * n * (n + 1) / 2 - 2 * ans);
    return 0;
}