Codeforces 809E - Surprise me!

Published on 2017-05-28

题目地址

描述

给定一棵 n(n200000)n(n\le 200000) 个点的树,每个点的点权 aia_i 形成了一个 1n1\sim n 的排列 a1,a2,,ana_1, a_2, \ldots, a_n,请你求

1n(n1)i=1nj=1nφ(aiaj)dis(i,j) \frac 1 {n(n - 1)}\cdot \sum_{i = 1}^n\sum_{j = 1}^n\varphi(a_ia_j)\cdot\mathrm{dis}(i,j)

答案对 109+7{10}^9 + 7 取模。

分析

暴力是 O(n2)O(n^2) 的,我们必须得把式子里面的 φ(aiaj)\varphi(a_ia_j) 拆掉才能继续优化。有这样一个公式

φ(ab)=φ(a)φ(b)dφ(d),d=gcd(a,b) \varphi(ab) = \varphi(a)\varphi(b)\frac d {\varphi(d)},\quad d = \gcd(a, b)

很容易使用 φ\varphi 的展开形式证明。现在我们要求

i=1nj=1nφ(ai)φ(aj)gcd(a,b)φ(gcd(a,b))dis(i,j) \sum_{i = 1}^n\sum_{j = 1}^n\varphi(a_i)\varphi(a_j)\frac {\gcd(a, b)} {\varphi(\gcd(a, b))}\mathrm{dis}(i, j)

根据一般的套路,我们枚举 gcd\gcd

d=1ndφ(d)i=1nj=1nφ(ai)φ(aj)dis(i,j)[gcd(ai,aj)=d] \sum_{d = 1}^n\frac d {\varphi(d)}\sum_{i = 1}^n\sum_{j = 1}^n\varphi(a_i)\varphi(a_j)\mathrm{dis}(i, j)[\gcd(a_i, a_j) = d]

使用莫比乌斯反演,设

g(d)=i=1nj=1nφ(ai)φ(aj)dis(i,j)[gcd(ai,aj)=d] g(d) = \sum_{i = 1}^n\sum_{j = 1}^n\varphi(a_i)\varphi(a_j)\mathrm{dis}(i, j)[\gcd(a_i, a_j) = d]

再设

f(d)=i=1nj=1nφ(ai)φ(aj)dis(i,j)[dgcd(ai,aj)] f(d) = \sum_{i = 1}^n\sum_{j = 1}^n\varphi(a_i)\varphi(a_j)\mathrm{dis}(i, j)[d\mid\gcd(a_i, a_j)]

则有

g(d)=i=1ndμ(i)f(id) g(d) = \sum_{i = 1}^{\left\lfloor\frac n d\right\rfloor}\mu(i)f(i\cdot d)

如果能求出 f(n)f(n),那么 g(n)g(n) 就能用上面的式子 O(nlogn)O(n\log n) 算出来。考虑求 f(d)f(d),我们观察到 f(d)f(d) 中的条件是 dgcd(ai,aj)d\mid\gcd(a_i, a_j),也就是说只有 daid\mid a_iaia_i 才可能有贡献,而题目中的权值是 1n1\sim n 的排列,这启发我们对所有 daid\mid a_i 的点 ii 建立虚树。根据调和级数 Hn=11+12++1n=O(lnn)H_n = \frac 1 1 + \frac 1 2 + \cdots + \frac 1 n = O(\ln n) 的事实,我们可以知道总的点数是 O(nlnn)O(n\ln n) 的。我们考虑如何在虚树上求出答案。

和式中的距离不好搞,那我们就把距离写成:

dis(u,v)=dep(u)+dep(v)2dep(lca(u,v)) \mathrm{dis}(u, v) = \mathrm{dep}(u) + \mathrm{dep}(v) - 2\cdot \mathrm{dep}(\mathrm{lca}(u, v))

现在的和式变成了

其中 dep(i)\mathrm{dep}(i)dep(j)\mathrm{dep}(j) 的贡献可以直接计算,考虑如何计算 dep(i,j)\mathrm{dep}(i, j) 的贡献。

由于虚树包含了所有关键点的 LCA,我们直接在虚树上 DP 即可,对于一个点 vv,我们求所有 LCA 为点 vv 的点对 (i,j)(i, j) 的贡献,只需要记录一下 svs_v 表示 vv 的子树中的 φ(ai)\varphi(a_i) 之和,就能轻松求出贡献,复杂度是线性的。

总共会有 O(nlnn)O(n\ln n) 个关键点,瓶颈在于建立虚树。我们使用 RMQ O(nlogn)O(1)O(n\log n) \sim O(1) 求 LCA,并且在加入关键点的时候,按照 DFS 序加入(这样就无需在建立虚树时排序),就能做到在 O(nlnn)O(n\ln n) 的时间内建立虚树。

总复杂度 O(nlogn)O(n\log n)

代码

//  Created by Sengxian on 2017/05/26.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  Codeforces 809E 莫比乌斯反演 + 虚树 DP
#include <bits/stdc++.h>
#define len(x) ((int)x.size())
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 200000 + 3, MOD = 1000000007;
vector<int> G[MAX_N], factor[MAX_N], vec[MAX_N];
int n, a[MAX_N], id[MAX_N], dfn[MAX_N], dep[MAX_N];
int seq[MAX_N * 2], minDep[20][MAX_N * 2], logs[MAX_N * 2];
ll inv[MAX_N];

void dfs(int u, int fa) {
    static int ts = 0;
    for (int i = 0; i < len(factor[a[u]]); ++i)
        vec[factor[a[u]][i]].push_back(u);
    dfn[u] = ts, seq[ts++] = u;
    for (int i = 0; i < (int)G[u].size(); ++i) {
        int v = G[u][i];
        if (v == fa) continue;
        dep[v] = dep[u] + 1;
        dfs(v, u);
        seq[ts++] = u;
    }
}

