BZOJ 4869 - [Shoi2017]相逢是问候

Published on 2017-04-27

题目地址

描述

Informatik verbindet dich und mich.
信息将你我连结。

B 君希望以维护一个长度为 n(n50000)n(n\le 50000) 的数组,这个数组的下标为从 11nn 的正整数。

一共有 m(m50000)m(m\le 50000) 个操作,可以分为两种:

  • 0 l r0 \ l \ r:表示将第 ll 个到第 rr 个数 (al,al+1,,ar)(a_l, a_{l+1}, \ldots, a_r) 中的每一个数 aia_i 替换为 caic^{a_i},即 ccaia_i 次方,其中 cc 是输入的一个常数,也就是执行赋值

    aicai a_i \leftarrow c^{a_i}
  • 1 l r1 \ l \ r:求第 ll 个到第 rr 个数的和,也就是输出:

    i=lrai\sum_{i = l}^r a_i

因为这个结果可能会很大,所以你只需要输出结果 modp\mathrm{mod}\;p 的值即可。

分析

根据 BZOJ 2749 提供的信息,n>2n > 2 时,φ(n)\varphi(n) 是偶数,也就是说 φ(n)n2\varphi(n) \le \frac n 2。则一个数 nn 经过 O(logn)O(\log n)nφ(n)n\leftarrow \varphi(n) 的变换就会变到 nn

再根据 UVa 10692 提供的信息可以知道,存在指数循环节公式

axaxmodφ(n)+φ(n)(modn)(xφ(n)) a^x\equiv a^{x \bmod \varphi(n) + \varphi(n)} \pmod n (x\ge \varphi(n))

一旦模数变为 11,意味着结果必然是 00,所以一个数经过至多 O(logn)O(\log n) 次修改操作就必然会变为一个常数。

于是我们得到一个算法,利用线段树维护区间和,同时维护一个区间的最小操作次数。设 nn 通过 nφ(n)n\leftarrow\varphi(n) 的变换变到 11 需要 ss 次,当执行修改操作的时候,如果当前区间的最小操作次数 >s> s 时,则修改没有意义,直接不修改;否则正常递归下去。根据势能分析,我们用 O(nlog2n)O(n\log^2 n) 的代价执行了 O(nlogn)O(n\log n) 次修改。

对于单点修改,使用与 UVa 10692 一模一样的方法即可,复杂度为 O(log2n)O(\log^2 n)。注意到本题中 cc 是一个定值,可以通过预处理的方法 O(1)O(1) 求出 cxmodp(x232)c^x\bmod p(x\le 2^{32})。具体做法是预处理 pow1(n)\mathrm{pow_1}(n) 表示 cnmodpc^n\bmod p 的值,处理到 6553665536;再预处理 pow2(n)\mathrm{pow_2}(n) 表示 cnc^n 平方 1616modp\mathrm{mod}\;p 的值,同样处理到 6553665536,则 cxmodpc^x\bmod p 可以通过分两段查表的方式快速算出:

pow1[b & 65535] * pow2[b >> 16] % mod

由于只有 O(logn)O(\log n) 个模数,对每个模数都预处理,那么复杂度为:O(6553616logn)O(65536 \cdot 16\log n)

总的复杂度:O(nlog2n+6553616logn)O(n\log^2n + 65536 \cdot 16\log n)

一点后话:指数循环节公式只在 xφ(n)x\ge \varphi(n) 时成立,在 UVa 10692 中,用试乘来判断是否 φ(n)\ge \varphi(n),我们在试乘的时候,是以上一层返回的取模后结果作为幂试乘,这样并不准确,应该使用上一层的答案进行试乘。但是放心,经过验证,这样做没有任何问题。因为如果 xφ(n)x\ge \varphi(n),那么 axna^x\ge n 只在 n=6n = 6 时不成立,经过验证,这个带来的一系列后续影响不会造成答案的错误,所以大可放心使用。

代码

//  Created by Sengxian on 2017/04/26.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4869 数论 + 线段树
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 50000 + 3, MAX_LOG_N = 40, LIM = 350000;

int calc(int n) {
    int res = n, t = n;
    for (int i = 2; i * i <= n; ++i) if (t % i == 0) {
        res /= i, res *= i - 1;
        while (t % i == 0) t /= i;
    }
    if (t > 1) res /= t, res *= t - 1;
    return res;
}

