BZOJ 4377 - [POI2015]Kurs szybkiego czytania

Published on 2017-01-31

给定 n,a,b,p(2n109,1a,b,p<n)n, a, b, p(2\le n\le {10}^9, 1\le a, b, p < n),其中 n,an, a 互质。定义一个长度为 nn0101 ,其中 ci=0c_i = 0 当且仅当 (ai+b)modn<p(a\cdot i+b) \bmod n < p

给定一个长为 m(m106)m(m\le {10}^6) 的小 0101 串,求出小串在大串中出现了几次。

分析

首先发现 (ai+b)modn(a\cdot i+b) \bmod n 是不可能相同的。原因是假设 ai+baj+b(modn)a\cdot i+b \equiv a\cdot j+b \pmod n,那么就有 aiaj(modn)a\cdot i \equiv a\cdot j \pmod n,因为 gcd(a,n)=1\gcd(a, n) = 1,所以 a1a^{-1} 存在,两边同乘 a1a^{-1} 得到 ij(modn)i \equiv j \pmod n,所以 i,ji, j 必然相等。

得到这个性质,可以发现惯常的匹配方法就派不上用场了,nn 高达 109{10}^9,所以不可能将 cic_i 写出来,我们得想别的办法。

f(i)=(ai+b)modnf(i) = (a\cdot i+b) \bmod n,由于 f(i+1)f(i)=af(i + 1) - f(i) = a,只要固定了 f(i)f(i),那么 都是知道的,也就是说只要固定了 f(i)f(i),那么 也是固定的,f(i)f(i) 是不是一个合法的匹配的开头也就是知道的。

由于 f(i)(0i<n)f(i)(0\le i < n) 两两不重复,于是我们只需要知道有多少个 f(i)f(i) 可以作为匹配的开头即可。

我们观察模式串 ,假设 f(i)f(i) 作为开头,那么每一位 pip_i 对于 f(i)f(i) 的值就有一个约束,比如 p0=0p_0 = 0,那么有约束 0f(i)<p0\le f(i) < p;再如 p3=1p_3 = 1,那么有约束 pf(i)+3a<np\le f(i) + 3a < n。所有这些约束的交集,就是所有可以作为开头的 f(i)f(i)(注意要排除掉 ),对于这种模意义下的不等式,我们考虑区间平移的方法处理,那么每一条约束就有可能变为 f(i)[l1,r1)f(i) \in [l_1, r_1)f(i)[l2,r2)f(i) \in [l_2, r_2),这是不容易求并集的。我们转而考虑计算所有不可行区间的并集,排序后扫描就能得到所有可行的取值区间。

最多有 4m4m 个不可行区间,总复杂度 O(mlogm)O(m\log m)

总结:本题需要观察到普通的匹配方法是没法用的,由于 f(i)f(i) 不重复,所以转而考虑哪些 f(i)f(i) 可以作为开头,构造出 f(i)f(i) 的取值约束,最后转化到了一个经典的区间求并问题上,非常妙。

代码

//  Created by Sengxian on 2017/1/29.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4377 数学
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_M = 1000000 + 3;
int n, a, b, p, m;
char s[MAX_M];

typedef pair<int, int> range;
range ranges[MAX_M * 4];
int cnt = 0;

inline void insert(const range &r) {
    if (r.second - r.first >= 1)
        ranges[cnt++] = r;
}

inline void add(const range &r) {
    insert(range(0, r.first));
    insert(range(r.second, n));
}

inline void add(const range &r1, const range &r2) {
    insert(range(0, r1.first));
    insert(range(r1.second, r2.first));
    insert(range(r2.second, n));
}

int solve() {
    ranges[cnt++] = range(0, 0), ranges[cnt++] = range(n, n);
    sort(ranges, ranges + cnt);

    int ans = 0;
    for (int i = 0, j = 0; i < cnt; i = j) {
        int rightMost = ranges[i].second;
        while (j < cnt && ranges[j].first <= rightMost) rightMost = max(rightMost, ranges[j++].second);
        if (j != cnt) ans += ranges[j].first - rightMost;
    }

    return ans;
}

int main() {
    scanf("%d%d%d%d%d%s", &n, &a, &b, &p, &m, s);

    for (int i = 0, t = 0; i < m; ++i) {
        if (s[i] == '0') {
            if (t >= p) add(range(n - t, n + p - t));
            else add(range(0, p - t), range(n - t, n));
        } else {
            if (t <= p) add(range(p - t, n - t));
            else add(range(0, n - t), range(n + p - t, n));
        }
        (t += a) %= n;
    }

    for (int i = n - m + 1; i < n; ++i)
        ranges[cnt++] = range(((ll)a * i + b) % n, ((ll)a * i + b) % n + 1);

    printf("%d\n", solve());
    return 0;
}