题目地址

描述

让我们继续 JC 和 DZY 的故事。
“你是我的小丫小苹果,怎么爱你都不嫌多!”
“点亮我生命的火,火火火火火!”

话说 JC 历经艰辛来到了城市 B,但是由于他的疏忽 DZY 偷走了他的小苹果!没有小苹果怎么听歌!他发现邪恶的 DZY 把他的小苹果藏在了一个迷宫里。JC 在经历了之前的战斗后他还剩下 hp(hp10000)\mathrm{hp}(\mathrm{hp}\le 10000) 点血。开始 JC 在 11 号点,他的小苹果在 n(n150)n(n\le 150) 号点。DZY 在一些点里放了怪兽。当 JC 每次遇到位置在 ii 的怪兽时他会损失 AiA_i 点血。当 JC 的血小于等于 00 时他就会被自动弹出迷宫并且再也无法进入。

但是 JC 迷路了,他每次只能从当前所在点出发等概率的选择一条道路走。所有道路都是双向的,一共有 m(m5000)m(m\le 5000) 条,怪兽无法被杀死。现在 JC 想知道他找到他的小苹果的概率。

分析

如果没做过这类题,建议先做 BZOJ 1444BZOJ 3143

我们考虑一个复杂度稍高的期望 DP ,设 E(i,j)E(i, j) 当前在点 ii,血量为 jj,经过这个状态的期望次数,那么由于走到 nn 点即停止,所以每个在 nn 点的状态的期望和数就是走到 nn 点的概率。答案就是:

1khpE(n,k) \sum_{1\le k\le \mathrm{hp}}E(n, k)

由于存在 Ai=0A_i = 0 的情况,并非是拓扑关系,所以使用高斯消元解方程,复杂度 O((nhp)3)O((n\cdot \mathrm{hp})^3),太高了。
考虑到血量是单调不增的,所以我们可以将状态按照血量分层,每层之间满足拓扑关系,这样每一层中高斯消元的复杂度就是 O(n3)O(n^3),总复杂度 O(hpn3)O(\mathrm{hp}\cdot n^3),还是不行。

我们考虑每一层中高斯消元的异同点,不难发现,每一次的系数矩阵是一样的,而不一样的只有常数矩阵。我们考虑预处理一次高斯消元,记录每个方程的最终的常数项是由原方程的的常数如何线性组合得到的。

看一个例子,比如我们有两个方程:

这两个方程不一样的就只有常数矩阵,一个是 [a,b]T[a, b]^T,一个是 [c,d]T[c, d]^T,我们考虑预处理一个方程的解,其中常数项代表常数矩阵的一个线性组合。比如一开始预处理的方程是这样的:

{x1x2=[1,0]x1+x2=[0,1] \begin{cases} x_1 - x_2 = [1, 0]\\ x_1 + x_2 = [0, 1] \end{cases}

其中,[1,0][1, 0] 表示如果我们的常数矩阵是 [c0,c1]T[c_0, c_1]^T 的话,那么这个值就是 c01+c10c_0 \cdot 1 + c_1 \cdot 0。显然这个东西是可以加减乘除的。我们解出这个方程得到:

{x1+0x2=[0.5,0.5]0x1+x2=[0.5,0.5] \begin{cases} x_1 + 0x_2 = [0.5, 0.5]\\ 0x_1 + x_2 = [-0.5, 0.5] \end{cases}

那么带入对应的常数矩阵 [a,b]T[a, b]^T,我们就能立即知道,解为:

{x0=a0.5+b0.5x1=a(0.5)+b0.5 \begin{cases} x_0 = a \cdot 0.5 + b \cdot 0.5\\ x_1 = a \cdot (-0.5) + b \cdot 0.5 \end{cases}

回到这个题,我们可以用 O(n3)O(n^3) 的时间内预处理一次高斯消元,对于计算每层,只需要预先计算出常数矩阵,然后按照我们解出来的解带入值即可。

复杂度:O(n3+hpn2)O(n^3 + \mathrm{hp} \cdot n ^ 2)

代码

//  Created by Sengxian on 2016/12/02.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 3640 分层期望 DP + 预处理高斯消元
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 150 + 3, MAX_M = 5000 + 3, MAX_HP = 10000 + 3;

struct edge {
    edge *next;
    int to;
    edge(edge *next = NULL, int to = 0): next(next), to(to) {}
} pool[MAX_M * 2], *pit = pool, *first[MAX_N], *rFirst[MAX_N];

int n, m, hp;
int c[MAX_N], deg[MAX_N];
double a[MAX_N][MAX_N];

struct data {
    double a[MAX_N];
    inline double& operator [] (const int &i) {
        return a[i];
    }
} datas[MAX_N];

void gauss_jordan(int n, double a[][MAX_N], data datas[]) {
    for (int i = 0; i < n; ++i) {
        int idx = i;
        for (int j = i + 1; j < n; ++j)
            if (fabs(a[j][i]) > fabs(a[idx][i])) idx = j;

        assert(fabs(a[idx][i]) > 1e-8);
        if (idx != i) {
            for (int j = i; j < n; ++j) swap(a[idx][j], a[i][j]);
            for (int j = 0; j < n; ++j) swap(datas[idx][j], datas[i][j]);
        }

        for (int j = 0; j < n; ++j) if (j != i) {
            double t = a[j][i] / a[i][i];
            for (int k = i; k < n; ++k) a[j][k] -= a[i][k] * t;
            for (int k = 0; k < n; ++k) datas[j][k] -= datas[i][k] * t;
        }
    }

    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
            datas[i][j] /= a[i][i];
}

void prepare() {
    memset(a, 0, sizeof a);
    memset(datas, 0, sizeof datas);

    for (int i = 0; i < n; ++i) {
        a[i][i] = 1.0, datas[i][i] = 1.0;
        if (c[i] == 0) {
            for (edge *e = first[i]; e; e = e->next) if (e->to + 1 != n)
                a[i][e->to] -= 1.0 / deg[e->to];
        } 
    }

    gauss_jordan(n, a, datas);
}

double dp[MAX_HP][MAX_N];

void solve() {
    double ans = 0;

    for (int h = hp; h; --h) {
        static double constant[MAX_N];
        memset(constant, 0, sizeof constant);
        if (h == hp) constant[0] = 1.0;

        for (int i = 0; i < n; ++i) if (c[i] && h + c[i] <= hp)
            for (edge *e = first[i]; e; e = e->next) if (e->to + 1 != n)
                constant[i] += dp[h + c[i]][e->to] / deg[e->to];

        for (int i = 0; i < n; ++i)
            for (int j = 0; j < n; ++j)
                dp[h][i] += constant[j] * datas[i][j];

        ans += dp[h][n - 1];
    }

    printf("%.8f\n", ans);
}

int main() {
    n = readInt(), m = readInt(), hp = readInt();
    for (int i = 0; i < n; ++i) c[i] = readInt();
    for (int i = 0, u, v; i < m; ++i) {
        u = readInt() - 1, v = readInt() - 1;
        first[u] = new (pit++) edge(first[u], v), deg[u]++;
        if (u != v) first[v] = new (pit++) edge(first[v], u), deg[v]++;
    }

    prepare();
    solve();

    return 0;
}