BZOJ 4503 - 两个串

Published on 2017-01-10

描述

兔子们在玩两个串的游戏。给定两个字符串 S(S105)S(\vert S\vert \le {10}^5)T(TS)T(\vert T\vert \le \vert S\vert),兔子们想知道 TTSS 中出现了几次,
分别在哪些位置出现。注意 T 中可能有 ? 字符,这个字符可以匹配任何字符。

分析

这个通配符很不好办,普通的 KMP 肯定是搞不定了。我们将 TT 串倒过来,这样我们构造一个卷积:

ck=i+j=k(SiTj)2Tj c_k = \sum_{i + j = k}(S_i - T_j)^2\cdot T_j

其中 TjT_j 不存在的位置设为 00,通配符所在的位置也设置为 00

我们观察这个和式,发现 SiS_iTjT_j 在字符串意义下相等,当且仅当 (SiTj)2Tj(S_i - T_j)^2\cdot T_j00,而由于通配符的值为 00,所以这个式子也能处理通配符。这样一来,对于一个位置 kk,如果他是某个匹配的结束位置,那么 ckc_k 一定是 00,反之亦然。理解起来很容易,由于 ck=0c_k = 0,而且所有项非负,那么自然所有项都是 00,显然有 ,注意超出的部分看作是通配符,所以也能匹配。

这个式子拆开后变为标准的卷积形式,可以用 FFT 优化,复杂度 O(nlogn)O(n\log n)

总结:字符串匹配可以构造成卷积的形式,只需要将模式串倒过来,然后构造一个式子,使得两个位置匹配当且仅当式子的结果为 0。这个方法具有高度扩展性,所以可以轻松扩展到通配符的情况。

代码

//  Created by Sengxian on 2017/1/10.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4503 FFT
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

const int MAX_N = 100000 + 3;
char s[MAX_N], t[MAX_N];

const double pi = acos(-1.0);
typedef complex<double> C;
typedef vector<C> vc;
typedef vector<double> vd;

void FFT(vc &a, int oper = 1) {
    int n = a.size();
    for (int i = 0, j = 0; i < n; ++i) {
        if (i > j) swap(a[i], a[j]);
        for (int l = n >> 1; (j ^= l) < l; l >>= 1);
    }
    for (int l = 1, ll = 2; l < n; l <<= 1, ll <<= 1) {
        double x = oper * pi / l;
        C omega = 1, omegan(cos(x), sin(x));
        for (int k = 0; k < l; ++k, omega *= omegan) {
            for (int st = k; st < n; st += ll) {
                C tmp = omega * a[st + l];
                a[st + l] = a[st] - tmp;
                a[st] += tmp;
            }
        }
    }
    if (oper == -1) for (int i = 0; i < n; ++i) a[i] /= n;
}

vd operator * (const vd &v1, const vd &v2) {
    int s = 1, ss = (int)v1.size() + (int)v2.size();
    while (s < ss) s <<= 1;
    vc a(s, 0), b(s, 0);
    for (int i = 0; i < v1.size(); ++i) a[i] = v1[i];
    for (int i = 0; i < v2.size(); ++i) b[i] = v2[i];
    FFT(a), FFT(b);
    for (int i = 0; i < s; ++i) a[i] *= b[i];
    FFT(a, -1);
    vd res(s);
    for (int i = 0; i < s; ++i) res[i] = a[i].real();
    return res;
}

vd operator + (const vd &v1, const vd &v2) {
    int n = min(v2.size(), v1.size());
    vd res(n);
    for (int i = 0; i < n; ++i) res[i] = v1[i] + v2[i];
    return res;
}

int main() {
    scanf("%s%s", s, t);
    int n = strlen(s), m = strlen(t);

    reverse(t, t + m);
    for (int i = 0; i < m; ++i) if (t[i] == '?') t[i] = 0;

    vd a(n), b(n), sqrA(n), sqrB(n), cubeB(n);
    for (int i = 0; i < n; ++i) a[i] = s[i], sqrA[i] = s[i] * s[i];
    for (int i = 0; i < n; ++i) b[i] = t[i], sqrB[i] = t[i] * t[i], cubeB[i] = t[i] * t[i] * t[i] + (i == 0 ? 0 : cubeB[i - 1]);

    vd t1 = sqrA * b, t2 = cubeB, t3 = a * sqrB;
    for (int i = 0; i < n; ++i) t3[i] = -2 * t3[i];
    vd ans = t1 + t2 + t3;

    vector<int> pos;
    for (int i = 0; i < n; ++i)
        if (i - m + 1 >= 0 && fabs(ans[i]) < 0.3) pos.push_back(i - m + 1);
    printf("%u\n", pos.size());
    for (int i = 0; i < (int)pos.size(); ++i)
        printf("%d\n", pos[i]);
    return 0;
}