BZOJ 3811 - 玛里苟斯

Published on 2016-12-11

题目地址

描述

魔法之龙玛里苟斯最近在为加基森拍卖师的削弱而感到伤心,于是他想了一道数学题。

SS 是一个可重集合,

等概率随机取 SS 的一个子集 ,计算出 AA 中所有元素的异或值 xx,求 xk(1k5)x^k(1\le k\le 5) 的期望。

保证答案不超过 2642^{64}

分析

我们对每一个 kk 分别考虑。

k=1k = 1 时,由于贡献是线性的,所以我们可以对每一个二进制位分别考虑。我们考虑二进制位 ii,如果在某一个 aja_j 中,存在二进制位 ii11,那么子集的异或和中二进制位 ii11 的概率为 12\frac 1 2。如果不存在这样的 aja_j,概率为 00。证明很容易,因为子集中有奇数个或者偶数个 aja_j 二进制位 ii11 的概率是一样的,而只有奇数个 aja_j 二进制位 ii11 才能满足异或和中二进制位 ii11

k=2k = 2 时,我们求的是期望的平方,即每个异或和的贡献为 ,写成和式就是

ijbjbi2i+j \sum_{i}\sum_j b_jb_i \cdot 2 ^ {i + j}

我们需要枚举两个二进制位,现在每个数变成了 (0/1,0/1)(0/1,0/1) 二元组,仅当异或后得到 (1,1)(1, 1),才会产生 2i+j2^{i + j} 的贡献。根据 k=1k = 1 的情况不难发现,有 14\frac 1 4 的概率得到 (1,1)(1, 1)。需要特判的是,如果所有数都是 (1,1)(1, 1) 或者 (0,0)(0, 0) 且至少有一个 (1,1)(1, 1),那么概率为 12\frac 1 2。如果所有数都是 (0,0)(0, 0),那么概率为 00

k3k \ge 3 时,由于答案不超过 2642^{64},所以每个数不超过 2222^{22},这些数的线性基不会超过 2222 个,所以我们可以考虑求出线性基,然后暴力枚举线性基的子集即可。

虽然答案不会溢出,但是中间过程是有可能是溢出的,为了防止溢出,在 k1k \neq 1 时,若除数为 2m2^m,那么我们在乘的过程中,将答案记录为 y=y2m2m+ymod2my = \lfloor \frac y {2^m}\rfloor \cdot 2^m + y \bmod 2^m 的形式,这样两个项都不会超过 2642^{64},可以计算了。

现在我们考虑输出小数的问题,当 k=1k = 1 时显然小数位要么是 00 要么是 0.50.5;而 k1k \neq 1 时这个结论仍然成立(虚心求证明),所以特判一下就好了。

代码

//  Created by Sengxian on 2016/12/10.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  BZOJ 3811 线性基
#include <bits/stdc++.h>
using namespace std;

typedef unsigned long long ll;
inline ll readLL() {
    static ll n;
    static int ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 100000 + 3, MAX_BASE = 23;
int n, K;
ll a[MAX_N], b[MAX_N];

void solve1() {
    ll res = 0;
    for (int i = 0; i < n; ++i) res |= a[i];
    printf("%llu", res / 2);
    if (res & 1) puts(".5");
    else puts("");
}

void solve2() {
    ll ans = 0, res = 0;
    for (int i = 0; i < 32; ++i)
        for (int j = 0; j < 32; ++j) {
            bool flag = false;
            for (int k = 0; k < n; ++k) if (a[k] >> i & 1) { flag = true; break; }
            if (!flag) continue;
            flag = false;
            for (int k = 0; k < n; ++k) if (a[k] >> j & 1) { flag = true; break; }
            if (!flag) continue;

            flag = false;
            for (int k = 0; k < n; ++k) if ((a[k] >> i & 1) != (a[k] >> j & 1)) { flag = true; break; }

            if (i + j - 1 - flag < 0) res++;
            else {
                if (!flag) ans += 1LL << (i + j - 1); // 1 / 2
                else ans += 1LL << (i + j - 1 - 1); // 1 / 4
            }
        }

    ans += res >> 1, res &= 1;
    printf("%llu", ans);
    if (res) puts(".5");
    else puts("");
}

void solve3() {
    vector<int> vec;
    for (int i = 0; i < n; ++i)
        for (int j = MAX_BASE; j >= 0; --j)
            if (a[i] >> j & 1) {
                if (b[j]) a[i] ^= b[j];
                else {
                    b[j] = a[i];
                    vec.push_back(a[i]);
                    break;
                }
            }

    int all = vec.size();
    ll ans = 0, res = 0;
    for (int i = (1 << all) - 1; i >= 0; --i) {
        int val = 0;
        for (int j = 0; j < (int)vec.size(); ++j) if (i >> j & 1) val ^= vec[j];

        ll a = 0, b = 1;
        for (int j = 0; j < K; ++j) {
            a *= val, b *= val;
            a += b >> all, b &= (1 << all) - 1;
        }

        ans += a, res += b;
        ans += res >> all, res &= (1 << all) - 1;
    }

    printf("%llu", ans);
    if (res) puts(".5");
    else puts("");
}

int main() {
    n = readLL(), K = readLL();
    for (int i = 0; i < n; ++i) a[i] = readLL();

    if (K == 1) solve1();
    else if (K == 2) solve2();
    else solve3();
    return 0;
}