UVa 11992 - Fast Matrix Operations(线段树模板)
Published on 2016-01-06描述
给一个总元素个数为 的矩阵,矩阵长度为 ,宽度为 ,矩阵初始值全为整数 。定义一个子矩阵 (包括边界),操作有三种:
- 将子矩阵的每个元素加上一个值 。
- 将子矩阵的每个元素赋值为 。
- 查询子矩阵所有元素的和,最小值,最大值。
样例输入
4 4 8 1 1 2 4 4 5 3 2 1 4 4 1 1 1 3 4 2 3 1 2 4 4 3 1 1 3 4 2 2 1 4 4 2 3 1 2 4 4 1 1 1 4 3 3
样例输出
45 0 5 78 5 7 69 2 7 39 2 7
附加输入
1 2 8 1 1 1 1 2 1 3 1 1 1 2 1 1 1 1 1 1 3 1 1 1 2 2 1 2 1 2 3 3 1 1 1 2 1 1 1 1 2 3 3 1 1 1 2
附加输出
2 1 1 3 1 2 5 2 3 11 5 6
分析
因为行数不超过二十行,可以为每行建立一个线段树,则二维操作可由一维合并而来。
本题是线段树基本操作的汇总,同时有 和 操作,对于它们的考虑顺序是极为重要的,建议在头脑清醒的时候看,不然容易对线段树失去兴趣 -_-。
首先规定,若一个节点上同时有 和 标记,先执行 ,再执行 ,这个的重要性在后面就会体现。
讨论执行 操作过程中会遇到的事:
- 若区间不相交,直接退出。
- 若待加区间完全包含当前区间,无论有没有 标记(因为 标记优先考虑),直接将此节点的 标记增加。
- 若无任何标记,一切照常。
- 若已有 标记,不必理会,因为 不存在相互影响的关系(但为了方便,我们还是传递到子树)。
- 若有 标记,将 标记传递到子树,同时更新子树以及自己的信息。
接着讨论将节点 的 标记传递给儿子过程中会发生的事情:
- 若儿子有 标记,清除其 标记:因为儿子的区间已经被强制 , 操作必然是过时的。
- 若自己有 标记,将自己的 标记传递给儿子,自己的 标记清零(这一点接下来会说明)。
讨论执行 操作过程中会遇到的事:
- 若区间不相交,直接退出。
- 若待加区间完全包含当前区间,将此节点的 标记清空,直接赋值 标记,不管以前有没有 标记。
- 若无任何标记,一切照常。
- 若已有 标记,将 标记传递到子树,同时更新子树以及自己的信息。
- 若已有 标记,将 标记传递到子树,同时更新子树以及自己的信息。
总结一下,对于 操作要清除 标记,对于 操作不清除 标记,两种操作都要传递标签到子树,查询时要按照先 后 的顺序考虑。细节见代码。
很麻烦吧,自己用笔推一推,就会发现这是很自然的事情,只有自己理解了,线段树才写的对啊。
代码
// Created by Sengxian on 1/6/16. // Copyright (c) 2015年 Sengxian. All rights reserved. // UVa 11992 线段树所有基本操作 #include <algorithm> #include <iostream> #include <cstring> #include <cstdio> #include <deque> #include <climits> #include <cassert> #include <set> using namespace std; const int maxr = 20 + 3; inline void tension(int &a, const int b) { if(b < a) a = b; } inline void relax(int &a, const int b) { if(b > a) a = b; } struct SegmentTree { static const int maxNode = (1 << 20) * 2 + 100; int sum[maxNode], minValue[maxNode], maxValue[maxNode], addv[maxNode], setv[maxNode], n; void init(int _n) { n = 1; while(n < _n) n *= 2; for(int i = 0; i < n * 2; ++i) { sum[i] = minValue[i] = maxValue[i] = 0; setv[i] = -1; addv[i] = 0; } } int a, b, v; void pushdown(int k) { if(setv[k] >= 0) { setv[k * 2 + 1] = setv[k * 2 + 2] = setv[k]; addv[k * 2 + 1] = addv[k * 2 + 2] = 0; setv[k] = -1; } if(addv[k] > 0) { addv[k * 2 + 1] += addv[k]; addv[k * 2 + 2] += addv[k]; addv[k] = 0; } } void maintain(int k, int l, int r) { sum[k] = minValue[k] = maxValue[k] = 0; if(setv[k] >= 0) { sum[k] = setv[k] * (r - l); minValue[k] = maxValue[k] = setv[k]; }else if(r - l > 1) { sum[k] = sum[k * 2 + 1] + sum[k * 2 + 2]; minValue[k] = min(minValue[k * 2 + 1], minValue[k * 2 + 2]); maxValue[k] = max(maxValue[k * 2 + 1], maxValue[k * 2 + 2]); } if(addv[k] > 0) sum[k] += addv[k] * (r - l); minValue[k] += addv[k]; maxValue[k] += addv[k]; } void add(int k, int l, int r) { int mid = (l + r) / 2; if(r <= a || l >= b) return; //have no intersection if(l >= a && r <= b) addv[k] += v; else { pushdown(k); if(mid > a) add(k * 2 + 1, l, mid); else maintain(k * 2 + 1, l, mid); //left if(mid < b) add(k * 2 + 2, mid, r); else maintain(k * 2 + 2, mid, r); //right } maintain(k, l, r); } void Add(int l, int r, int _v) { a = l, b = r, v = _v; add(0, 0, n); } void set(int k, int l, int r) { int mid = (l + r) / 2; if(r <= a || l >= b) return; if(l >= a && r <= b) { addv[k] = 0; setv[k] = v; }else { pushdown(k); if(mid > a) set(k * 2 + 1, l, mid); else maintain(k * 2 + 1, l, mid); if(mid < b) set(k * 2 + 2, mid, r); else maintain(k * 2 + 2, mid, r); } maintain(k, l, r); } void Set(int l, int r, int _v) { a = l, b = r, v = _v; set(0, 0, n); } int _sum, _min, _max; void query(int k, int l, int r, int add) { int mid = (l + r) / 2; if(r <= a || l >= b) return; else if(setv[k] >= 0) { int val = add + addv[k] + setv[k]; _sum += val * (min(r, b) - max(l, a)); tension(_min, val); relax(_max, val); }else if(l >= a && r <= b) { //no set effect and totally in _sum += (r - l) * add + sum[k]; tension(_min, minValue[k] + add); relax(_max, maxValue[k] + add); }else { query(k * 2 + 1, l, mid, add + addv[k]); query(k * 2 + 2, mid, r, add + addv[k]); } } void Query(int l, int r) { a = l, b = r; _sum = 0; _min = INT_MAX; _max = INT_MIN; query(0, 0, n, 0); } int Sum(int l, int r) { Query(l, r); return _sum; } int Min(int l, int r) { Query(l, r); return _min; } int Max(int l, int r) { Query(l, r); return _max; } }Solver[maxr]; int r, c, m; inline int ReadInt() { int x; scanf("%d", &x); return x; } int main() { while(scanf("%d%d%d", &r, &c, &m) == 3) { for(int i = 0; i < r; ++i) Solver[i].init(c); for(int i = 0; i < m; ++i) { int type = ReadInt(); int x1 = ReadInt() - 1, y1 = ReadInt() - 1, x2 = ReadInt(), y2 = ReadInt(); if(type == 1) { int v = ReadInt(); for(int j = x1; j < x2; ++j) { Solver[j].Add(y1, y2, v); } }else if(type == 2) { int v = ReadInt(); for(int j = x1; j < x2; ++j) Solver[j].Set(y1, y2, v); }else if(type == 3) { int sum = 0, _min = INT_MAX, _max = INT_MIN; for(int j = x1; j < x2; ++j) { Solver[j].Query(y1, y2); sum += Solver[j]._sum; tension(_min, Solver[j]._min); relax(_max, Solver[j]._max); } printf("%d %d %d\n", sum, _min, _max); }else assert(false); } } return 0; }