BZOJ 4650 - [Noi2016]优秀的拆分

Published on 2016-08-04

题目地址

描述

UOJ 传送门

分析

算法一:枚举 AABB\mathrm{AABB} 串的中心点,则如果记 pre(i)\mathrm{pre}(i) 为在 ii 前面,有多少个以 ii 结尾的 AAAA 串;post(i)\mathrm{post}(i) 为在 ii 后面,有多少个以 ii 开头的 AAAA 串,则我们的答案为:

0in2pre(i)×post(i+1)\sum_{0\le i\le n - 2}\mathrm{pre}(i)\times\mathrm{post}(i + 1)

其中求 pre\mathrm{pre} 以及 post\mathrm{post} 的方法很多,我们只需要一个工具判断两段串是否相等即可。
比较简单的做法是记 Li,jL_{i, j} 为后缀 i,ji, j 的 LCP(Longest Common Prefix,最长公共前缀),则 Li,j=[si=sj](Li+1,j+1+1)L_{i, j} = [s_i = s_j](L_{i + 1, j + 1} + 1),判断两个串相等,只需要判断两个串的起始点代表的后缀的 LCP 是否大于串的长度即可。
复杂度 O(n2)O(n^2),期望得分:95 分。

算法二:延续上一个算法的思路,算法一的瓶颈在于求 pre\mathrm{pre} 以及 post\mathrm{post},我们换一个思路,找出所有形如 AA\mathrm{AA} 的子串。
枚举串 AA\mathrm{AA} 的一半长度 lenlen(也就是只考虑长度为 2len2 * lenAA\mathrm{AA} 串),我们原串每隔长度 lenlen 设置一个关键点,则所有串 AA\mathrm{AA} 必定覆盖两个关键点,而且这两个关键点位于 A\mathrm{A} 的同一个位置,如下图:


接下来我们只考虑求 pre\mathrm{pre} 数组,因为求 post\mathrm{post} 的方法是大同小异的。

枚举相邻的两个关键点 i,i+1i, i + 1,则这一次枚举,将会影响 [(i+1)len,(i+2)len)[(i + 1) * len, (i + 2) * len) 内的 pre\mathrm{pre} 值,因为如果某个 AA\mathrm{AA} 串覆盖关键点 i,i+1i, i + 1 的话,其串末尾只可能落在 [(i+1)len,(i+2)len)[(i + 1) * len, (i + 2) * len) 里面,如下图:

设后缀 ii 与后缀 i+1i + 1 的 LCP 为 xx,前缀 ii 与前缀 i+1i + 1 的 LCS(Longest Common Suffix,最长公共后缀) 为 yy
x+y<lenx + y < len,不存在一个长度为 2len2 * lenAA\mathrm{AA} 串覆盖关键点 i,i+1i, i + 1
x+y=lenx + y = len,存在且仅存在一个长度为 2len2 * lenAA\mathrm{AA} 串覆盖关键点 i,i+1i, i + 1,这个串的末尾的坐标是 (i+1)len+x1(i + 1) * len + x - 1,如下图:

x+y>lenx + y > len,也就是说区间重叠了,这时可能有多个长度为 2len2 * lenAA\mathrm{AA} 串,如下图,淡绿色区间的点都是长度为 2len2 * lenAA\mathrm{AA} 串的末尾点,这个区间为 [(i+1)lenx+len,(i+1)len+y)[(i + 1) * len - x + len, (i + 1) * len + y),如下图:


我们来证一下,显然,我们证明区间的端点是 AA\mathrm{AA} 串的末尾点即可:
对于点 (i+1)lenx+len(i + 1) * len - x + len 作为末尾,这个 AA\mathrm{AA} 串的开头就是最前面,显然成立,如下图,红色的部分相等:

而对于点 (i+1)len+y1(i + 1) * len + y - 1,好像结论不是很显然了,我们还是要证红色的部分相等:

把串剥离出来考虑,可以发现,重叠的部分会导致两个串的首尾一段相等:

从而两个串都是灰色部分 + 绿色部分,相等!

也就是说,每次枚举会导致 [(i+1)len,(i+2)len)[(i + 1) * len, (i + 2) * len) 的一个子区间的 pre\mathrm{pre} + 1,我们差分,将 pre(i)\mathrm{pre}(i) 变为 pre(i)pre(i1)\mathrm{pre}(i) - \mathrm{pre}(i - 1),这样区间加变为两个单点修改,最后求一次前缀和即可。求 post\mathrm{post} 无非是找到 AA\mathrm{AA} 的开头的一段区间 + 1,容易类比求 pre\mathrm{pre} 的过程求出。
LCP + LCS 采用后缀数组 + ST 表实现,枚举到关键点以后单次计算 O(1)O(1),枚举长度为 L,共有 nL\frac n L 个关键点,枚举的总复杂度是 n1+n2+n3+=O(nlogn)\frac n 1 + \frac n 2 + \frac n 3 + \cdots = O(n\log n),所以总复杂度 O(Tnlogn)O(Tn\log n)

