BZOJ 4332 - [JSOI2012]分零食

Published on 2016-05-25

题目地址

描述

n(1n109)n(1\le n\le 10^9) 个小朋友站成一排,你有 m(1m10000)m(1\le m\le 10000) 颗糖果,你要把所有糖果分给这些小朋友。一个合法的方案是如果第 ii 个小朋友没有糖果,那么他之后的小朋友都没有糖果。给定 a,b,c(a4,b300,c100)a, b, c(a\le 4, b\le 300, c\le 100),如果一个小朋友分到了 xx 个糖果,那么的他的欢乐程度是 f(x)=ax2+bx+cf(x) = ax^2 + bx + c,没有分到糖果的小朋友的欢乐程度是 1。定义一种合法方案的欢乐程度是小朋友的欢乐程度的乘积。请你求所有合法方案的欢乐程度之和。

分析

g(i,j)g(i, j) 为前 ii 个小朋友分 jj 个糖果,每个小朋友都要有糖果的欢乐程度乘积之和,则 g(0,0)=1g(0, 0) = 1,方程为:

g(i,j)=1kjg(i1,jk)f(k) g(i, j) = \sum_{1\le k \le j}g(i - 1,j - k)f(k)

明显是卷积形式,强调一下 g(i)g(i)m+1m + 1 项。

发现 g(0)g(0) 是单位元,所以 gn=fng_n = f^n,可以用 FFT 在 O(mlogmlogmin(n,m))O(m\log m\log\min(n, m)) 的时间内求出来(O(mlogm)O(m\log m) 是一次乘法的时间)。但我们要求的不是 gng_n,而是 1kng(k,m)\sum\limits_{1\le k\le n}g(k, m),这怎么处理呢?

定义 p(i)=1kig(k)p(i) = \sum_{1\le k\le i}g(k),考虑快速幂求 p(m)p(m),先抛出一个结论:

p(i)={p(i2)+p(i2)g(i2)nmod2=0p(i2)+p(i2)g(i2)+g(n)otherwise p(i) = \begin{cases} p(\frac i 2) + p(\frac i 2)g(\frac i 2) & n \bmod 2 = 0\\ p(\left\lfloor\frac i 2\right\rfloor) + p(\left\lfloor\frac i 2\right\rfloor)g(\left\lfloor\frac i 2\right\rfloor)+ g(n)& otherwise \end{cases}

要证明上式,只需要证明 g(i2+k)=g(i2)g(k)g(\frac i 2 + k) = g(\frac i 2)g(k),由于 g(n)=fng(n) = f^n,所以有 g(a+b)=g(a)g(b)g(a + b) = g(a)g(b),所以上式成立。
多项式乘法满足结合律,于是可以使用快速幂解决问题,复杂度是 O(mlogmlogmin(n,m))O(m\log m\log\min(n, m))
实现方面,为了降低复杂度,g(i2)g(\frac i 2) 这一项是不停的平方上来的。

P.S:时间卡的很死,建议手写复数。

代码

//  Created by Sengxian on 5/26/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 4332 FFT 快速幂
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int ReadInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 1) + (n << 3) + ch - '0', ch = getchar();
    return n;
}

int modu;

inline void mod(int &x) {
    if (x < 0) x = (x % modu + modu) % modu;
    else if (x >= modu) x %= modu;
}

inline int mod1(const int &x) {
    if (x < 0) return (x % modu + modu) % modu;
    else if (x >= modu) return x % modu;
    return x;
}

namespace FFT {
    const double pi = acos(-1.0);
    struct complex {
        double a, b;
        complex(double a = 0, double b = 0): a(a), b(b) {}
        inline void init() {a = 0, b = 0;}
        inline complex operator + (const complex &ano) const {return complex(a + ano.a, b + ano.b);}
        inline complex operator - (const complex &ano) const {return complex(a - ano.a, b - ano.b);}
        inline complex operator * (const complex &ano) const {return complex(a * ano.a - b * ano.b, b * ano.a + a * ano.b);}
    };
    typedef complex C;
    typedef vector<C> vc;
    typedef vector<int> vi;

    vc a, b;

    void DFT(vc &a, int oper = 1) {
        int n = a.size();
        for (int i = 0, j = 0; i < n; ++i) {
            if (i > j) swap(a[i], a[j]);
            for (int l = n >> 1; (j ^= l) < l; l >>= 1);
        }
        for (int l = 1, ll = 2; l < n; l <<= 1, ll <<= 1) {
            double x = oper * pi / l;
            C omega = 1, omegan(cos(x), sin(x));
            for (int k = 0; k < l; ++k, omega = omega * omegan) {
                for (int st = k; st < n; st += ll) {
                    C tmp = omega * a[st + l];
                    a[st + l] = a[st] - tmp;
                    a[st] = a[st] + tmp;
                }
            }
        }
        if (oper == -1) for (int i = 0; i < n; ++i) a[i].a /= n;
    }

    vi& operator * (const vi &v1, const vi &v2) {
        int s = 1, ss = (int)v1.size() + (int)v2.size();
        while (s < ss) s <<= 1;
        a.resize(s), b.resize(s);
        for (int i = 0; i < s; ++i) a[i].init(), b[i].init();
        for (unsigned int i = 0; i < v1.size(); ++i) a[i] = v1[i];
        for (unsigned int i = 0; i < v2.size(); ++i) b[i] = v2[i];
        DFT(a), DFT(b);
        for (int i = 0; i < s; ++i) a[i] = a[i] * b[i];
        DFT(a, -1);
        static vi res;
        res.resize(v1.size());
        for (unsigned int i = 0; i < v1.size(); ++i) res[i] = mod1((int)round(a[i].a));
        return res;
    }

    void operator *= (vi &v1, const vi &v2) {
        int s = 1, ss = (int)v1.size() + (int)v2.size();
        while (s < ss) s <<= 1;
        a.resize(s), b.resize(s);
        for (int i = 0; i < s; ++i) a[i].init(), b[i].init();
        for (unsigned int i = 0; i < v1.size(); ++i) a[i] = v1[i];
        for (unsigned int i = 0; i < v2.size(); ++i) b[i] = v2[i];
        DFT(a), DFT(b);
        for (int i = 0; i < s; ++i) a[i] = a[i] * b[i];
        DFT(a, -1);
        for (unsigned int i = 0; i < v1.size(); ++i) v1[i] = mod1((int)round(a[i].a));
    }

    void operator += (vi &v1, const vi &v2) {
        for (unsigned int i = 0; i < v1.size(); ++i) mod(v1[i] += v2[i]);
    }

}

using namespace FFT;

int m, n, O, S, U;
vi f;

inline vi mod_powV(const vi &v, int b) {
    vi res(v.size(), 0), tmp = v;
    res[0] = 1;
    while (b) {
        if (b & 1) res *= tmp;
        tmp *= tmp;
        b >>= 1;
    }
    return res;
}

inline vi& solve(int n) {
    static vi res, ghalf;
    if (n == 1) return res = ghalf = f;
    solve(n / 2);
    res += res * ghalf;
    ghalf *= ghalf;
    if (n & 1) res += mod_powV(f, n), ghalf *= f;
    return res;
}

int main() {
    m = ReadInt(), modu = ReadInt(), n = ReadInt(), O = ReadInt(), S = ReadInt(), U = ReadInt();
    f = vi(m + 1, 0);
    for (int i = 1; i < m + 1; ++i) f[i] = ((ll)O * i * i + S * i + U) % modu;
    vi &res = solve(min(n, m));
    printf("%d\n", res[m]);
    return 0;
}