BZOJ 4503 - 两个串
Published on 2017-01-10描述
兔子们在玩两个串的游戏。给定两个字符串 和 ,兔子们想知道 在 中出现了几次,
分别在哪些位置出现。注意 T 中可能有 ?
字符,这个字符可以匹配任何字符。
分析
这个通配符很不好办,普通的 KMP 肯定是搞不定了。我们将 串倒过来,这样我们构造一个卷积:
其中 不存在的位置设为 ,通配符所在的位置也设置为 。
我们观察这个和式,发现 和 在字符串意义下相等,当且仅当 为 ,而由于通配符的值为 ,所以这个式子也能处理通配符。这样一来,对于一个位置 ,如果他是某个匹配的结束位置,那么 一定是 ,反之亦然。理解起来很容易,由于 ,而且所有项非负,那么自然所有项都是 ,显然有 ,注意超出的部分看作是通配符,所以也能匹配。
这个式子拆开后变为标准的卷积形式,可以用 FFT 优化,复杂度 。
总结:字符串匹配可以构造成卷积的形式,只需要将模式串倒过来,然后构造一个式子,使得两个位置匹配当且仅当式子的结果为 0。这个方法具有高度扩展性,所以可以轻松扩展到通配符的情况。
代码
// Created by Sengxian on 2017/1/10. // Copyright (c) 2017年 Sengxian. All rights reserved. // BZOJ 4503 FFT #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAX_N = 100000 + 3; char s[MAX_N], t[MAX_N]; const double pi = acos(-1.0); typedef complex<double> C; typedef vector<C> vc; typedef vector<double> vd; void FFT(vc &a, int oper = 1) { int n = a.size(); for (int i = 0, j = 0; i < n; ++i) { if (i > j) swap(a[i], a[j]); for (int l = n >> 1; (j ^= l) < l; l >>= 1); } for (int l = 1, ll = 2; l < n; l <<= 1, ll <<= 1) { double x = oper * pi / l; C omega = 1, omegan(cos(x), sin(x)); for (int k = 0; k < l; ++k, omega *= omegan) { for (int st = k; st < n; st += ll) { C tmp = omega * a[st + l]; a[st + l] = a[st] - tmp; a[st] += tmp; } } } if (oper == -1) for (int i = 0; i < n; ++i) a[i] /= n; } vd operator * (const vd &v1, const vd &v2) { int s = 1, ss = (int)v1.size() + (int)v2.size(); while (s < ss) s <<= 1; vc a(s, 0), b(s, 0); for (int i = 0; i < v1.size(); ++i) a[i] = v1[i]; for (int i = 0; i < v2.size(); ++i) b[i] = v2[i]; FFT(a), FFT(b); for (int i = 0; i < s; ++i) a[i] *= b[i]; FFT(a, -1); vd res(s); for (int i = 0; i < s; ++i) res[i] = a[i].real(); return res; } vd operator + (const vd &v1, const vd &v2) { int n = min(v2.size(), v1.size()); vd res(n); for (int i = 0; i < n; ++i) res[i] = v1[i] + v2[i]; return res; } int main() { scanf("%s%s", s, t); int n = strlen(s), m = strlen(t); reverse(t, t + m); for (int i = 0; i < m; ++i) if (t[i] == '?') t[i] = 0; vd a(n), b(n), sqrA(n), sqrB(n), cubeB(n); for (int i = 0; i < n; ++i) a[i] = s[i], sqrA[i] = s[i] * s[i]; for (int i = 0; i < n; ++i) b[i] = t[i], sqrB[i] = t[i] * t[i], cubeB[i] = t[i] * t[i] * t[i] + (i == 0 ? 0 : cubeB[i - 1]); vd t1 = sqrA * b, t2 = cubeB, t3 = a * sqrB; for (int i = 0; i < n; ++i) t3[i] = -2 * t3[i]; vd ans = t1 + t2 + t3; vector<int> pos; for (int i = 0; i < n; ++i) if (i - m + 1 >= 0 && fabs(ans[i]) < 0.3) pos.push_back(i - m + 1); printf("%u\n", pos.size()); for (int i = 0; i < (int)pos.size(); ++i) printf("%d\n", pos[i]); return 0; }