int primes[MAX_N], phi[MAX_N], mu[MAX_N], cnt = 0;
bool isNotPrime[MAX_N];

inline bool cmp(const int &i, const int &j) {
    return dep[i] < dep[j];
}

void prepare() {
    phi[1] = 1, mu[1] = 1;
    for (int i = 2; i <= n; ++i) {
        if (!isNotPrime[i]) primes[cnt++] = i, phi[i] = i - 1, mu[i] = -1;
        for (int j = 0; j < cnt && i * primes[j] <= n; ++j) {
            int t = i * primes[j];
            isNotPrime[t] = true;
            if (i % primes[j] == 0) {
                phi[t] = phi[i] * primes[j];
                mu[t] = 0;
                break;
            } else phi[t] = phi[i] * phi[primes[j]], mu[t] = -mu[i];
        }
    }

    for (int i = 1; i <= n; ++i)
        for (int j = i; j <= n; j += i)
            factor[j].push_back(i);

    dfs(0, -1);
    int len = n * 2 + 1;
    for (int i = 2; i <= len; ++i) logs[i] = logs[i >> 1] + 1;
    for (int i = 0; i < len; ++i) minDep[0][i] = seq[i];
    for (int w = 1; (1 << w) <= len; ++w)
        for (int i = 0; i + (1 << w) <= len; ++i)
            minDep[w][i] = min(minDep[w - 1][i], minDep[w - 1][i + (1 << (w - 1))], cmp);

    inv[1] = 1;
    for (int i = 2; i <= n; ++i) inv[i] = ((-(MOD / i) * inv[MOD % i] % MOD) + MOD) % MOD;
}

inline int query(int l, int r) {
    int w = logs[r - l];
    return min(minDep[w][l], minDep[w][r - (1 << w)], cmp);
}

inline int lca(int u, int v) {
    if (dfn[u] > dfn[v]) swap(u, v);
    return query(dfn[u], dfn[v] + 1);
}

inline int dis(int u, int v) {
    return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}

struct Edge {
    Edge *next;
    int to;
    Edge(Edge *next = NULL, int to = 0) : next(next), to(to) {}
} pool[MAX_N], *pit = pool, *first[MAX_N];

inline void addEdge(int u, int v) {
    first[u] = new (pit++) Edge(first[u], v);
}

void build(vector<int> &vec) {
    static int stk[MAX_N];
    int k = len(vec), sz = 0;
    for (int i = 0; i < k; ++i) first[vec[i]] = NULL;

    stk[sz++] = 0, first[0] = NULL, pit = pool;
    for (int i = 0; i < k; ++i) {
        int u = vec[i], lca = ::lca(u, stk[sz - 1]);

        if (lca == stk[sz - 1]) stk[sz++] = u;
        else {
            while (sz - 2 >= 0 && dep[stk[sz - 2]] >= dep[lca]) {
                 addEdge(stk[sz - 2], stk[sz - 1]);
                sz--;
            }

            if (stk[sz - 1] != lca) {
                first[lca] = NULL;
                addEdge(lca, stk[--sz]);
                stk[sz++] = lca, vec.push_back(lca);
            }

            stk[sz++] = u;
        }
    }

    for (int i = 0; i < sz - 1; ++i) addEdge(stk[i], stk[i + 1]);
}

ll s[MAX_N];

int dfs2(int u, int d) {
    ll res1 = 0, res2 = 0;
    s[u] = a[u] % d == 0 ? phi[a[u]] : 0;
    for (Edge *e = first[u]; e; e = e->next) {;
        (res1 += dfs2(e->to, d)) %= MOD;
        (res2 += s[e->to] * s[u] % MOD) %= MOD;
        (s[u] += s[e->to]) %= MOD;
    }
    res2 = res2 * -4 % MOD * dep[u] % MOD;
    if (a[u] % d == 0) (res2 += (ll)phi[a[u]] * phi[a[u]] % MOD * -2 * dep[u] % MOD) %= MOD;
    return (res1 + res2 + MOD) % MOD;
}

int calc(int d) {
    ll res = 0, sum = 0;
    for (int i = 0; i < len(vec[d]); ++i) (sum += phi[a[vec[d][i]]]) %= MOD;
    for (int i = 0; i < len(vec[d]); ++i) (res += sum * phi[a[vec[d][i]]] % MOD * dep[vec[d][i]] % MOD) %= MOD;
    res = res * 2 % MOD;
    build(vec[d]);
    (res += dfs2(0, d)) %= MOD;
    return (res + MOD) % MOD;
}

int main() {
#ifdef DEBUG
    freopen("bigtest.in", "r", stdin);
#endif
    n = readInt();
    for (int i = 1; i <= n; ++i) a[i] = readInt(), id[a[i]] = i;
    for (int i = 0; i < n - 1; ++i) {
        int u = readInt(), v = readInt();
        G[u].push_back(v), G[v].push_back(u);
    }
    G[0].push_back(1), G[1].push_back(0);

    prepare();

    static ll f[MAX_N];

    ll res = 0;

    for (int i = 1; i <= n; ++i) f[i] = calc(i);
    for (int d = 1; d <= n; ++d) {
        ll g = 0;
        for (int i = 1; i * d <= n; ++i)
            (g += mu[i] * f[i * d] % MOD) %= MOD;
        (res += g * d % MOD * inv[phi[d]] % MOD) %= MOD;
    }

    res = (ll)res * inv[n] % MOD * inv[n - 1] % MOD;

    printf("%d\n", (int)res);

    return 0;
}