UVa 11992 - Fast Matrix Operations(线段树模板)

Published on 2016-01-06

题目地址

描述

给一个总元素个数为 n(n1000000)n(n\le1000000) 的矩阵,矩阵长度为 r(r20)r(r\le20),宽度为 cc,矩阵初始值全为整数 00。定义一个子矩阵 (x1,y1,x2,y2)(x_{1}, y_{1}, x_{2}, y_{2})(包括边界),操作有三种:

  1. 将子矩阵的每个元素加上一个值 v(v>0)v(v > 0)
  2. 将子矩阵的每个元素赋值为 v(v0)v(v \ge 0)
  3. 查询子矩阵所有元素的和,最小值,最大值。

样例输入

4 4 8
1 1 2 4 4 5
3 2 1 4 4
1 1 1 3 4 2
3 1 2 4 4
3 1 1 3 4
2 2 1 4 4 2
3 1 2 4 4
1 1 1 4 3 3

样例输出

45 0 5
78 5 7
69 2 7
39 2 7

附加输入

1 2 8
1 1 1 1 2 1
3 1 1 1 2
1 1 1 1 1 1
3 1 1 1 2
2 1 2 1 2 3
3 1 1 1 2
1 1 1 1 2 3
3 1 1 1 2

附加输出

2 1 1
3 1 2
5 2 3
11 5 6

分析

因为行数不超过二十行,可以为每行建立一个线段树,则二维操作可由一维合并而来。
本题是线段树基本操作的汇总,同时有 setsetaddadd 操作,对于它们的考虑顺序是极为重要的,建议在头脑清醒的时候看,不然容易对线段树失去兴趣 -_-。
首先规定,若一个节点上同时有 setsetaddadd 标记,先执行 setset,再执行 addadd,这个的重要性在后面就会体现。
讨论执行 addadd 操作过程中会遇到的事:

  • 若区间不相交,直接退出。
  • 若待加区间完全包含当前区间,无论有没有 setset 标记(因为 setset 标记优先考虑),直接将此节点的 addadd 标记增加。
  • 若无任何标记,一切照常。
  • 若已有 addadd 标记,不必理会,因为 addadd 不存在相互影响的关系(但为了方便,我们还是传递到子树)。
  • 若有 setset 标记,将 setset 标记传递到子树,同时更新子树以及自己的信息。

接着讨论将节点 vvsetset 标记传递给儿子过程中会发生的事情:

  • 若儿子有 addadd 标记,清除其 addadd 标记:因为儿子的区间已经被强制 setsetaddadd 操作必然是过时的。
  • 若自己有 addadd 标记,将自己的 addadd 标记传递给儿子,自己的 addadd 标记清零(这一点接下来会说明)。

讨论执行 setset 操作过程中会遇到的事:

  • 若区间不相交,直接退出。
  • 若待加区间完全包含当前区间,将此节点的 addadd 标记清空,直接赋值 setset 标记,不管以前有没有 setset 标记。
  • 若无任何标记,一切照常。
  • 若已有 setset 标记,将 setset 标记传递到子树,同时更新子树以及自己的信息。
  • 若已有 addadd 标记,将 addadd 标记传递到子树,同时更新子树以及自己的信息。

总结一下,对于 setset 操作要清除 addadd 标记,对于 addadd 操作不清除 setset 标记,两种操作都要传递标签到子树,查询时要按照先 setsetaddadd 的顺序考虑。细节见代码。
很麻烦吧,自己用笔推一推,就会发现这是很自然的事情,只有自己理解了,线段树才写的对啊。

代码

//  Created by Sengxian on 1/6/16.
//  Copyright (c) 2015年 Sengxian. All rights reserved.
//  UVa 11992 线段树所有基本操作
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <deque>
#include <climits>
#include <cassert>
#include <set>
using namespace std;
const int maxr = 20 + 3;

inline void tension(int &a, const int b) {
    if(b < a) a = b;
}

inline void relax(int &a, const int b) {
    if(b > a) a = b;
}

