BZOJ 4650 - [Noi2016]优秀的拆分
Published on 2016-08-04描述
分析
算法一:枚举 串的中心点,则如果记 为在 前面,有多少个以 结尾的 串; 为在 后面,有多少个以 开头的 串,则我们的答案为:
其中求 以及 的方法很多,我们只需要一个工具判断两段串是否相等即可。
比较简单的做法是记 为后缀 的 LCP(Longest Common Prefix,最长公共前缀),则 ,判断两个串相等,只需要判断两个串的起始点代表的后缀的 LCP 是否大于串的长度即可。
复杂度 ,期望得分:95 分。
算法二:延续上一个算法的思路,算法一的瓶颈在于求 以及 ,我们换一个思路,找出所有形如 的子串。
枚举串 的一半长度 (也就是只考虑长度为 的 串),我们原串每隔长度 设置一个关键点,则所有串 必定覆盖两个关键点,而且这两个关键点位于 的同一个位置,如下图:
接下来我们只考虑求 数组,因为求 的方法是大同小异的。
枚举相邻的两个关键点 ,则这一次枚举,将会影响 内的 值,因为如果某个 串覆盖关键点 的话,其串末尾只可能落在 里面,如下图:
设后缀 与后缀 的 LCP 为 ,前缀 与前缀 的 LCS(Longest Common Suffix,最长公共后缀) 为 。
若 ,不存在一个长度为 的 串覆盖关键点 。
若 ,存在且仅存在一个长度为 的 串覆盖关键点 ,这个串的末尾的坐标是 ,如下图:
若 ,也就是说区间重叠了,这时可能有多个长度为 的 串,如下图,淡绿色区间的点都是长度为 的 串的末尾点,这个区间为 ,如下图:
我们来证一下,显然,我们证明区间的端点是 串的末尾点即可:
对于点 作为末尾,这个 串的开头就是最前面,显然成立,如下图,红色的部分相等:
而对于点 ,好像结论不是很显然了,我们还是要证红色的部分相等:
把串剥离出来考虑,可以发现,重叠的部分会导致两个串的首尾一段相等:
从而两个串都是灰色部分 + 绿色部分,相等!
也就是说,每次枚举会导致 的一个子区间的 + 1,我们差分,将 变为 ,这样区间加变为两个单点修改,最后求一次前缀和即可。求 无非是找到 的开头的一段区间 + 1,容易类比求 的过程求出。
LCP + LCS 采用后缀数组 + ST 表实现,枚举到关键点以后单次计算 ,枚举长度为 L,共有 个关键点,枚举的总复杂度是 ,所以总复杂度 。
代码
// Created by Sengxian on 8/4/16. // Copyright (c) 2016年 Sengxian. All rights reserved. // BZOJ 4650 NOI 2016 D1T1 后缀数组 # pragma GCC optimize("O3") #include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 30000 + 3; int logs[maxn], pre[maxn], post[maxn], n; struct SuffixArray { static const int maxNode = maxn; int sa[maxNode], rank[maxNode], minHeight[15][maxNode], n; char str[maxNode]; inline void build_sa(int m = 'z' + 3) { static int tmpSA[maxNode], rank1[maxNode], rank2[maxNode], cnt[maxNode]; register int i; n = strlen(str) + 1, str[n] = 0; memset(cnt, 0, sizeof (int) * m); for (i = 0; i < n; ++i) cnt[(int)str[i]]++; for (i = 1; i < m; ++i) cnt[i] += cnt[i - 1]; for (i = 0; i < n; ++i) rank[i] = cnt[(int)str[i]] - 1; for (int l = 1; l < n; l <<= 1) { for (i = 0; i < n; ++i) rank1[i] = rank[i], rank2[i] = i + l < n ? rank[i + l] : 0; memset(cnt, 0, sizeof (int) * n); for (i = 0; i < n; ++i) cnt[rank2[i]]++; for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1]; for (i = n - 1; ~i; --i) tmpSA[--cnt[rank2[i]]] = i; memset(cnt, 0, sizeof (int) * n); for (i = 0; i < n; ++i) cnt[rank1[i]]++; for (i = 1; i < n; ++i) cnt[i] += cnt[i - 1]; for (i = n - 1; ~i; --i) sa[--cnt[rank1[tmpSA[i]]]] = tmpSA[i]; bool unique = true; rank[sa[0]] = 0; for (i = 1; i < n; ++i) { rank[sa[i]] = rank[sa[i - 1]]; if (rank1[sa[i]] == rank1[sa[i - 1]] && rank2[sa[i]] == rank2[sa[i - 1]]) unique = false; else rank[sa[i]]++; } if (unique) break; } } inline void getHeight() { minHeight[0][0] = 0; for (int i = 0, j = 0, k = 0; i < n - 1; ++i) { if (k) --k; j = sa[rank[i] - 1]; while (str[i + k] == str[j + k]) k++; minHeight[0][rank[i]] = k; } for (int w = 1; (1 << w) <= n; ++w) for (int i = 0; i + (1 << w) <= n; ++i) minHeight[w][i] = min(minHeight[w - 1][i], minHeight[w - 1][i + (1 << (w - 1))]); } inline int query(int l, int r) { static int bit; bit = logs[r - l]; return min(minHeight[bit][l], minHeight[bit][r - (1 << bit)]); } inline int LCP(int l, int r) { l = rank[l], r = rank[r]; if (l > r) swap(l, r); return query(l + 1, r + 1); } }SA, rSA; inline int LCP(int i, int j) { return SA.LCP(i, j); } inline int LCS(int i, int j) { return rSA.LCP(n - i - 1, n - j - 1); } ll solve() { ll ans = 0; memset(pre, 0, sizeof (int) * (n + 1)); memset(post, 0, sizeof (int) * (n + 1)); for (int len = 1, x, y, l, r; (len << 1) <= n; ++len) for (int i = 0, j = len; j < n; i += len, j += len) if (SA.str[i] == SA.str[j]) { x = LCS(i, j), y = LCP(i, j), l = max(i - x + len, i), r = min(i + y, j); if (r - l >= 1) { pre[l + len]++, pre[r + len]--; post[l - len + 1]++, post[r - len + 1]--; } } for (int i = 1; i < n; ++i) pre[i] += pre[i - 1], post[i] += post[i - 1]; for (int i = 0; i < n - 1; ++i) ans += (ll)pre[i] * post[i + 1]; return ans; } int main() { logs[0] = logs[1] = 0; for (int i = 2; i < maxn; ++i) logs[i] = logs[i >> 1] + 1; int caseNum; scanf("%d", &caseNum); while (caseNum--) { scanf("%s", SA.str), n = strlen(SA.str); memcpy(rSA.str, SA.str, sizeof SA.str); reverse(rSA.str, rSA.str + n); SA.build_sa(), rSA.build_sa(), SA.getHeight(), rSA.getHeight(); printf("%lld\n", solve()); } return 0; }