QTREE 2 - Query on a tree II

Published on 2016-03-16

题目地址

描述

给你一个 n(n10000)n(n\le 10000) 个节点的树,现在要你支持两种操作:

  • DIST a b 询问 aabb 的路径的距离。
  • KTH a b c 询问 aabb 的路径上第 ii 个点的编号。

分析

第二题算是比较简单的了,注意到没有修改操作,直接上树上倍增即可。
就是 KTH 的时候比较纠结,需要多次 swap,不过还是能轻松搞定的。

代码

//  Created by Sengxian on 3/16/16.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//    Spoj QTREE II
#include <algorithm>
#include <iostream>
#include <cassert>
#include <climits>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;

inline int ReadInt() {
    int n = 0, ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) n = (n << 1) + (n << 3) + ch - '0', ch = getchar();
    return n;
}

const int maxn = 10000 + 3;
struct edge {
    int to, cost;
    edge(int to, int cost): to(to), cost(cost) {}
};
vector<edge> G[maxn];
int n = 0, ancestor[maxn][14 + 1], dist[maxn][14 + 1], dep[maxn];

void dfs(int u, int fa) {
    ancestor[u][0] = fa, dep[u] = fa == -1 ? 0 : dep[fa] + 1;
    for (int i = 0; i < (int)G[u].size(); ++i) {
        edge &e = G[u][i];
        if (e.to != fa) {
            dist[e.to][0] = e.cost;
            dfs(e.to, u);
        }
    }
}

void process() {
    dfs(0, -1);
    for (int w = 1; (1 << w) < n; ++w)
        for (int i = 0; i < n; ++i) if (dep[i] - (1 << w) >= 0)
            ancestor[i][w] = ancestor[ancestor[i][w - 1]][w - 1],
            dist[i][w] = dist[i][w - 1] + dist[ancestor[i][w - 1]][w - 1];
}


int cnt = 0, cntA = 0;
int queryDist(int a, int b) {
    if (dep[a] < dep[b]) swap(a, b);
    int maxBit = log2(a + 0.5), dis = 0;
    cnt = 0, cntA = 0;
    for (int i = maxBit; i >= 0; --i)
        if (dep[a] - (1 << i) >= dep[b]) {
            dis += dist[a][i], cnt += 1 << i, cntA += 1 << i;
            a = ancestor[a][i];
        }
    if (a == b) return dis;
    for (int i = maxBit; i >= 0; --i)
        if (dep[a] - (1 << i) >= 0 && ancestor[a][i] != ancestor[b][i]) {
            dis += dist[a][i] + dist[b][i], cnt += 1 << (i + 1), cntA += 1 << i;
            a = ancestor[a][i], b = ancestor[b][i];
        }
    dis += dist[a][0] + dist[b][0], cnt += 2, cntA++;
    return dis;
}

int queryKth(int a, int b, int k) {
    if(k == 1) return a; else k--;
    queryDist(a, b);
    if (dep[a] < dep[b]) { //保证 a 比 b 深,如果交换的话,那么就是 cnt - k 大
        swap(a, b);
        k = cnt - k;
    }
    if (k > cntA) { //如果第 k 大不在 A 到 LCA 的路上,交换的话,那么还是 cnt - k 大
        swap(a, b);
        k = cnt - k;
    }
    //保证从 a 跑 k 步能到达!
    int target = dep[a] - k;
    for (int i = log2(k + 0.5); i >= 0; --i)
        if (dep[a] - (1 << i) >= target)
            a = ancestor[a][i];
    return a;
}

char op[10];
int main() {
    int caseNum = ReadInt();
    while (caseNum--) {
        n = ReadInt();
        for (int i = 0; i < n; ++i) G[i].clear();
        int f, t, c;
        for (int i = 0; i < n - 1; ++i) {
            f = ReadInt() - 1, t = ReadInt() - 1, c = ReadInt();
            G[f].push_back(edge(t, c));
            G[t].push_back(edge(f, c));
        }
        process();
        while (~scanf("%s", op) && op[1] != 'O') {
            if (op[0] == 'D') {
                printf("%d\n", queryDist(ReadInt() - 1, ReadInt() - 1));
            }else if (op[0] == 'K') {
                int a = ReadInt() - 1, b = ReadInt() - 1, k = ReadInt();
                printf("%d\n", queryKth(a, b, k) + 1);
            }else assert(false);
        }
        putchar('\n');
    }
    return 0;
}