0%

4405.统计子矩阵

4405.统计子矩阵

题目描述:

给出一个 $ N \times M $ 的矩阵 $ A $ ,统计多少个子矩阵中所有数的和不超过给定的整数 $ K $

数据范围:
$ 1\le N,M \le 500 \\ 0 \le A_{ij} \le 10^3 \\ 1\le K \le 2.5 \times 10^8 $

题解:

很容易想到直接二维前缀和暴力, $ O(N^4) $ 复杂度太高。

可以优化到 $ O(N^3) $

很多类似的问题都可以使用这种做法,直接枚举子矩阵的左右边界,然后就变成了一个竖着的一维数组,使用滑动窗口滑动保持窗口内部总和 $ <=k $ ,然后直接计数就行。

滑动窗口:首先窗口大小为1或0,然后随着遍历元素,移动窗口尾部,扩大窗口。中间需要加上移动窗口尾部或者头部的while循环操作。满足某种条件就一直移动,注意边界,可以使用单调队列之类的。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
#define ll long long
#define lll long long
#define PII pair<int, int>
namespace FAST_IO
{

inline char nextChar()
{
static char buf[1000000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
}
#define getch getchar
template <class T>
inline void read(T &x)
{
T flag = 1;
x = 0;
char ch = getch();
while (ch < '0' || ch > '9')
{
if (ch == '-')
flag = -1;
ch = getch();
}
while (ch >= '0' && ch <= '9')
{
x = (x << 3) + (x << 1) + (ch ^ 48), ch = getch();
}
x *= flag;
}

template <class T, class... _T>
inline void read(T &x, _T &...y)
{
return read(x), read(y...);
}

inline void print128(lll x)
{
if (x < 0)
putchar('-'), x = -x;
if (x > 9)
print128(x / 10);
putchar(x % 10 + '0');
}

} // namespace FAST_IO
using namespace std;
using namespace FAST_IO;
const ll mod = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const ll INF_LL = 0x3f3f3f3f3f3f3f3f;
const double eps = 1e-5;
const int maxn = 5e2 + 10;
const int maxm = 1e5 + 10;
ll t, n, m, k;
ll a[maxn][maxn];
int main()
{
// #define COMP_DATA
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m >> k;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= m; j++)
{
cin >> a[i][j];
a[i][j] = a[i - 1][j] + a[i][j - 1] + a[i][j] - a[i - 1][j - 1];
}
}
int ans = 0;
for (int i = 1; i <= m; i++)
{
for (int j = i; j <= m; j++)
{
for (int p = 1, q = 1; q <= n; ++q)
{
while (p <= q && a[q][j] - a[p - 1][j] - a[q][i - 1] + a[p - 1][i - 1] > k)
{
p++;
}
if (p <= q)
ans += q - p + 1;
}
}
}
cout << ans << endl;
return 0;
}