BZOJ 4565 - [Haoi2016]字符合并

Published on 2017-01-07

描述

有一个长度为 n(n300)n(n\le 300) 的 01 串,你可以每次将相邻的 k(k8)k(k\le 8) 个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这 kk 个字符确定。你需要求出你能获得的最大分数。

分析

我们首先需要知道一个事实:最终的串必然是要合并到不能合并为止。

也就是说,对于任意一个串,存在一个最优合并方案,使得剩下的字符个数在 [0,k)[0, k) 之内。这就可以得到一个性质:

性质:对于每一个固定长度的串,一定存在一个最优方案,使得剩下的字符数量为一个定值 ttt[0,k)t\in [0, k)

对于每一个长度,最终得到的长度可以线性地预处理出来。

for (int i = 0; i < K; ++i) len[i] = i;
for (int i = K; i <= n; ++i) len[i] = len[i - K + 1];

本题涉及到相邻字符的合并,那么不难想到一个区间 + 状压 DP 的模型出来。设 dp(l,r,s)dp(l, r, s) 为将区间 [l,r)[l, r) 合并为 ss 的最大得分。注意到这个 ss 是有长度的,而我们在程序中是用一个数来保存,而 00101 都会被保存为 11,如何避免混淆呢?

根据我们开头的性质,这个问题非常好解决。每一个固定长度的串,最后得到的长度是一定的,所以我们只需要取 ss 二进制位前面对应位数即可。

对于转移,我们可以知道,ss 的最后一位一定是由 [t,r)[t, r) 这一段区间得到的,所以我们枚举 tt 的位置即可。由于 [t,r)[t, r) 要求合并为一个字符,所以 [t,r)[t, r) 得满足最终合并的长度为 11 才行(由于我们并不能用状压的 ss 来区分最终合并到的长度,不判断长度就会造成不合法的转移)。最终合并为 1 的区间原长必然是 1+(k1)x,x01 + (k - 1) \cdot x, x\ge 0 的形式,枚举的时候每次减去 k1k - 1 即可。转移是标准的区间 DP 转移,就不写了。

对于得分的累加,只需要在 len(w)=1\mathrm{len}(w) = 1 的时候,DP 出长度为 kk 的最优得分,然后手动将其合并成一个字符即可。比较简单,可以参见代码。

复杂度:O(n2k2k)O(\frac {n^2} k \cdot 2 ^ k)

总结:本题精彩之处在于「对于一个定长区间,最后合并到的长度是一个定值」,这就可以大大优化 DP 的状态定义。由此得到:状态数过大的时候,要考虑从性质上下手。

代码

//  Created by Sengxian on 2017/1/7.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 4565 状压 DP
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 300 + 3, MAX_K = 8;
const ll INF = 0x3f3f3f3f3f3f3f3fLL;
char str[MAX_N];
int n, K, ch[MAX_N], score[1 << MAX_K];

ll dp[MAX_N][MAX_N][1 << MAX_K];
int len[MAX_N];

inline void relax(ll &a, const ll &b) {
    if (b > a) a = b;
}

void solve() {
    memset(dp, -0x3f, sizeof dp);
    for (int i = 0; i < n; ++i) dp[i][i][0] = 0;
    for (int i = 0; i < n; ++i) dp[i][i + 1][str[i]] = 0;

    for (int w = 2; w <= n; ++w) {
        if (len[w] == 1) {
            for (int i = 0; i + w <= n; ++i) {
                int j = i + w;    
                for (int s = 0; s < (1 << K); ++s)
                    for (int l = j - 1; l > i; l -= K - 1)
                        relax(dp[i][j][s], dp[i][l][s >> 1] + dp[l][j][s & 1]);

                static ll tmp[2];
                tmp[0] = -INF, tmp[1] = -INF;
                for (int s = 0; s < (1 << K); ++s) relax(tmp[ch[s]], dp[i][j][s] + score[s]);
                dp[i][j][0] = tmp[0], dp[i][j][1] = tmp[1];
            }
        } else {
            for (int i = 0; i + w <= n; ++i) {
                int j = i + w;
                for (int s = 0; s < (1 << len[w]); ++s)
                    for (int l = j - 1; l > i; l -= K - 1)
                        relax(dp[i][j][s], dp[i][l][s >> 1] + dp[l][j][s & 1]);
            }
        }
    }

    printf("%lld\n", *max_element(dp[0][n], dp[0][n] + (1 << len[n])));
}

int main() {
    n = readInt(), K = readInt();
    scanf("%s", str);
    for (int i = 0; i < n; ++i) str[i] -= '0';
    for (int i = 0; i < (1 << K); ++i) ch[i] = readInt(), score[i] = readInt();
    for (int i = 0; i < K; ++i) len[i] = i;
    for (int i = K; i <= n; ++i) len[i] = len[i - K + 1];
    solve();
    return 0;
}