HNOI 2017 Day 1 题解

Published on 2017-04-19

单旋

分析

我们还是维护这一棵 Spaly,与暴力模拟不同的是,我们需要每次在 O(logm)O(\log m) 的时间内完成所有操作。

对于插入操作,不难发现一个性质:插入关键字 key\mathrm{key},则 key\mathrm{key} 要么在前驱的右儿子,要么在后继的左儿子,两个中有且仅有一个是空的,key\mathrm{key} 就插入在空的那个地方。我们只需要维护前驱后继,就能实现 O(logm)O(\log m) 快速插入,深度也可以直接维护。

对于将最小值旋到根的操作(最大值同理),不难发现只会一直执行 zig\mathrm{zig} 操作。设最小值的节点为 vv,那么 vv 是没有左子树的。找一番规律可以得到:将 vv 转到根,就是将 vv 的右子树接到 vv 的父亲的左子树上,然后将 vv 作为根,vv 的右儿子指向原来的根。这种情况下,深度的变化是:vv 的深度变为 11vv 原来的右子树深度不变,其他的节点深度 +1+1。这种情况下,为了维护深度,需要实现权值区间的深度加减(vv 的右子树代表的区间是 vv 的权值到 vv 的父亲的权值之间的一段),我们直接用任意一种平衡树打标记维护就好了,复杂度 O(logm)O(\log m)

删除操作,直接删掉根,所有的的节点深度 1-1,也在平衡树上打标记就好了。

复杂度:O(mlogm)O(m\log m)

代码

#include <algorithm>
#include <cassert>
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <climits>
#include <cmath>
#include <vector>
using namespace std;