struct SegmentTree {
    static const int maxNode = (1 << 20) * 2 + 100;
    int sum[maxNode], minValue[maxNode], maxValue[maxNode], addv[maxNode], setv[maxNode], n;
    void init(int _n) {
        n = 1;
        while(n < _n) n *= 2;
        for(int i = 0; i < n * 2; ++i) {
            sum[i] = minValue[i] = maxValue[i] = 0;
            setv[i] = -1; addv[i] = 0;
        }
    }
    int a, b, v;
    void pushdown(int k) {
        if(setv[k] >= 0) {
            setv[k * 2 + 1] = setv[k * 2 + 2] = setv[k];
            addv[k * 2 + 1] = addv[k * 2 + 2] = 0;
            setv[k] = -1;
        }
        if(addv[k] > 0) {
            addv[k * 2 + 1] += addv[k];
            addv[k * 2 + 2] += addv[k];
            addv[k] = 0;
        }
    }
    void maintain(int k, int l, int r) {
        sum[k] = minValue[k] = maxValue[k] = 0;
        if(setv[k] >= 0) {
            sum[k] = setv[k] * (r - l);
            minValue[k] = maxValue[k] = setv[k];
        }else if(r - l > 1) {
            sum[k] = sum[k * 2 + 1] + sum[k * 2 + 2];
            minValue[k] = min(minValue[k * 2 + 1], minValue[k * 2 + 2]);
            maxValue[k] = max(maxValue[k * 2 + 1], maxValue[k * 2 + 2]);
        }
        if(addv[k] > 0) sum[k] += addv[k] * (r - l); minValue[k] += addv[k]; maxValue[k] += addv[k];
    }
    void add(int k, int l, int r) {
        int mid = (l + r) / 2;
        if(r <= a || l >= b) return; //have no intersection
        if(l >= a && r <= b)
            addv[k] += v;
        else {
            pushdown(k);
            if(mid > a) add(k * 2 + 1, l, mid); else maintain(k * 2 + 1, l, mid); //left
            if(mid < b) add(k * 2 + 2, mid, r); else maintain(k * 2 + 2, mid, r); //right
        }
        maintain(k, l, r);
    }
    void Add(int l, int r, int _v) {
        a = l, b = r, v = _v;
        add(0, 0, n);
    }
    void set(int k, int l, int r) {
        int mid = (l + r) / 2;
        if(r <= a || l >= b) return;
        if(l >= a && r <= b) {
            addv[k] = 0;
            setv[k] = v;
        }else {
            pushdown(k);
            if(mid > a) set(k * 2 + 1, l, mid); else maintain(k * 2 + 1, l, mid);
            if(mid < b) set(k * 2 + 2, mid, r); else maintain(k * 2 + 2, mid, r);
        }
        maintain(k, l, r);
    }
    void Set(int l, int r, int _v) {
        a = l, b = r, v = _v;
        set(0, 0, n);
    }
    int _sum, _min, _max;
    void query(int k, int l, int r, int add) {

        int mid = (l + r) / 2;
        if(r <= a || l >= b) return;
        else if(setv[k] >= 0) {
            int val = add + addv[k] + setv[k];
            _sum += val * (min(r, b) - max(l, a));
            tension(_min, val);
            relax(_max, val);
        }else if(l >= a && r <= b) {
            //no set effect and totally in
            _sum += (r - l) * add + sum[k];
            tension(_min, minValue[k] + add);
            relax(_max, maxValue[k] + add);
        }else {
            query(k * 2 + 1, l, mid, add + addv[k]);
            query(k * 2 + 2, mid, r, add + addv[k]);
        }
    }
    void Query(int l, int r) {
        a = l, b = r;
        _sum = 0;
        _min = INT_MAX;
        _max = INT_MIN;
        query(0, 0, n, 0);
    }
    int Sum(int l, int r) {
        Query(l, r);
        return _sum;
    }
    int Min(int l, int r) {
        Query(l, r);
        return _min;
    }
    int Max(int l, int r) {
        Query(l, r);
        return _max;
    }
}Solver[maxr];

int r, c, m;

inline int ReadInt() {
    int x;
    scanf("%d", &x);
    return x;
}

int main() {
    while(scanf("%d%d%d", &r, &c, &m) == 3) {
        for(int i = 0; i < r; ++i) Solver[i].init(c);
        for(int i = 0; i < m; ++i) {
            int type = ReadInt();
            int x1 = ReadInt() - 1, y1 = ReadInt() - 1, x2 = ReadInt(), y2 = ReadInt();
            if(type == 1) {
                int v = ReadInt();
                for(int j = x1; j < x2; ++j) {
                    Solver[j].Add(y1, y2, v);
                }
            }else if(type == 2) {
                int v = ReadInt();
                for(int j = x1; j < x2; ++j)
                    Solver[j].Set(y1, y2, v);
            }else if(type == 3) {
                int sum = 0, _min = INT_MAX, _max = INT_MIN;
                for(int j = x1; j < x2; ++j) {
                    Solver[j].Query(y1, y2);
                    sum += Solver[j]._sum;
                    tension(_min, Solver[j]._min);
                    relax(_max, Solver[j]._max);
                }
                printf("%d %d %d\n", sum, _min, _max);
            }else assert(false);
        }
    }
    return 0;
}