HDOJ 5822 - color

Published on 2017-05-18

题目地址

描述

给定一个 n(n105)n(n\le {10}^5) 个点的基环树,用 m(m109)m(m\le {10}^9) 种颜色将这棵树染色,问有多少种本质不同的方案,答案对 109+7{10}^9 + 7 取模。

分析

由于基环树是一个环 + 一堆有根树,我们先考虑有根树怎么做。

dp(u)dp(u) 为给 uu 及其子树染色的本质不同的方案数。首先 uumm 种染色方案,我们要计算子树的本质不同的染色方案,问题就麻烦在同构的子树上面,根据树同构的套路,先求出 uu 的每个子树的 Hash 值,然后将 Hash 值相同的归为一类。显然不同类的方案数是可以直接用乘法原理的,我们考虑计算相同类的方案数,假设某一类有 aa 棵子树,均有 bb 种方案(显然同构的子树,方案数也是相同的),那么问题就转化为:有 aa 个元素,给每个元素一个 1b1\sim b 的编号,问有多少种本质不同的标号方法。根据最小表示法,我们只需要计算有多少个标号单调不降的序列即可。这个标号序列一定是一段一段的,每一段的标号都是相同的,我们枚举序列被分成的段数,那么方案数为

k=1a(a1k1)(bk) \sum_{k = 1}^a \binom {a - 1} {k - 1}\binom b k

根据恒等式

k(lm+k)(sn+k)=(l+slm+n),l0,integern,m \sum_k \binom l {m + k}\binom s {n + k} = \binom {l + s} {l - m + n}, l\ge 0, \text{integer}\;n, m

则和式化为

(a+b1a) \binom {a + b - 1} {a}

我们需要求 (a+b1a)mod109+7\binom {a + b - 1} {a} \bmod {10}^9 + 7 的结果,然而我们现在只能知道 bmod109+7b\bmod {10}^9 + 7 的结果,似乎没有办法继续做下去?考虑到 P=109+7P = {10}^9 + 7 是一个质数,我们使用 Lucas 定理

(a+b1a)((a+b1)modPamodP)(a+b1PaP)(modP) \binom {a + b - 1} {a} \equiv \binom {(a + b - 1)\bmod P} {a\bmod P} \binom {\left\lfloor \frac {a + b - 1} P\right\rfloor}{\left\lfloor \frac a P\right\rfloor}\pmod P

由于 an105a\le n\le {10}^5,所以 aP=0\left\lfloor \frac a P\right\rfloor = 0,所以 (a+b1PaP)=1\binom {\left\lfloor \frac {a + b - 1} P\right\rfloor}{\left\lfloor \frac a P\right\rfloor} = 1,那么我们可以直接将二项式系数的上指标对 PP 取模,算出来的答案是一样的。考虑到下指标非常小而上指标非常大,我们在计算组合数的时候,采用同行递推的方法计算,由于 aa 的和是 O(n)O(n) 的,所以计算组合数的复杂度也是 O(n)O(n) 的,考虑到对 Hash 值分类需要排序,DP 的总时间复杂度为 O(nlogn)O(n\log n)

现在的问题转化为:在一个长度为 ll 的环 cir\mathrm{cir} 上,每一个珠子有一个颜色(对应的子树的 Hash 值),每一个珠子又有若干种方案(相同颜色的珠子所有方案相同),问有多少种本质不同的方案。

环上同构,只用考虑旋转 x(0x<l)x(0\le x < l) 格这些置换就够了,本题中珠子有颜色,也就是说旋转 xx 格这一置换在置换群内,当且仅当旋转 xx 格之后,珠子的颜色能完全对上。使用 KMP 计算就能 O(n)O(n) 计算出所有的 xx

不难发现这样的置换群仍然是合法的,使用 Burnside 引理,我们只需计算出旋转 xx 格的不动点个数。旋转 xx 格会产生 gcd(l,x)\gcd(l, x) 个循环,同时 1,2,3,,gcd(l,x)1, 2, 3, \ldots, \gcd(l, x) 分别属于不同的循环,那么不动点个数为

i=1gcd(l,x)dp(cir(i)) \prod_{i = 1}^{\gcd(l, x)}dp(\mathrm{cir}(i))