inline int readInt() {
    int n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_M = 100000 + 3, INF = 0x3f3f3f3f;

int m;

struct Splay {
    Splay *fa, *lc, *rc;
    int key;

    Splay(int key = 0) : fa(NULL), lc(NULL), rc(NULL), key(key) {}
} pool[MAX_M], *pit = pool;

namespace Treap {
    struct Node *null;
    struct Node {
        Node *lc, *rc;
        int key, val;
        int s, dep, tag;
        Splay *node;

        Node(int key = 0, int dep = 0, Splay *node = NULL) : lc(null), rc(null), key(key), val(rand()), s(1), dep(dep), tag(0), node(node) {}

        inline void maintain() {
            s = lc->s + rc->s + 1;
        }

        inline void add(int x) {
            dep += x, tag += x;
        }

        inline void pushDown() {
            if (tag != 0) {
                lc->add(tag);
                rc->add(tag);
                tag = 0;
            }
        }

        int lowerCount(int key) {
            if (this == null) return 0;
            return this->key >= key ? lc->lowerCount(key) : rc->lowerCount(key) + lc->s + 1;
        }

        int upperCount(int key) {
            if (this == null) return 0;
            return this->key > key ? lc->upperCount(key) : rc->upperCount(key) + lc->s + 1;
        }

        Node *select(int k) {
            pushDown();
            if (lc->s + 1 == k) return this;
            if (lc->s >= k) return lc->select(k);
            return rc->select(k - lc->s - 1);
        }

        void print() {
            if (this == null) return;
            pushDown();
            lc->print();
            printf("(%d, %d) ", key, dep);
            rc->print();
        }
    } pool[MAX_M], *pit, *root;

    void init() {
        pit = pool;
        null = new (pit++) Node(), null->s = 0;
        root = null;
    }

    Node *merge(Node *a, Node *b) {
        if (a == null) return b;
        if (b == null) return a;
        a->pushDown(), b->pushDown();

        if (a->val < b->val) {
            a->rc = merge(a->rc, b);
            a->maintain();
            return a;
        } else {
            b->lc = merge(a, b->lc);
            b->maintain();
            return b;
        }
    }

    void split(Node *o, int k, Node *&l, Node *&r) {
        if (o == null) {
            l = r = null;
            return;
        }

        o->pushDown();
        if (o->lc->s >= k) {
            split(o->lc, k, l, r);
            o->lc = r, r = o;
        } else {
            split(o->rc, k - o->lc->s - 1, l, r);
            o->rc = l, l = o;
        }

        o->maintain();
    }

    inline void insert(int key, int dep, Splay *v = NULL) {
        Node *l, *r;
        split(root, root->lowerCount(key), l, r);
        root = merge(merge(l, new (pit++) Node(key, dep, v)), r);
    }

    inline void add(int lVal, int rVal, int val) {
        Node *l, *tmp, *target, *r;
        split(root, root->lowerCount(lVal), l, tmp);

        if (tmp == null || tmp->select(1)->key > rVal) {
            root = merge(l, tmp);
            return;
        }

        split(tmp, tmp->upperCount(rVal), target, r);
        target->add(val);
        root = merge(merge(l, target), r);
    }

    Node *pred(int x) {
        int k = root->lowerCount(x);
        if (k == 0) return NULL;
        Node *v = root->select(k);
        return v;
    }

    Node *succ(int x) {
        int k = root->upperCount(x);
        if (k == root->s) return NULL;
        Node *v = root->select(k + 1);
        return v;
    }

    Node *leftest() {
        return root->select(1);
    }

    Node *rightest() {
        return root->select(root->s);
    }

    void delMin() {
        Node *l, *r;
        split(root, 1, l, r);
        root = r;
    }

    void delMax() {
        Node *l, *r;
        split(root, root->s - 1, l, r);
        root = l;
    }

    void print() {
        root->print();
        cout << endl;
    }
}

Splay *root = NULL;

int insert(int key) {
    Splay *v = new (pit++) Splay(key);
    if (root == NULL) {
        root = v;
        Treap::insert(key, 1, v);
        return 1;
    } else {
        Treap::Node *pred = Treap::pred(key), *succ = Treap::succ(key);
        if (pred == NULL || pred->node->rc != NULL) {
            succ->node->lc = v, v->fa = succ->node;
            Treap::insert(key, succ->dep + 1, v);
            return succ->dep + 1;
        } else if (succ == NULL || succ->node->lc != NULL) {
            pred->node->rc = v, v->fa = pred->node;
            Treap::insert(key, pred->dep + 1, v);
            return pred->dep + 1;
        }
    }
}

int splayMin() {
    Treap::Node *leftest = Treap::leftest();
    int dep = leftest->dep;
    Splay *v = leftest->node;

    if (v->fa != NULL) {
        Treap::add(1, INF, 1);
        Treap::add(v->key, v->key, -dep);
        Treap::add(v->key + 1, v->fa->key - 1, -1);
        if (v->rc != NULL) v->rc->fa = v->fa;
        v->fa->lc = v->rc;
        v->fa = NULL, root->fa = v, v->rc = root;
        root = v;
    }

    return dep;
}

int splayMax() {
    Treap::Node *rightest = Treap::rightest();
    int dep = rightest->dep;
    Splay *v = rightest->node;

    if (v->fa != NULL) {
        Treap::add(1, INF, 1);
        Treap::add(v->key, v->key, -dep);
        Treap::add(v->fa->key + 1, v->key - 1, -1);
        if (v->lc != NULL) v->lc->fa = v->fa;
        v->fa->rc = v->lc;
        v->fa = NULL, root->fa = v, v->lc = root;
        root = v;
    }

    return dep;
}

void delMin() {
    Treap::delMin();
    root = root->rc;
    if (root) root->fa = NULL;
    Treap::add(1, INF, -1);
}

void delMax() {
    Treap::delMax();
    root = root->lc;
    if (root) root->fa = NULL;
    Treap::add(1, INF, -1);
}

int main() {
    m = readInt();

    Treap::init();
    for (int i = 0; i < m; ++i) {
        int c = readInt();
        if (c == 1) {
            printf("%d\n", insert(readInt()));
        } else if (c == 2) {
            if (root == NULL) {
                puts("0");
                continue;
            }
            printf("%d\n", splayMin());
        } else if (c == 3) {
            if (root == NULL) {
                puts("0");
                continue;
            }
            printf("%d\n", splayMax());
        } else if (c == 4) {
            printf("%d\n", splayMin());
            delMin();
        } else if (c == 5) {
            printf("%d\n", splayMax());
            delMax();
        }
    }

    return 0;
}

影魔

分析

我们用单调栈 O(n)O(n) 求出 lastPos(i)\mathrm{lastPos}(i) 表示上一个比 kik_i 大的位置,nextPos(i)\mathrm{nextPos}(i) 表示下一个比 kik_i 大的位置。

我们考虑位置 ii 做最大值时的贡献:对于 q1q_1,当且仅当选择点对 (lastPos(i),nextPos(i))(\mathrm{lastPos}(i), \mathrm{nextPos}(i)) 时,有 q1q_1 的贡献,其他情况要么两边不同时大于 kik_i,要么 kik_i 不做最大值。

接着考虑什么时候会有 q2q_2 的贡献,只有选择的点对为 (lastPos(i),j),i<j<nextPos(i)(\mathrm{lastPos}(i), j), i<j<\mathrm{nextPos(i)} 或者 (j,nextPos(i)),lastPos(i)<j<i(j, \mathrm{nextPos}(i)), \mathrm{lastPos}(i) < j < i 时才有 q2q_2 的贡献。我们把所有这些点对写到平面上,那么区间 [l,r][l, r] 的答案就是点 (l,r)(l, r) 右下方的点权和。我们需要快速计算点 (l,r)(l, r) 右下角的点权和,这是一个稍弱的二维数点问题,唯一棘手的地方是,q2q_2 的贡献不只是一个点,而是一条横着或者竖着的线段。我们可以离线所有询问,从下往上和右往左分别做一次扫描线 + 线段树,就可以处理横着和竖着的线段,具体的实现方法与矩形周长并类似。对于 q1q_1 的贡献,由于只是一个点,随便在哪个方向计算都可以。

复杂度:O((n+m)logn)O((n + m)\log n)

代码

//  Created by Sengxian on 2017/04/17.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4826 离线 + 线段树
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 200000 + 3, MAX_M = 200000 + 3;

struct SegmentTree {
    static const int MAX_NODE = (1 << 18) * 2;

    #define ls (((o) << 1) + 1)
    #define rs (((o) << 1) + 2)
    #define mid (((l) + (r)) >> 1)

    int n;
    ll sum[MAX_NODE], tag[MAX_NODE];

    void init(int _n) {
        n = _n;
        memset(sum, 0, sizeof sum);
        memset(tag, 0, sizeof tag);
    }

    void giveTag(int o, int l, int r, ll val) {
        sum[o] += (r - l) * val;
        tag[o] += val;
    }

    void pushDown(int o, int l, int r) {
        if (tag[o]) {
            giveTag(ls, l, mid, tag[o]);
            giveTag(rs, mid, r, tag[o]);
            tag[o] = 0;
        }
    }

    ll query(int o, int l, int r, int a, int b) {
        if (r <= a || l >= b) return 0;
        if (l >= a && r <= b) return sum[o];
        pushDown(o, l, r);
        return query(ls, l, mid, a, b) + query(rs, mid, r, a, b);
    }

    void add(int o, int l, int r, int a, int b, int val) {
        if (r <= a || l >= b) return;
        if (l >= a && r <= b) giveTag(o, l, r, val);
        else {
            pushDown(o, l, r);
            add(ls, l, mid, a, b, val);
            add(rs, mid, r, a, b, val);
            sum[o] = sum[ls] + sum[rs];
        }
    }

    inline ll querySuffix(int pos) {
        return query(0, 0, n, pos, n);
    }

    inline ll queryPrefix(int pos) {
        return query(0, 0, n, 0, pos + 1);
    }

    inline void add(int l, int r, int val) {
        add(0, 0, n, l, r, val);
    }

    void print() {
        for (int i = 0; i < n; ++i)
            printf("%d ", query(0, 0, n, i, i + 1));
        cout << endl;
    }
} segmentTree;

int n, m, p1, p2, k[MAX_N], lastPos[MAX_N], nextPos[MAX_N];

void calc(int a[], int lastPos[]) {
    static int stk[MAX_N];
    int sz = 0;
    stk[sz++] = -1;

    for (int i = 0; i < n; ++i) {
        while (sz > 1 && a[stk[sz - 1]] < a[i]) --sz;
        lastPos[i] = stk[sz - 1];
        stk[sz++] = i;
    }
}

vector<pair<int, int> > query1[MAX_N], query2[MAX_N], add1[MAX_N], add2[MAX_N];
vector<int> add[MAX_N];

void prepare() {
    calc(k, lastPos);
    reverse(k, k + n);
    calc(k, nextPos);
    reverse(k, k + n);
    reverse(nextPos, nextPos + n);
    for (int i = 0; i < n; ++i) nextPos[i] = n - nextPos[i] - 1;

    for (int i = 0; i < n; ++i) {
        if (lastPos[i] != -1 && nextPos[i] != n)
            add[nextPos[i]].push_back(lastPos[i]);
        if (nextPos[i] != n) add1[nextPos[i]].push_back(make_pair(lastPos[i] + 1, i));
        if (lastPos[i] != -1) add2[lastPos[i]].push_back(make_pair(i + 1, nextPos[i]));
    }
}

ll ans[MAX_M];

void solve() {
    segmentTree.init(n);
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < (int)add[i].size(); ++j)
            segmentTree.add(add[i][j], add[i][j] + 1, p1);
        for (int j = 0; j < (int)add1[i].size(); ++j)
            segmentTree.add(add1[i][j].first, add1[i][j].second, p2);
        for (int j = 0; j < (int)query1[i].size(); ++j)
            ans[query1[i][j].second] += segmentTree.querySuffix(query1[i][j].first);
    }

    segmentTree.init(n);
    for (int i = n - 1; i >= 0; --i) {
        for (int j = 0; j < (int)add2[i].size(); ++j)
            segmentTree.add(add2[i][j].first, add2[i][j].second, p2);
        for (int j = 0; j < (int)query2[i].size(); ++j)
            ans[query2[i][j].second] += segmentTree.queryPrefix(query2[i][j].first);
    }

    for (int i = 0; i < m; ++i) printf("%lld\n", ans[i]);
}