代码

//  Created by Sengxian on 8/4/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 4650 NOI 2016 D1T1 后缀数组
# pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int maxn = 30000 + 3;
int logs[maxn], pre[maxn], post[maxn], n;
struct SuffixArray {
    static const int maxNode = maxn;
    int sa[maxNode], rank[maxNode], minHeight[15][maxNode], n;
    char str[maxNode];
    inline void build_sa(int m = 'z' + 3) {
        static int tmpSA[maxNode], rank1[maxNode], rank2[maxNode], cnt[maxNode];
        register int i;
        n = strlen(str) + 1, str[n] = 0;
        memset(cnt, 0, sizeof (int) * m);
        for (i = 0; i < n; ++i) cnt[(int)str[i]]++;
        for (i = 1; i < m; ++i) cnt[i] += cnt[i - 1];
        for (i = 0; i < n; ++i) rank[i] = cnt[(int)str[i]] - 1;
        for (int l = 1; l < n; l <<= 1) {
            for (i = 0; i < n; ++i)
                rank1[i] = rank[i], rank2[i] = i + l < n ? rank[i + l] : 0;
            memset(cnt, 0, sizeof (int) * n);
            for (i = 0; i < n; ++i) cnt[rank2[i]]++;
            for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1];
            for (i = n - 1; ~i; --i) tmpSA[--cnt[rank2[i]]] = i;
            memset(cnt, 0, sizeof (int) * n);
            for (i = 0; i < n; ++i) cnt[rank1[i]]++;
            for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1];
            for (i = n - 1; ~i; --i) sa[--cnt[rank1[tmpSA[i]]]] = tmpSA[i];
            bool unique = true;
            rank[sa[0]] = 0;
            for (i = 1; i < n; ++i) {
                rank[sa[i]] = rank[sa[i - 1]];
                if (rank1[sa[i]] == rank1[sa[i - 1]] && rank2[sa[i]] == rank2[sa[i - 1]]) unique = false;
                else rank[sa[i]]++;
            }
            if (unique) break;
        }
    }
    inline void getHeight() {
        minHeight[0][0] = 0;
        for (int i = 0, j = 0, k = 0; i < n - 1; ++i) {
            if (k) --k;
            j = sa[rank[i] - 1];
            while (str[i + k] == str[j + k]) k++;
            minHeight[0][rank[i]] = k;
        }
        for (int w = 1; (1 << w) <= n; ++w)
            for (int i = 0; i + (1 << w) <= n; ++i)
                minHeight[w][i] = min(minHeight[w - 1][i], minHeight[w - 1][i + (1 << (w - 1))]);
    }
    inline int query(int l, int r) {
        static int bit;
        bit = logs[r - l];
        return min(minHeight[bit][l], minHeight[bit][r - (1 << bit)]);
    }
    inline int LCP(int l, int r) {
        l = rank[l], r = rank[r];
        if (l > r) swap(l, r);
        return query(l + 1, r + 1);
    }
}SA, rSA;

inline int LCP(int i, int j) {
    return SA.LCP(i, j);
}

inline int LCS(int i, int j) {
    return rSA.LCP(n - i - 1, n - j - 1);
}

ll solve() {
    ll ans = 0;
    memset(pre, 0, sizeof (int) * (n + 1));
    memset(post, 0, sizeof (int) * (n + 1));
    for (int len = 1, x, y, l, r; (len << 1) <= n; ++len)
        for (int i = 0, j = len; j < n; i += len, j += len) if (SA.str[i] == SA.str[j]) {
            x = LCS(i, j), y = LCP(i, j), l = max(i - x + len, i), r = min(i + y, j);
            if (r - l >= 1) {
                pre[l + len]++, pre[r + len]--;
                post[l - len + 1]++, post[r - len + 1]--;
            }
        }
    for (int i = 1; i < n; ++i) pre[i] += pre[i - 1], post[i] += post[i - 1];
    for (int i = 0; i < n - 1; ++i)
        ans += (ll)pre[i] * post[i + 1];
    return ans;
}

int main() {
    logs[0] = logs[1] = 0;
    for (int i = 2; i < maxn; ++i) logs[i] = logs[i >> 1] + 1;
    int caseNum; scanf("%d", &caseNum);
    while (caseNum--) {
        scanf("%s", SA.str), n = strlen(SA.str);
        memcpy(rSA.str, SA.str, sizeof SA.str);
        reverse(rSA.str, rSA.str + n);
        SA.build_sa(), rSA.build_sa(), SA.getHeight(), rSA.getHeight();
        printf("%lld\n", solve());
    }
    return 0;
}