记录一个前缀积,就能 O(1)O(1) 计算出不动点个数。

过程中需使用 1n1\sim n 的逆元,线性递推 O(n)O(n) 计算即可。瓶颈在于 DP 时对子树的 Hash 值排序,总复杂度 O(nlogn)O(n\log n)

代码

//  Created by Sengxian on 2017/05/17.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  HDOJ 5822 树的同构 + Burnside 引理
#include <bits/stdc++.h>
using namespace std;

typedef unsigned long long ull;
typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 100000 + 3, MOD = 1000000007;
vector<int> G[MAX_N];
int n, m, fa[MAX_N], cir[MAX_N];
bool inCir[MAX_N];
ull a[MAX_N], b[MAX_N], h[MAX_N], hashVal[MAX_N * 2];
ll inv[MAX_N], dp[MAX_N], prefixProduct[MAX_N];

inline bool cmp(const int &i, const int &j) {
    return h[i] < h[j];
}

inline int C(int n, int k) {
    assert(n >= k);
    ll res = 1;
    for (int i = 0; i < k; ++i)
        res = res * (n - i) % MOD * inv[i + 1] % MOD;
    return res;
}

ull dfs(int u, int dep = 0) {
    ull &val = h[u] = 0;
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (!inCir[v]) val += dfs(v, dep + 1);
        else G[u].erase(G[u].begin() + i), --i;
    }

    dp[u] = m;
    sort(G[u].begin(), G[u].end(), cmp);
    for (int i = 0, j = 0; i < (int)G[u].size(); i = j) {
        while (j < (int)G[u].size() && h[G[u][i]] == h[G[u][j]]) ++j;
        (dp[u] *= C(j - i - 1 + dp[G[u][i]], j - i)) %= MOD;
    }

    (val *= a[dep]) += b[dep];
    return val = val * val;
}

void getFail(int n, int *f, ull *str) {
    f[0] = f[1] = 0;
    for (int i = 1; i < n; ++i) {
        int j = f[i];
        while (j && str[i] != str[j]) j = f[j];
        f[i + 1] = str[i] == str[j] ? j + 1 : 0; 
    }
}

int solve(int n) {
    static int fail[MAX_N * 2];
    getFail(n, fail, hashVal);
    int ans = 0, cnt = 0;
    prefixProduct[0] = 1;
    for (int i = 0; i < n; ++i) prefixProduct[i + 1] = prefixProduct[i] * dp[cir[i]] % MOD;

    int j = 0;
    for (int i = 0; i < n * 2 - 1; ++i) {
        while (j && hashVal[i] != hashVal[j]) j = fail[j];
        if (hashVal[i] == hashVal[j]) ++j;
        if (j == n) (ans += prefixProduct[__gcd(n, i - n + 1)]) %= MOD, j = fail[j], ++cnt;
    }
    return ans * inv[cnt] % MOD;
}

int main() {
    for (int i = 0; i < MAX_N; ++i) a[i] = rand();
    for (int i = 0; i < MAX_N; ++i) b[i] = rand();
    inv[1] = 1;
    for (int i = 2; i < MAX_N; ++i) inv[i] = ((-(MOD / i) * inv[MOD % i] % MOD) + MOD) % MOD;

    int caseNum = readInt();
    while (caseNum--) {
        n = readInt(), m = readInt();
        static bool vis[MAX_N];
        for (int i = 0; i < n; ++i) G[i].clear(), vis[i] = false, inCir[i] = false;
        for (int i = 0; i < n; ++i) G[fa[i] = readInt() - 1].push_back(i);

        int cirSz = 0, u = 0;
        while (!vis[u]) {
            assert(u >= 0 && u < n);
            vis[u] = true;
            u = fa[u];
        }
        cir[cirSz++] = u;
        for (int i = fa[u]; i != u; i = fa[i]) cir[cirSz++] = i;
        for (int i = 0; i < cirSz; ++i) inCir[cir[i]] = true;
        for (int i = 0; i < cirSz; ++i) hashVal[i + cirSz] = hashVal[i] = dfs(cir[i]);

        printf("%d\n", solve(cirSz));
    }

    return 0;
}