int main() {
    n = readInt(), m = readInt(), p1 = readInt(), p2 = readInt();
    for (int i = 0; i < n; ++i) k[i] = readInt();

    prepare();

    for (int i = 0; i < m; ++i) {
        int l = readInt() - 1, r = readInt() - 1;
        ans[i] = (r - l) * p1;
        query1[r].push_back(make_pair(l, i));
        query2[l].push_back(make_pair(r, i));
    }

    solve();

    return 0;
}

礼物

首先我们需要最小化的是

i=1n(xiyi+c)2 \sum_{i = 1}^n (x_i - y_i + c)^2

展开

i=1n(xiyi+c)2=i=1n(xiyi)2+2(xiyi)c+c2=i=1nxi22xiyi+yi2+2(xiyi)c+c2=i=1nxi2+i=1nyi22i=1nxiyi+nc2+2(i=1nxii=1nyi)c \begin{aligned} &\sum_{i = 1}^n (x_i - y_i + c)^2\\ =&\sum_{i = 1}^n (x_i - y_i)^2 + 2(x_i - y_i)c + c^2\\ =&\sum_{i = 1}^n x_i^2 - 2x_iy_i + y_i^2 + 2(x_i - y_i)c + c^2\\ =&\sum_{i = 1}^nx_i^2 + \sum_{i = 1}^ny_i^2-2\sum_{i = 1}^nx_iy_i+nc^2+2(\sum_{i = 1}^nx_i - \sum_{i = 1}^ny_i)c \end{aligned}

