BZOJ 4820 - [Sdoi2017]硬币游戏

Published on 2017-05-03

题目地址

描述

周末同学们非常无聊,有人提议,咱们扔硬币玩吧,谁扔的硬币正面次数多谁胜利。

大家纷纷觉得这个游戏非常符合同学们的特色,但只是扔硬币实在是太单调了。

同学们觉得要加强趣味性,所以要找一个同学扔很多很多次硬币,其他同学记录下正反面情况。

H \texttt{H} 表示正面朝上, 用 T \texttt{T} 表示反面朝上,扔很多次硬币后,会得到一个硬币序列。比如 HTT \texttt{HTT} 表示第一次正面朝上,后两次反面朝上。

但扔到什么时候停止呢?大家提议,选出 n(n300) n(n\le 300) 个同学, 每个同学猜一个长度为 m(m300) m(m\le 300) 的序列,当某一个同学猜的序列在硬币序列中出现时,就不再扔硬币了,并且这个同学胜利。为了保证只有一个同学胜利,同学们猜的 n n 个序列两两不同。

很快,n n 个同学猜好序列,然后进入了紧张而又刺激的扔硬币环节。你想知道,如果硬币正反面朝上的概率相同,每个同学胜利的概率是多少。

分析

本题可用类似 BZOJ 1444 的方法,建立 AC 自动机之后高斯消元,复杂度 O((nm)3)O((nm)^3),可以获得 40 分。

上述做法的瓶颈在于,存在太多中间状态,实际上单词末尾状态却只有 nn 个,造成了浪费。我们考虑将非单词末尾状态一起考虑,重新审视这道题。

设期望经过非单词末尾节点的次数为 x0x_0,经过序列 ii 的末尾节点的期望次数为 xix_i,根据 BZOJ 1444 中的讨论,期望经过次数 xix_i 就是第 ii 个同学获胜的概率。可以列出第一个方程

x1+x2++xn=1 x_1 + x_2 + \cdots + x_n = 1

我们考虑非单词末尾状态的转移,内部的转移是没有必要考虑的,我们考虑非单词末尾状态向单词末尾状态的转移。考虑第 ii 个序列,在非单词末尾状态后加上第 ii 个序列,可以到达第 ii 个序列,但是中途可能先走到另外一个末尾节点,我们考虑什么情况下会出现这种情况。

A=TTH,B=THTA = \texttt{TTH}, B = \texttt{THT},如果非单词末尾状态的最后两个字符是 TH\texttt{TH} 的话,在后面加上 AA 序列,会提前走到 BB 序列的单词末尾状态,这是因为 AA 的长度为 11 的前缀等于 BB 的长度为 11 的后缀。由于非单词末尾状态最后两个字符是出现的概率是均等的(并不会证明,感性理解下应该可以),我们固定了非单词末尾状态最后两个字符,所以这种情况出现的概率为 122\frac 1 {2^2}

不难推广到一般情况,对于 ii,枚举另外一个序列 jj,使用 KMP 计算出所有 ii 的长度为 kk 的前缀是 jj 的一段长度为 kk 的后缀的情况,并分别累加概率得到 Pi,j=k12mkP_{i, j} = \sum_k\frac {1} {2^{m - k}},就能列出方程

xi+i=1nPi,jxi=12mx0 x_i + \sum_{i = 1}^nP_{i, j}x_i = \frac {1} {2^m} x_0

现在有 n+1n + 1 个未知数和 n+1n + 1 个方程,高斯消元解方程即可,总复杂度 O(n2(n+m))O(n^2(n + m))

代码

//  Created by Sengxian on 2017/05/03.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4820 KMP + 期望 + 高斯消元
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 300 + 3, MAX_M = 300 + 3;

typedef long double type;

void gaussJordan(int n, type a[][MAX_N]) {
    for (int i = 0; i < n; ++i) {
        int idx = i;
        for (int j = i + 1; j < n; ++j)
            if (fabs(a[idx][i]) < fabs(a[j][i])) idx = j;
        if (idx != i) for (int j = i; j <= n; ++j) swap(a[idx][j], a[i][j]);
        for (int j = 0; j < n; ++j) if (j != i) {
            type t = a[j][i] / a[i][i];
            for (int k = i; k <= n; ++k) a[j][k] -= t * a[i][k];
        }
    }
    for (int i = 0; i < n; ++i) a[i][n] /= a[i][i], a[i][i] = 1;
}

int n, m;
char str[MAX_N][MAX_M];
int f[MAX_N][MAX_M];
type pow2[MAX_M];

type calc(int a, int b) {
    int j = 0;
    for (int i = 0; i < m; ++i) {
        while (j && str[b][i] != str[a][j]) j = f[a][j];
        if (str[b][i] == str[a][j]) ++j;
    }

    if (a == b) j = f[a][j];

    type res = 0;
    while (j) {
        res += pow2[m - j];
        j = f[a][j];
    }

    return res;
}

int main() {
#ifdef DEBUG
    freopen("test.in", "r", stdin);
#endif
    n = readInt(), m = readInt();

    for (int i = 0; i < n; ++i) {
        scanf("%s", str[i]);
        for (int j = 1; j < m; ++j) {
            int k = f[i][j];
            while (k && str[i][j] != str[i][k]) k = f[i][k];
            f[i][j + 1] = str[i][j] == str[i][k] ? k + 1 : 0;
        }
    }

    pow2[0] = 1;
    for (int i = 1; i <= m; ++i) pow2[i] = pow2[i - 1] * 0.5;

    static type a[MAX_N][MAX_N];

    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
            a[i][j] = calc(i, j);

    for (int i = 0; i < n; ++i) a[i][n] = -pow2[m], ++a[i][i];

    for (int i = 0; i < n; ++i) a[n][i] = 1;
    a[n][n + 1] = 1;

    gaussJordan(n + 1, a);

    for (int i = 0; i < n; ++i) printf("%.10Lf\n", a[i][n + 1]);

    return 0;
}