BZOJ 1195 - [HNOI2006]最短母串

Published on 2016-10-16

题目地址

描述

给定 n(n12)n(n\le 12) 个字符串 ,要求找到一个最短的字符串 TT,使得这 nn 个字符串都是 TT 的子串。

分析

很容易发现,我们要做的就是将字符串接起来,尽量让相邻的两个串有较多的公共部分。这就有了一个 DP 的模型,首先要记录一个集合 ss,表示选了哪些字符串,再怎么办?想了一下发现记录最后一个串是谁好像不够,因为在接上一个新的字符串的时候,有可能“完全吞掉最后一个串”(比如末尾是 a,而我们接一个 aa),这就要知道前面一个乃至更前的串是什么,算法到这好像陷入了僵局。

我们考虑什么时候才会出现“完全吞掉最后一个串”的情况,如果串 aa 吞掉了 bb,那么 bb 一定是 aa 的子串,实际上 bb 根本无需考虑,因为只要串 aa 在里面,bb 一定在里面!所以我们可以在一开始就剔除掉是别人的子串的字符串。
这样一来,记录最后一个串是谁就够了,设 dp[s][i]dp[s][i] 为选择的集合为 ss,最后一个串是 ii 的最小长度,枚举前一个串进行转移即可,注意要记录。

复杂度我们来推导一下,对于每个状态 dp[s][i]dp[s][i],要枚举在 ss 中的 iijj,而且还要进行长度最大为 50s50 \cdot \mid s \mid 的字符串比较,因此式子为:

i=0n(ni)i250i\sum_{i = 0}^n\binom n ii^2\cdot 50\cdot i

利用一些技巧,得到封闭形式为 50(n2n1+n(n1)2n2+n(n1)(n2)2n3)50\cdot(n \cdot 2^{n - 1}+ n(n-1)\cdot2^{n - 2} + n(n-1)(n-2)\cdot2^{n-3})n=12n = 12 时为 41074\cdot{10}^7 左右,可以承受。

代码

//  Created by Sengxian on 2016/10/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 1195 DP
#include <bits/stdc++.h>
using namespace std;

const int MAX_N = 12 + 3, MAX_LEN = 50 + 3, INF = 0x3f3f3f3f;
int n, f[MAX_N][MAX_LEN], val[MAX_N][MAX_N];
string s[MAX_N];
bool block[MAX_N];
string res[1 << MAX_N][MAX_N];

int dp[1 << MAX_N][MAX_N];

inline int cal(const string &a, const string &b, int f[]) {
    int c = 0;
    for (int i = 0; i < (int)a.length(); ++i) {
        while (c && a[i] != b[c]) c = f[c];
        if (a[i] == b[c]) ++c;
    }
    return c;
}

inline bool contain(const string &a, const string &b, int f[]) {
    int c = 0;
    for (int i = 0; i < (int)a.length(); ++i) {
        while (c && a[i] != b[c]) c = f[c];
        if (a[i] == b[c]) ++c;
        if (c == b.length()) return true;
    }
    return false;
}


inline string cat(const string &a, const string &b, int f[]) {
    int c = cal(a, b, f);
    return a + b.substr(c, b.length() - c);
}

void process(const string &s, int f[]) {
    f[0] = f[1] = 0;
    for (int i = 1, j = 0; i < (int)s.length(); ++i) {
        j = f[i];
        while (j && s[i] != s[j]) j = f[j];
        f[i + 1] = s[i] == s[j] ? j + 1 : 0;
    }
}

void unique() {
    for (int i = 0; i < n; ++i)
        for (int j = i + 1; j < n; ++j) {
            if (s[i] == s[j]) block[j] = true;
            else if (contain(s[i], s[j], f[j])) block[j] = true;
            else if (contain(s[j], s[i], f[i])) block[i] = true;
        }
    int cnt = 0;
    for (int i = 0; i < n; ++i)
        if (!block[i]) s[cnt++] = s[i];
    n = cnt;
}

int main() {
    cin >> n;
    for (int i = 0; i < n; ++i)
        cin >> s[i], process(s[i], f[i]);
    unique();
    for (int i = 0; i < n; ++i) process(s[i], f[i]);
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
            val[i][j] = s[j].length() - cal(s[i], s[j], f[j]);
    for (int i = 0; i < n; ++i) dp[1 << i][i] = s[i].length(), res[1 << i][i] = s[i];
    int len = 0;
    for (int ss = 0; ss < (1 << n); ++ss) if (__builtin_popcount(ss) > 1)
        for (int i = 0; i < n; ++i) if ((ss >> i) & 1) {
            dp[ss][i] = INF;
            for (int j = 0; j < n; ++j) if (i != j && (ss >> j) & 1) {
                len = dp[ss ^ (1 << i)][j] + val[j][i];
                if (len > dp[ss][i]) continue;
                string st = res[ss ^ (1 << i)][j] + s[i].substr(s[i].length() - val[j][i], val[j][i]);
                if (len < dp[ss][i]) dp[ss][i] = len, res[ss][i] = st;
                else if (st < res[ss][i]) res[ss][i] = st;
            }
        }
    string ans = res[(1 << n) - 1][0];
    for (int i = 1; i < n; ++i) {
        string now = res[(1 << n) - 1][i];
        if (now.length() < ans.length()) ans = now;
        else if (now.length() == ans.length()) ans = min(ans, now);
    }
    cout << ans << endl;
    return 0;
}