Codeforces 757F - Team Rocket Rises Again

Published on 2017-01-13

题目地址

描述

给定一个 n(n2×105)n(n\le 2\times {10}^5) 个点,m(m3×105)m(m\le 3\times {10}^5) 的带权无向图。请你选择一个点 vsv\neq s,使得在图中删掉点 vv 后,有尽可能多的点到 ss 的最短距离改变。

分析

首先从 ss 出发跑一遍单源最短路,记 ssvv 的最短距离为 dis(v)\mathrm{dis}(v),我们考虑什么情况下 ssvv 的最短距离改变。

新建一个图,考虑原图中的每条边 (u,v)(u, v),如果满足 dis(u)+cost(u,v)=dis(v)\mathrm{dis}(u) + \mathrm{cost}(u, v) = \mathrm{dis}(v),在新图中连边 (u,v)(u, v)。由于边权是正整数,这个新图显然是一个 DAG,而且是从 ss 出发的 DAG。这个 DAG 满足一个性质:删掉一个点之后,只要 ssvv 仍然有一条路径,那么 ssvv 的最短路就不会改变。

这种约束关系,让人回想起了 ZJOI 2012 灾难 一题中的「灭绝树」。

灭绝树是指,对于任意一个子树,若这个子树的根节点灭绝,那么子树中的所有点都会灭绝。如果我们定义本题的灭绝关系指的是「到 ss 的最短路改变」,那么只要将 DAG 转化为灭绝树,答案就是根节点最大的子树的大小。本题中灭绝树的意义是,对于任意一个子树,若这个子树的根节点被删除,那么子树中的所有点到根的最短路都会改变。

我们考虑使用增量法构出灭绝树,按照 DAG 的拓扑序不断加入点。灭绝树的树根为 ss,考虑如何加入节点 vv:找到 vv 在 DAG 中的所有前继,计算出所有前继的 LCA 为 uu,在灭绝树中连边 (u,v)(u, v)

正确性容易证明,当且仅当 vv 的前继都被摧毁,ssvv 的最短路才会改变,而只有摧毁 vv 的所有前继的 LCA,vv 的所有前继才会被摧毁。

使用 SPFA 做单源最短路,使用树上倍增寻找 LCA,复杂度:O(mk+mlogn)O(mk + m\log n)

代码

//  Created by Sengxian on 2017/01/13.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  Codeforces 757F 最短路 + 灭绝树
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 200000 + 3, MAX_M = 300000 + 3;

struct edge {
    edge *next;
    int to, cost;
    edge(edge *next = NULL, int to = 0, int cost = 0): next(next), to(to), cost(cost) {}
} pool[MAX_M * 2 + MAX_N], *pit = pool, *first[MAX_N], *newFirst[MAX_N];

int n, m, s;
ll dis[MAX_N];

inline bool tension(ll &a, const ll &b) {
    return b < a ? a = b, true : false;
}

void spfa(int s) {
    static int q[MAX_N];
    static bool inQ[MAX_N];
    int l = 0, r = 0;
    memset(dis, 0x3f, sizeof dis);
    dis[s] = 0, q[r++] = s;

    while (r - l >= 1) {
        int u = q[(l++) % n];
        inQ[u] = false;
        for (edge *e = first[u]; e; e = e->next)
            if (tension(dis[e->to], dis[u] + e->cost) && !inQ[e->to])
                inQ[e->to] = true, q[(r++) % n] = e->to;
    }
}

inline bool cmp(const int &i, const int &j) {
    return dis[i] < dis[j];
}

int logs[MAX_N];
int fa[MAX_N], anc[20][MAX_N], dep[MAX_N];

inline int lca(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);

    for (int bit = logs[dep[u]]; bit >= 0; --bit)
        if (dep[u] - (1 << bit) >= dep[v])
            u = anc[bit][u];
    if (u == v) return u;

    for (int bit = logs[dep[u]]; bit >= 0; --bit)
        if (dep[u] - (1 << bit) >= 0 && anc[bit][u] != anc[bit][v])
            u = anc[bit][u], v = anc[bit][v];

    return fa[u];
}

int sz[MAX_N];

int dfs(int u) {
    sz[u] = 1;
    for (edge *e = newFirst[u]; e; e = e->next)
        sz[u] += dfs(e->to);
    return sz[u];
}

int solve() {
    static int ordID[MAX_N];
    for (int i = 0; i < n; ++i) ordID[i] = i;
    sort(ordID, ordID + n, cmp);

    for (int i = 2; i <= n; ++i) logs[i] = logs[i >> 1] + 1;

    dep[s] = 0, fa[s] = -1;
    for (int i = 0; i < n; ++i) {
        int u = ordID[i];
        fa[u] = -1;

        for (edge *e = first[u]; e; e = e->next)
            if (dis[e->to] + e->cost == dis[u]) {
                if (fa[u] == -1) fa[u] = e->to;
                else fa[u] = lca(fa[u], e->to);
            }

        if (fa[u] != -1) {
            dep[u] = dep[fa[u]] + 1;
            newFirst[fa[u]] = new (pit++) edge(newFirst[fa[u]], u); 

            anc[0][u] = fa[u];
            for (int w = 1; (1 << w) <= dep[u]; ++w)
                anc[w][u] = anc[w - 1][anc[w - 1][u]];
        }
    }

    dfs(s);
    int ans = 0;
    for (int i = 0; i < n; ++i) if (i != s) ans = max(ans, sz[i]);
    return ans;
}

int main() {
#ifdef DEBUG
    freopen("test.in", "r", stdin);
#endif
    n = readInt(), m = readInt(), s = readInt() - 1;
    for (int i = 0; i < m; ++i) {
        int u = readInt() - 1, v = readInt() - 1, w = readInt();
        first[u] = new (pit++) edge(first[u], v, w);
        first[v] = new (pit++) edge(first[v], u, w);
    }

    spfa(s);
    printf("%d\n", solve());
    return 0;
}