HDOJ 4626 - Jinkeloid

Published on 2017-04-24

题目地址

描述

给定一个长度为 n(n105)n(n\le {10}^5) 的字符串 ss(字符集为前 20 个小写字母)。另有 Q(Q3×104)Q(Q\le 3\times {10}^4) 个询问,每次给定 k(1k5)k(1\le k\le 5) 个字符 c1,c2,,ckc_1, c_2, \ldots, c_k,问 ss 有多少个子串满足 c1,c2,,ckc_1, c_2, \ldots, c_k 都出现了偶数次。

分析

我们把第 i(0i<20)i(0\le i < 20) 个字符看成一个二进制位 ii,设 fif_is0s1si1s_0\otimes s_1\otimes\cdots\otimes s_{i - 1},其中 f0=0f_0 = 0。则子串 [l,r)[l, r) 各个字符出现的奇偶性就是 frflf_r\otimes f_l

对于询问字符 c1,c2,,ckc_1, c_2, \ldots, c_k,我们需要计算有多少个无序对 (l,r)(l, r),满足 flfrf_l\otimes f_r 中二进制位 c1,c2,,ckc_1, c_2, \ldots, c_k00

我们考虑如果 frf_r 中二进制位 c1,c2,,ckc_1, c_2, \ldots, c_k 的情况为 ss,那么 flf_l 中二进制位 c1,c2,,ckc_1, c_2, \ldots, c_k 的情况也应该为 ss,才能保证得到 00

我们可以得到这样一个算法,枚举 c1,c2,,ckc_1, c_2, \ldots, c_k 的情况为 ss,统计出有 mm ,那么贡献为 m(m1)2\frac {m(m - 1)} 2。枚举 ss 的复杂度为 2s2^s,统计的复杂度为 O(n)O(n),这样单组询问就是 O(n2s)O(n2^s) 的,无法承受。

我们从优化统计 的数量入手,符合要求的 fif_i 满足 c1,c2,,ckc_1, c_2, \ldots, c_k 的情况为 ss,其他位都是任意的,我们设 gs=sSfSg_s = \sum_{s\subseteq S}f_S(就是 ss 的超集的 ff 和),设 s1s_1ssc1,c2,,ckc_1, c_2, \ldots, c_k00 的位置,那么我们可以用容斥算出数量

有点难以解释,但是思路很清晰,因为我们只能固定一些位置为 11,计算这些位置的超集的 ff 和;而不能固定一些位置为 0/10/1(因为这样就有 3203^{20} 种固定方法了,我们是没办法预处理的)。既然只能固定一些位置为 11,那我们就对固定为 00 的位置容斥一下,这样就能计算出固定 0/10/1 的方案了。

这样做的话,单次查询的复杂度降为 O(3k)O(3^k)

剩下的问题就是如何计算 gsg_s,暴力计算为 O(320)O(3^{20}),无法通过,我们考虑 dp(i,s)dp(i, s) 表示固定前 ii 位不动的情况下,考虑 ss 的超集的 ff 的和,那么,dp(20,s)=fsdp(20, s) = f_s,转移为:

dp(i,s)={dp(i+1,s),isdp(i+1,s)+dp(i+1,s{i})otherwise dp(i, s) = \begin{cases} dp(i + 1, s),& i\in s\\ dp(i + 1, s) + dp(i + 1, s \cup \{i\}) & \mathrm{otherwise} \end{cases}

由于可以循环利用数组,程序可以简化到 3 行:

for (int s = 0; s < (1 << SIGMA_SIZE); ++s) g[s] = f[s];
for (int i = SIGMA_SIZE - 1; i >= 0; --i)
    for (int s = 0; s < (1 << SIGMA_SIZE); ++s) if (!(s >> i & 1))
        g[s] += f[s ^ (1 << i)];

这样就能 O(n2n)O(n2^n) 求出超集,总的复杂度为:O(n2n+q3k)O(n2^n + q3^k)

代码

//  Created by Sengxian on 2017/04/24.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  HDOJ 4626 子集 DP + 容斥原理
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    int n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

inline int readAlpha() {
    int c = getchar();
    while (!isalpha(c)) c = getchar();
    return c;
}

const int MAX_N = 100000 + 3, SIGMA_SIZE = 20;
char str[MAX_N];
int n, f[1 << SIGMA_SIZE], popCount[1 << SIGMA_SIZE];

int main() {
    for (int i = 1; i < (1 << SIGMA_SIZE); ++i) popCount[i] = popCount[i >> 1] + (i & 1);

    int caseNum = readInt();
    while (caseNum--) {
        scanf("%s", str);
        n = strlen(str);

        memset(f, 0, sizeof f), ++f[0];
        for (int i = 0, t = 0; i < n; ++i) ++f[t ^= 1 << (str[i] - 'a')];
        for (int i = SIGMA_SIZE - 1; i >= 0; --i)
            for (int s = 0; s < (1 << SIGMA_SIZE); ++s) if (!(s >> i & 1))
                f[s] += f[s ^ (1 << i)];

        int q = readInt();
        while (q--) {
            int k = readInt(), s = 0, s1;
            ll ans = 0;
            for (int i = 0; i < k; ++i) s |= 1 << (readAlpha() - 'a');

            s1 = s;
            do {
                int s2 = s ^ s1, s3 = s2, cnt = 0;

                do {
                    cnt += ((popCount[s3] & 1) ? -1 : 1) * f[s1 | s3];
                    s3 = (s3 - 1) & s2;
                } while (s3 != s2);

                ans += (ll)cnt * (cnt - 1) / 2;

                s1 = (s1 - 1) & s;
            } while (s1 != s);

            printf("%lld\n", ans);
        }
    }

    return 0;
}