int pow1[MAX_LOG_N][1 << 16], pow2[MAX_LOG_N][1 << 16], cnt[MAX_LOG_N];
vector<int> mods;

inline int modPow(int b, int mod) {
    int idx = lower_bound(mods.begin(), mods.end(), mod) - mods.begin();
    if (b < cnt[idx]) return pow1[idx][b]; // < mod
    return ((ll)pow1[idx][b & 65535] * pow2[idx][b >> 16] % mod) + mod;
}

map<int, int> phi;
int n, m, p, c, a[MAX_N], s;

int calc(int a, int cnt, int mod) { // 如果答案 >= mod,返回 % mod + mod
    if (c == 1) return c % mod;
    if (cnt == 0) return a < mod ? a : a % mod + mod;
    int res = calc(a, cnt - 1, phi[mod]);
    return modPow(res, mod);
}

namespace SegmentTree {
    static const int MAX_NODE = (1 << 16) * 2;

    #define ls (((o) << 1) + 1)
    #define rs (((o) << 1) + 2)
    #define mid (((l) + (r)) >> 1)

    struct Node {
        int sum, minOperCnt;
    } nodes[MAX_NODE];

    int n, *a;

    inline Node merge(const Node &lhs, const Node &rhs) {
        Node res;
        res.sum = (lhs.sum + rhs.sum) % p;
        res.minOperCnt = min(lhs.minOperCnt, rhs.minOperCnt);
        return res;
    }

    void build(int o, int l, int r) {
        nodes[o].minOperCnt = 0;
        if (r - l == 1) nodes[o].sum = a[l];
        else {
            build(ls, l, mid), build(rs, mid, r);
            nodes[o] = merge(nodes[ls], nodes[rs]);
        }
    }

    void init(int _n, int *_a) {
        n = _n, a = _a;
        build(0, 0, n);
    }

    int query(int o, int l, int r, int a, int b) {
        if (r <= a || l >= b) return 0;
        if (l >= a && r <= b) return nodes[o].sum;
        return (query(ls, l, mid, a, b) + query(rs, mid, r, a, b)) % p;
    }

    void modify(int o, int l, int r, int a, int b) {
        if (r <= a || l >= b || nodes[o].minOperCnt > s) return;
        if (r - l == 1) nodes[o].sum = calc(SegmentTree::a[l], ++nodes[o].minOperCnt, p) % p;
        else {
            modify(ls, l, mid, a, b);
            modify(rs, mid, r, a, b);
            nodes[o] = merge(nodes[ls], nodes[rs]);
        }
    }

    int query(int l, int r) {
        return query(0, 0, n, l, r);
    }

    void modify(int l, int r) {
        modify(0, 0, n, l, r);
    }

    void print() {
        for (int i = 0; i < n; ++i)
            cout << query(i, i + 1) << ' ';
        cout << endl;
    }
};

void prepare() {
    int now = p;
    while (now != 1) mods.push_back(now), now = phi[now] = calc(now);
    phi[1] = 1, mods.push_back(now);
    s = phi.size();

    reverse(mods.begin(), mods.end());
    for (int i = 0; i < (int)mods.size(); ++i) {
        pow1[i][0] = 1;
        for (int j = 1; j < (1 << 16); ++j)
            pow1[i][j] = (ll)pow1[i][j - 1] * c % mods[i];
        for (int j = 0; j < (1 << 16); ++j) {
            pow2[i][j] = pow1[i][j];
            for (int k = 0; k < 16; ++k) pow2[i][j] = (ll)pow2[i][j] * pow2[i][j] % mods[i];
        }
        if (c != 1) {
            ll tt = 1, t = 0;
            while (tt < mods[i]) tt *= c, ++t;
            cnt[i] = t;
        }
    }
}

int main() {
    n = readInt(), m = readInt(), p = readInt(), c = readInt();
    for (int i = 0; i < n; ++i) a[i] = readInt();

    prepare();
    SegmentTree::init(n, a);

    for (int i = 0; i < m; ++i) {
        int op = readInt(), l = readInt() - 1, r = readInt();
        if (op == 0) {
            SegmentTree::modify(l, r);
        } else printf("%d\n", SegmentTree::query(l, r));
    }

    return 0;
}