HDOJ 4626 - Jinkeloid
Published on 2017-04-24描述
给定一个长度为 的字符串 (字符集为前 20 个小写字母)。另有 个询问,每次给定 个字符 ,问 有多少个子串满足 都出现了偶数次。
分析
我们把第 个字符看成一个二进制位 ,设 为 ,其中 。则子串 各个字符出现的奇偶性就是 。
对于询问字符 ,我们需要计算有多少个无序对 ,满足 中二进制位 为 。
我们考虑如果 中二进制位 的情况为 ,那么 中二进制位 的情况也应该为 ,才能保证得到 。
我们可以得到这样一个算法,枚举 的情况为 ,统计出有 个 ,那么贡献为 。枚举 的复杂度为 ,统计的复杂度为 ,这样单组询问就是 的,无法承受。
我们从优化统计 的数量入手,符合要求的 满足 的情况为 ,其他位都是任意的,我们设 (就是 的超集的 和),设 为 中 为 的位置,那么我们可以用容斥算出数量
有点难以解释,但是思路很清晰,因为我们只能固定一些位置为 ,计算这些位置的超集的 和;而不能固定一些位置为 (因为这样就有 种固定方法了,我们是没办法预处理的)。既然只能固定一些位置为 ,那我们就对固定为 的位置容斥一下,这样就能计算出固定 的方案了。
这样做的话,单次查询的复杂度降为 。
剩下的问题就是如何计算 ,暴力计算为 ,无法通过,我们考虑 表示固定前 位不动的情况下,考虑 的超集的 的和,那么,,转移为:
由于可以循环利用数组,程序可以简化到 3 行:
for (int s = 0; s < (1 << SIGMA_SIZE); ++s) g[s] = f[s]; for (int i = SIGMA_SIZE - 1; i >= 0; --i) for (int s = 0; s < (1 << SIGMA_SIZE); ++s) if (!(s >> i & 1)) g[s] += f[s ^ (1 << i)];
这样就能 求出超集,总的复杂度为:。
代码
// Created by Sengxian on 2017/04/24. // Copyright (c) 2017年 Sengxian. All rights reserved. // HDOJ 4626 子集 DP + 容斥原理 #include <bits/stdc++.h> using namespace std; typedef long long ll; inline int readInt() { int n = 0, ch = getchar(); while (!isdigit(ch)) ch = getchar(); while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar(); return n; } inline int readAlpha() { int c = getchar(); while (!isalpha(c)) c = getchar(); return c; } const int MAX_N = 100000 + 3, SIGMA_SIZE = 20; char str[MAX_N]; int n, f[1 << SIGMA_SIZE], popCount[1 << SIGMA_SIZE]; int main() { for (int i = 1; i < (1 << SIGMA_SIZE); ++i) popCount[i] = popCount[i >> 1] + (i & 1); int caseNum = readInt(); while (caseNum--) { scanf("%s", str); n = strlen(str); memset(f, 0, sizeof f), ++f[0]; for (int i = 0, t = 0; i < n; ++i) ++f[t ^= 1 << (str[i] - 'a')]; for (int i = SIGMA_SIZE - 1; i >= 0; --i) for (int s = 0; s < (1 << SIGMA_SIZE); ++s) if (!(s >> i & 1)) f[s] += f[s ^ (1 << i)]; int q = readInt(); while (q--) { int k = readInt(), s = 0, s1; ll ans = 0; for (int i = 0; i < k; ++i) s |= 1 << (readAlpha() - 'a'); s1 = s; do { int s2 = s ^ s1, s3 = s2, cnt = 0; do { cnt += ((popCount[s3] & 1) ? -1 : 1) * f[s1 | s3]; s3 = (s3 - 1) & s2; } while (s3 != s2); ans += (ll)cnt * (cnt - 1) / 2; s1 = (s1 - 1) & s; } while (s1 != s); printf("%lld\n", ans); } } return 0; }