10167 - 금광
플레인스위핑 + 세그먼트 트리 + maximum sub-array(maximum sub-sequence sum) 문제
O(N^3)
maximum sub-array 문제를 그리디하게 해결 할 경우 O(N) 이지만 (누적이 음수이면 배제)
y1, y2 에 대해서 각각 O(N) 을 수행하면 O(N^3) 이라서 문제를 해결 할 수 없다.
O(N^2*logN)
- maximum subarray 문제를 분할정복으로 푸는 방법은 O(NlogN) 이다.
- 그러나 세그먼트 트리를 이용해서 값을 갱신하게 되면, query는 O(1), update는 O(logN) 이다.
- y1, y2를 지정하는데 각각 O(N) 이고 query/update 수행 시간은 O(logN) 이기 때문에
- 결과적으로 O(N^2 * logN) 이 된다.
maximum subarray 문제를 분할정복법으로 풀기 + 세그먼트 트리 적용
- 세그먼트 트리를 이용해서 분할 정복을 한다.
- 일단 중앙값(mid) 기준으로 좌우로 나눈다고 가정할 때, 최대의 구간합은
- max(좌측 최대 구간합, 우측 최대 구간합, 걸쳐있는 구간의 최대합) 이라고 생각할 수 있다.
- 이를 O(1) 로 구하기 위해서는 다음과 같은 쿼리를 필요로 한다
- tsum(i, j) : i~j 구간의 합
- lsum(i~j) : i~j 일 때, i에서 시작한 좌측 구간에서의 최대합
- rsum(i~j) : i~j 일 때, j에서 시작한 우측 구간에서의 최대합
- maxsum(i~j) : i~j 일 때, 최대 구간 합 (실제 구하려는 값)
- 쿼리를 만들기 위해 만들어야 하는 트리와 수식은 아래와 같다
- tsum(i, j) : tsum(i, mid) + tsum(mid+1, j)
- lsum(i, j) : max(lsum(i, mid), tsum(i, mid) + lsum(mid+1, j))
- lsum(i, j) = max(좌측 자식의 lsum 값, 좌측의 구간합 + 우측의 좌측 구간 최대 합)
- rsum(i, j) : max(rsum(mid+1, j), tsum(mid+1, j) + lsum(i, mid))
- rsum(i, j) = max(우측 자식의 rsum 값, 우측의 구간합 + 좌측의 우측 구간 최대 합)
- maxsum(i, j) : max(maxsum(i, mid), maxsum(mid+1, j), rsum(1, mid) + lsum(mid + 1, j))
- maxsum(i, j) = max(좌측 자식 maxsum 값, 우측 자식 maxsum 값, 좌측 자식의 우측 구간 최대 + 우측 자식의 좌측 구간 최대)
maximum subarray 문제를 단편적으로 그리디하게 해결하면 O(N) 이 가장 빠르지만, 매번 갱신이라는 개념이 추가된다면 각각 O(N) 번 수행하는 것 보다. O(N) 에 대해서 O(logN) 수행하는 세그먼트 트리가 더 빠르다는 것이다.
하지만 구현 역시 만만치 않다.
#include <stdio.h>
#include <vector>
#include <algorithm>
#include <tuple>
#define SIZE (3000 << 2)
using namespace std;
typedef long long int lld;
typedef tuple<int, int, int> G;
lld tsum[SIZE];
lld lsum[SIZE];
lld rsum[SIZE];
lld maxsum[SIZE];
vector<G> g;
vector<int> yy;
lld query(int l, int r, int node, int lo, int hi) {
if (r < lo || hi < l)
return 0;
if (l <= lo && hi <= r)
return maxsum[node];
int mid = (l + r) / 2;
return max(query(l, mid, node * 2, lo, hi), query(mid + 1, r, node * 2 + 1, lo, hi));
}
void update(int l, int r, int node, int pos, int val) {
if (r < pos || pos < l)
return;
if (l == r) {
tsum[node] += val;
lsum[node] += val;
rsum[node] += val;
maxsum[node] += val;
return;
}
int mid = (l + r) / 2;
update(l, mid, node * 2, pos, val);
update(mid + 1, r, node * 2 + 1, pos, val);
tsum[node] = tsum[node * 2] + tsum[node * 2 + 1];
lsum[node] = max(lsum[node * 2], tsum[node * 2] + lsum[node * 2 + 1]);
rsum[node] = max(rsum[node * 2 + 1], tsum[node * 2 + 1] + rsum[node * 2]);
maxsum[node] = max(max(maxsum[node * 2], maxsum[node * 2 + 1]), rsum[node * 2] + lsum[node * 2 + 1]);
}
int main()
{
int n, i, x, y, w, j, px;
lld ans = -1e9 - 1;
scanf("%d", &n);
for (i = 0; i < n; i++) {
scanf("%d%d%d", &x, &y, &w);
g.push_back(G(x, y, w));
yy.push_back(y);
}
sort(g.begin(), g.end());
sort(yy.begin(), yy.end());
for (i = 0; i < n; i++) {
while (i + 1 < n && get<0>(g[i + 1]) == get<0>(g[i]))
i++;
x = get<0>(g[i]);
fill(tsum, tsum + SIZE, 0);
fill(lsum, lsum + SIZE, 0);
fill(rsum, rsum + SIZE, 0);
fill(maxsum, maxsum + SIZE, 0);
for (j = i; j >= 0; j--) {
if (i == n - 1)
i = i;
int x2 = get<0>(g[j]), y2 = get<1>(g[j]), w2 = get<2>(g[j]);
int pos = lower_bound(yy.begin(), yy.end(), y2) - yy.begin();
update(0, n - 1, 1, pos, w2);
if (j - 1 >= 0 && get<0>(g[j - 1]) == get<0>(g[j]))
continue;
lld t = query(0, n - 1, 1, 0, n - 1);
ans = max(ans, t);
}
}
printf("%lld", ans);
return 0;
}