后面关于 cc 的部分是一个二次函数,可以快速求出其在整点处的最小值。我们现在只需要最大化 i=1nxiyi\sum_{i = 1}^nx_iy_i,两个项链一共有 nn 种对齐方案,我们固定第一串项链,枚举第二串项链的起点 kk,则此时的答案为(令 yi+n=yiy_{i + n} = y_i

fk=i=1nxiyi+k f_k = \sum_{i = 1}^nx_iy_{i + k}

我们将 xx 翻转

fk=i=1nxni+1yi+k f_k = \sum_{i = 1}^nx_{n - i + 1}y_{i + k}

于是这就变成了一个卷积的形式,用长度为 nnxx 卷积长度为 2n2nyy,用 FFT 加速计算卷积即可做到 O(nlogn)O(n\log n),这样就能计算出每个起点的答案。总复杂度 O(nlogn)O(n\log n)

代码

//  Created by Sengxian on 2017/04/17.
//  Copyright (c) 2017年 Sengxian. All rights reserved.
//  BZOJ 4827 FFT
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int readInt() {
    static int n, ch;
    n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return n;
}

const int MAX_N = 50000 + 3;

int n, m, x[MAX_N], y[MAX_N];

const double pi = acos(-1.0);
typedef complex<double> C;
typedef vector<C> vc;
typedef vector<double> vd;

void FFT(vc &a, int oper = 1) {
    int n = a.size();
    for (int i = 0, j = 0; i < n; ++i) {
        if (i > j) swap(a[i], a[j]);
        for (int l = n >> 1; (j ^= l) < l; l >>= 1);
    }
    for (int l = 1, ll = 2; l < n; l <<= 1, ll <<= 1) {
        double x = oper * pi / l;
        C omega = 1, omegan(cos(x), sin(x));
        for (int k = 0; k < l; ++k, omega *= omegan) {
            for (int st = k; st < n; st += ll) {
                C tmp = omega * a[st + l];
                a[st + l] = a[st] - tmp;
                a[st] += tmp;
            }
        }
    }
    if (oper == -1) for (int i = 0; i < n; ++i) a[i] /= n;
}

vd operator * (const vd &v1, const vd &v2) {
    int s = 1, ss = (int)v1.size() + (int)v2.size();
    while (s < ss) s <<= 1;
    vc a(s, 0), b(s, 0);
    for (int i = 0; i < (int)v1.size(); ++i) a[i] = v1[i];
    for (int i = 0; i < (int)v2.size(); ++i) b[i] = v2[i];
    FFT(a), FFT(b);
    for (int i = 0; i < s; ++i) a[i] *= b[i];
    FFT(a, -1);
    vd res(s);
    for (int i = 0; i < s; ++i) res[i] = round(a[i].real());
    return res;
}

int main() {
    n = readInt(), m = readInt();
    int sX = 0, sY = 0, sqrA = 0, sqrB = 0;
    for (int i = 0; i < n; ++i) sX += x[i] = readInt(), sqrA += x[i] * x[i];
    for (int i = 0; i < n; ++i) sY += y[i] = readInt(), sqrB += y[i] * y[i];

    int v1 = floor((double)(sX - sY) / -n), v2 = ceil((double)(sX - sY) / -n);
    ll ans = min(n * v1 * v1 + 2 * (sX - sY) * v1, n * v2 * v2 + 2 * (sX - sY) * v2) + sqrA + sqrB;

    vector<double> vec1, vec2;
    for (int i = 0; i < n; ++i) vec1.push_back(x[n - i - 1]);
    for (int i = 0; i < n * 2; ++i) vec2.push_back(y[i % n]);

    vector<double> res = vec1 * vec2;

    ans -= *max_element(res.begin() + n - 1, res.begin() + n - 1 + n) * 2;

    printf("%lld\n", ans);

    return 0;
}