BZOJ 1425 - SGU 421 k-th Product

Published on 2016-04-12

题目地址

描述

问在给出的 nn 个数 中选 mm 个数乘积第 kk 大为多少。

样例输入

4 3 3
2 3 3 5

样例输出

30

样例解释

样例有 4 个可行的乘积,3*3*5=45,2*3*5=30(包含第 1 个 3),2*3*5=30(包含第 2 个 3),2*3*3=18。

数据规模

对于 15% 的数据,n15n\le15
对于 25% 的数据,k10k\le10
对于 40% 的数据,
对于 60% 的数据,ai0a_i\ge 0
对于 100% 的数据,1n,k10000,1m13,kC(n,m),ai1000001\le n, k\le 10000,1\le m\le 13,k\le C(n,m), \left| a_i \right|\le 100000

分析

裸暴力 15 分,直接 long long 搞。
首先要知道,正解肯定要高精乘,所以这个不再强调。
由于负数比较麻烦,我们先讨论没有负数的情况,也就是 60% 的数据。
从大到小排序后,应该是这个样子的:

a1a2a3a4a5...an1ana_1 \ge a_2 \ge a_3 \ge a_4 \ge a_5 \ge ... \ge a_{n - 1} \ge a_n

我们考虑取 5 个数的情况,首先找第 1 大乘积,它必定是 a1a2a3a4a5a_1a_2a_3a_4a_5。接着考虑第 2 大,一番思考后发现,只能把 a5a_5 换成 a6a_6,换其他的任何数都不会比把 a5a_5 换成 a6a_6 更优。那么第 2 大就是 a1a2a3a4a6a_1a_2a_3a_4a_6
接着考虑第 3 大,我们还是用替换法构造,这回可能有两个结果了,分别是 a4a_4 换成 a5a_5a6a_6 换成 a7a_7,不难发现其他的不会更优:

可以发现,对于一个当前最大的乘积:

ap1ap2ap3...apm(p1<p2<p3<...<pm)a_{p_1}a_{p_2}a_{p_3}...a_{p_m}(p_1 < p_2 < p_3 < ... < p_m)

只有当 pi+1<pi+1p_i + 1 < p_{i + 1} 时(i=mi = m 时是 pm+1<=np_m + 1 <= n),将 pip_i 替换成 pi+1p_i + 1 才有可能更优。因为总要替换掉一个数,而且不能造成重复选数。如果把 pip_i 换成 pi+2p_i + 2,要么 pi+1=pi+1p_{i + 1} = p_i + 1,那么不如由 pi+1p_{i + 1} 来替换 pi+2p_i + 2。要么 pi+1<pi+1p_i + 1 < p_{i + 1},那么不如把 pip_i 换成 pi+1p_i + 1,其余的部分证明是类似的。
那么对于正数的算法就很明确了,我们维护一个大根堆,一开始把 a1a2a3...ana_1a_2a_3...a_n 放进堆中。从堆中弹出最大的乘积,用刚刚的替换法把新的乘积替换出来,再放进堆中。这样 k1k - 1 次过后,堆顶就是第 kk 大乘积。
这样做有一个缺陷,就是有可能构造由两个不同的乘积构造出一个相同的乘积,比如 a1a2a4a6a_1a_2a_4a_6 既可以由 a1a2a3a6a_1a_2a_3a_6,也可由 a1a2a4a5a_1a_2a_4a_5 得到,这种情况我们可以对选择的所有数的下标 hash 解决。

接着考虑有负数的情况,也就是 100% 的数据:
设正数有 n1n_1 个,负数有 n2n_2 个,则 n1+n2=mn_1 + n_2 = m
有负数的麻烦地方在于,不知道怎么选负数,一不小心就乘积变成负数。但可以肯定的是,第 kk 大一定选取 [0,n2][0, n_2] 个负数。
我们首先把正负数分开,枚举选择 xx 个负数,分别算出选择 xx 个负数时第 kk 大,都放进大根堆中,这样就可以找到全局的第 kk 大了。注意到如果 xx 是奇数,那么负数的排序就要颠倒过来。
这样做的复杂度是 O(mklog(k+m)L)O(mk\log (k + m) * L)LL 为数位长度,也就是高精度消耗。

细节不可能一一说全,代码就是一切。

代码

//  Created by Sengxian on 4/12/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  bzoj 1425 堆,贪心,哈希
#include <algorithm>
#include <iostream>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
using namespace std;

inline int ReadInt() {
    static int n, ch;
    static bool flag;
    n = 0, ch = getchar(), flag = false;
    while (!isdigit(ch)) flag |= ch == '-', ch = getchar();
    while (isdigit(ch)) n = (n << 3) + (n << 1) + ch - '0', ch = getchar();
    return flag ? -n : n; 
}

