UVa 10827 - Maximum sum on a torus

Published on 2016-10-19

题目地址

描述

给出一个 n×n(n75)n\times n(n\le 75) 的矩阵,把第一行和最后一行粘一起,把第一列和最后一列粘一起,形成一个环面,求出这个环面中最大的子矩阵和。

分析

关于最大子矩阵和,可以枚举子矩阵的两行,然后将两行中间,相同列的数加起来,这样就变成了一行数,在这一行中求解最大子段和就行了。而最大子段和,可以说是基础算法了,过程如下:

int ans = INT_MIN, tmp = 0;
for (int i = 0; i < n; ++i) {
    relax(ans, tmp + a[i]);
    tmp = max(tmp + a[i], 0);
}

核心思想是,对于每个元素 aia_i,计算以它结尾的最大子段和的值。为了做到这个,我们需要求以 ai1a_{i - 1} 结尾的最大子段和(或者什么都不选),而这个值存在于 fi1f_{i - 1} 中(算法中直接滚动了 ff)。显然求出最大的 ff,算法就是正确的,我们用归纳法证明我们的算法能够求出最大的 ff

若序列从 00 开始,对于 f1=0f_{-1} = 0,所以归纳法的基础是正确的。
不妨设 fi1,i1f_{i - 1}, i \ge 1 是最大的,则已经求出了以 ai1a_{i - 1} 结尾的最大子段和,若想扩展到以 aia_i 结尾,那么要么选上 aia_i,要么干脆什么都不选,所以 fi=max(fi1+ai,0)f_i = \max(f_{i - 1} + a_i, 0),由于我们只关心 fif_i 的值而不关心到底怎么选的,所以这一步一定也是最优的,所以这个算法是正确的。

而此题这个矩阵是环面的,直观的想法是把它对称成 4 个,然后求新的矩阵的最大子矩阵(大小不超过 n×nn\times n),而正因为有大小的限制,复杂度仍然为 O(n4)O(n^4)Codeforces 724C - Ray Tracing 中的思路启发了我们,两边都对称,反而没办法做,所以我们只对称一边,即将矩形翻下去。

我们枚举两行(由于有 2n2n 行,枚举的时候两行的距离不能超过 nn),用刚刚讲的方法求两行之间的最大子矩阵,然后还需要处理环面的情况,这个预处理一下前缀最大值和和后缀最大值即可求解。

代码

//  Created by Sengxian on 2016/10/18.
//  Copyright (c) 2016年 Sengxian. All rights reserved.
//  UVa 10827 最大子矩阵
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline int ReadInt() {
    static int n, ch;
    static bool flag;
    n = 0, ch = getchar(), flag = false;
    while (!isdigit(ch)) flag |= ch == '-', ch = getchar();
    while (isdigit(ch)) n = n * 10 + ch - '0', ch = getchar();
    return flag ? -n : n;
}

const int MAX_N = 75 + 3;
int n, grid[MAX_N * 2][MAX_N], sum[MAX_N][MAX_N * 2];

inline void relax(int &a, const int &b) { if (b > a) a = b; }

inline int solve(int s, int t) {
    int ans = INT_MIN, tmp = 0;
    static int a[MAX_N];
    for (int i = 0; i < n; ++i) a[i] = sum[i][t] - sum[i][s];
    //不跨越
    for (int i = 0; i < n; ++i) {
        relax(ans, tmp + a[i]);
        tmp = max(tmp + a[i], 0);
    }
    //跨越
    static int ss[MAX_N], pre[MAX_N], post[MAX_N];
    for (int i = 0; i < n; ++i) ss[i + 1] = ss[i] + a[i];
    for (int i = 0; i < n; ++i) {
        pre[i] = ss[i + 1];
        if (i) relax(pre[i], pre[i - 1]);
    }
    for (int i = n - 1; i >= 0; --i) {
        post[i] = ss[n] - ss[i];
        if (i + 1 != n) relax(post[i], post[i + 1]);
    }
    for (int i = 0; i < n - 1; ++i) relax(ans, pre[i] + post[i + 1]);
    return ans;
}

int main() {
    int caseNum = ReadInt();
    while (caseNum--) {
        n = ReadInt();
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < n; ++j)
                grid[i][j] = grid[i + n][j] = ReadInt();
        for (int j = 0; j < n; ++j)
            for (int i = 0; i < 2 * n; ++i)
                sum[j][i + 1] = sum[j][i] + grid[i][j];
        int ans = INT_MIN;
        for (int i = 0; i < n; ++i)
            for (int j = i + 1; j - i <= n && j <= 2 * n; ++j)
                relax(ans, solve(i, j));
        printf("%d\n", ans);
    }
    return 0;
}