BZOJ 1444 - [Jsoi2009]有趣的游戏

Published on 2016-07-03

题目地址

描述

分析

我们要求的答案是每个人获胜的概率,实际上,如果我们把所有单词建立出 AC 自动机,我们要求的就是到达每个人对应的单词末尾节点的概率。

我们定义 xix_i 为一局比赛经过 ii 点的概率(不是停在,而是经过),则那么 xix_i 就等于所有(ii 的前驱节点的概率 * ii 上的字符出现的概率)的和,根据这个可以列出方程,等等,为什么我解出来的都是 00?原因是因为常数项都是 00,这是个 nn 元齐次方程组,存在全 0 解。怎么办?可以令 xi=1x_i = 1 啊,怎么还是不对?

为什么没法解呢,原因是游戏开始,我们在根节点,所以经过根节点的概率为 x0=1x_0 = 1,根节点会转移出去,然后可能会有若干个节点转移到根节点,但根节点的概率最大就是 11,我们没法处理再次从根节点出发根节点的情况!一个可行的方案是退而求其次,用时间来换取精度,进行多次转移,这样到达每个节点的概率就能趋近于真实的值。这是网上的一种做法,转移使用矩阵快速幂实现,复杂度为 O((nl)3logK)O((nl)^3\log K)KK 为转移次数。

本文介绍一种更优美的做法。

我们发现刚刚的瓶颈在于无法处理根节点的多次转移,不过有一种巧妙的方法可以解决。

考虑定义 xix_i 为一局比赛经过 ii 点的期望次数(玩一盘游戏期望能经过多少次 ii 点)。这样的好处是什么呢?显而易见的是由于单词末尾节点走到了,意味着游戏结束,所以经过单词末尾节点的概率就是经过单词末尾节点的期望次数。第二个好处是对于根节点的,由于没有了概率上限为 11 的限制,我们可以自由的转移,则我们可以得到初始值:

x0=1x_0 = 1

Pi,jP_{i,j}jj 转移到 ii 的概率(后文讲求法),根据全期望公式(ii 的期望 = ii 的前驱节点的期望 * ii 上的字符出现的概率),根节点的方程是:

其余的节点的方程是:

考虑到需要使用高斯消元,我们把未知数放到一边:

根节点方程的方程是:

其余的节点的方程是:

考虑求 Pi,jP_{i, j}。设 p[i]p[i] 为生成字符 ii 的概率,char(i)\mathrm{char}(i) 为节点 ii 上的字符。如果 xux_u 能转移到 xvx_v,那么 Pv,uP_{v, u} 加上 p[char(v)]p[\mathrm{char(v)}]。如果 uu 是单词结尾节点或者是能通过 fail\mathrm{fail} 到达单词结尾节点的点,表示游戏结束,不再转移,所以直接略去其转移即可。

由于我们一局比赛走到末尾就停止,所以所有末尾节点经过次数的的期望和一定是 11

使用高斯消元来解方程,复杂度 O((nl)3)O((nl)^3)

Trick:所有人不能赢需要特判,否则你输出的都是 nan

总结:本题中若要使用高斯消元,定义概率是不可取的,原因是无法反复转移,而定义期望次数,则能巧妙的避开没办法反复转移的问题,从而解决问题。
吐槽:网上的高斯消元做法确实能 AC,但是没有一个能解释清楚为什么根节点的转移要减,原因是它们混淆了本题中应该设的变量是期望次数而不是概率。

代码

//  Created by Sengxian on 7/2/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 1444 AC自动机 + 期望 + 高斯消元
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return n;
}

const int maxn = 10 + 3, maxLen = 10 + 3, maxm = 10 + 3, modu = 10007;
const double eps = 1e-10;
int strID[maxn];
namespace AhoCorasickAutomata {
    const int maxNode = maxn * maxLen, maxChar = 26;
    int charNum = 26;
    int ch[maxNode][maxChar], f[maxNode], sz = 1;
    bool danger[maxNode];
    void insert(char *str, int id) {
        int cur = 0;
        for (int i = 0; str[i]; ++i) {
            int c = str[i] - 'A';
            if (!ch[cur][c]) ch[cur][c] = sz++;
            cur = ch[cur][c];
        }
        danger[cur] = true;
        strID[id] = cur;
    }
    void getFail() {
        queue<int> q;
        for (int i = 0; i < charNum; ++i) {
            int u = ch[0][i];
            if (u) q.push(u);
        }
        while (!q.empty()) {
            int r = q.front(); q.pop();
            for (int c = 0; c < charNum; ++c) {
                int u = ch[r][c];
                if (!u) ch[r][c] = ch[f[r]][c];
                else {
                    f[u] = ch[f[r]][c];
                    danger[u] |= danger[f[u]];
                    q.push(u);
                }
            }
        }
    }
}

using namespace AhoCorasickAutomata;

int n, l, m;
double P[maxm];
char str[maxLen];

typedef double Matrix[maxNode][maxNode];
Matrix a;

bool gauss_jordan(int n, Matrix a) {
    for (int i = 0; i < n; ++i) {
        int idx = i;
        for (int j = i + 1; j < n; ++j) if (fabs(a[j][i]) > fabs(a[idx][i])) idx = j;
        if (fabs(a[idx][i]) <= eps) return false;
        if (idx != i) for (int j = i; j <= n; ++j) swap(a[i][j], a[idx][j]);
        for (int j = 0; j < n; ++j) if (i != j) {
            double tmp = a[j][i] / a[i][i];
            for (int k = n; k >= i; --k) a[j][k] -= a[i][k] * tmp;
        }
    }
    return true;
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
#endif
    n = ReadInt(), l = ReadInt(), m = charNum = ReadInt();
    for (int i = 0; i < m; ++i) {
        int p = ReadInt(), q = ReadInt();
        P[i] = (double)p / q;
    }
    int fail = 0;
    for (int i = 0; i < n; ++i) {
        scanf("%s", str);
        for (int j = 0; j < l; ++j)
            if (P[str[j] - 'A'] < eps) {fail++; break;}
        insert(str, i);
    }
    if (fail == n) {
        for (int i = 0; i < n; ++i) puts("0.00");
        return 0;
    }
    getFail();
    a[0][0] = -1.0, a[0][sz] = -1.0;
    for (int i = 0; i < sz; ++i) {
        if (i) a[i][i] = -1.0;
        if (danger[i]) continue; //末尾节点无需转移
        for (int c = 0; c < m; ++c) {
            int u = ch[i][c];
            a[u][i] += P[c];
        }
    }

    assert(gauss_jordan(sz, a));
    for (int i = 0; i < n; ++i) {
        double p = a[strID[i]][sz] / a[strID[i]][strID[i]];
        if (fabs(p) <= eps) puts("0.00");
        else printf("%.2lf\n", p);
    }
    return 0;
}