typedef long long ll;
typedef unsigned long long ull;
const int maxn = 10000 + 3;

struct BigInt {
    static const int maxDigit = 8;
    static const ll BASE = 1000000000000LL;
    ll digits[maxDigit];
    int len;
    inline void operator *= (const int v) {
        if (v == 0) {len = 0; return;}
        if (!len) return;
        for (int i = 0; i < len; ++i) digits[i] *= v;
        for (int i = 0; i < len; ++i) digits[i + 1] += digits[i] / BASE, digits[i] %= BASE;
        if (digits[len]) len++;
    }
    inline bool operator < (const BigInt &s) const {
        if (len != s.len) return len < s.len;
        for (int i = len - 1; i >= 0; --i)
            if (digits[i] != s.digits[i])
                return digits[i] < s.digits[i];
        return false;
    }
    inline bool operator > (const BigInt &s) const {
        if (len != s.len) return len > s.len;
        for (int i = len - 1; i >= 0; --i)
            if (digits[i] != s.digits[i])
                return digits[i] > s.digits[i];
        return false;
    }
    inline void printLine() const {
        if (len == 0) {puts("0"); return;}
        printf("%lld", digits[len - 1]);
        for (int i = len - 2; i >= 0; --i) printf("%012lld", digits[i]);
        putchar('\n');
    }
    inline void init() {
        memset(digits, 0, sizeof digits);
        len = 1;
        digits[0] = 1;
    }
    BigInt() {init();}
};

struct state {
    vector<int> c[2];
    BigInt product;
    bool operator < (const state &s) const {
        bool op1 = c[1].size() & 1, op2 = s.c[1].size() & 1;
        if (op1 == op2) {
            if (op1 == 0) return product < s.product;
            else return product > s.product;
        }else {
            if (op1 == 0) return false;
            else return true;
        }
    }
    ull hash() {
        ull val = 0;
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < c[i].size(); ++j)
                val = val * 999997 + c[i][j] + 1;
            val = val * 999997 + maxn + 10;
        }
        return val;
    }
};

int n, m, k, a[maxn], b[maxn], n1, n2;
priority_queue<state> Kth;

set<ull> s;
inline void go(int m1, int m2) {
    priority_queue<state> PQ;
    s.clear();
    state st;
    for (int j = 0; j < m1; ++j) st.c[0].push_back(j), st.product *= a[j];
    for (int j = 0; j < m2; ++j) st.c[1].push_back(j), st.product *= b[j];
    PQ.push(st), Kth.push(st);
    for (int i = 0; i < k - 1 && PQ.size(); ++i) {
        st = PQ.top(); PQ.pop();
        for (int x = 0; x < 2; ++x) {
            const vector<int> &c = st.c[x];
            int m = c.size();
            for (int j = 0; j < m; ++j) // 如果这一位 +1 不等于下一位,那么就可以推
                if ((j + 1 < m && c[j] + 1 != c[j + 1]) || (j + 1 == m && c[j] + 1 < (x == 0 ? n1 : n2))) {
                    state ns = st;
                    ns.c[x][j] = c[j] + 1, ns.product.init();
                    for (int t = 0; t < m1; ++t) ns.product *= a[ns.c[0][t]];
                    for (int t = 0; t < m2; ++t) ns.product *= b[ns.c[1][t]];
                    ull hash = ns.hash();
                    if (s.count(hash)) continue; //注意按照这种方式构造排列是有可能重复的!
                    PQ.push(ns), Kth.push(ns);
                    s.insert(hash);
                }
        }
    }
}

void solve() {
    //枚举选多少个正数
    for (int i = 0; i <= m; ++i) {
        if (i > n1 || (m - i) > n2) continue;
        if ((m - i) & 1) reverse(b, b + n2); // 如果奇数个负数,那么倒转回来!
        go(i, m - i);
        if ((m - i) & 1) reverse(b, b + n2); // 倒转回来
    }
    for (int i = 0; i < k - 1; ++i) Kth.pop();
    if (Kth.top().c[1].size() & 1) putchar('-');
    Kth.top().product.printLine();
}

int main() {
    n = ReadInt(), m = ReadInt(), k = ReadInt();
    n1 = 0, n2 = 0;
    for (int i = 0; i < n; ++i) {
        int x = ReadInt();
        if (x >= 0) a[n1++] = x;
        else b[n2++] = -x;
    }
    sort(a, a + n1), sort(b, b + n2);
    reverse(a, a + n1), reverse(b, b + n2);
    solve();
    return 0;
}