10167 - 금광

플레인스위핑 + 세그먼트 트리 + maximum sub-array(maximum sub-sequence sum) 문제

O(N^3)

  • maximum sub-array 문제를 그리디하게 해결 할 경우 O(N) 이지만 (누적이 음수이면 배제)

  • y1, y2 에 대해서 각각 O(N) 을 수행하면 O(N^3) 이라서 문제를 해결 할 수 없다.

O(N^2*logN)

  • maximum subarray 문제를 분할정복으로 푸는 방법은 O(NlogN) 이다.
  • 그러나 세그먼트 트리를 이용해서 값을 갱신하게 되면, query는 O(1), update는 O(logN) 이다.
  • y1, y2를 지정하는데 각각 O(N) 이고 query/update 수행 시간은 O(logN) 이기 때문에
  • 결과적으로 O(N^2 * logN) 이 된다.

maximum subarray 문제를 분할정복법으로 풀기 + 세그먼트 트리 적용

  • 세그먼트 트리를 이용해서 분할 정복을 한다.
  • 일단 중앙값(mid) 기준으로 좌우로 나눈다고 가정할 때, 최대의 구간합은
  • max(좌측 최대 구간합, 우측 최대 구간합, 걸쳐있는 구간의 최대합) 이라고 생각할 수 있다.
  • 이를 O(1) 로 구하기 위해서는 다음과 같은 쿼리를 필요로 한다
    • tsum(i, j) : i~j 구간의 합
    • lsum(i~j) : i~j 일 때, i에서 시작한 좌측 구간에서의 최대합
    • rsum(i~j) : i~j 일 때, j에서 시작한 우측 구간에서의 최대합
    • maxsum(i~j) : i~j 일 때, 최대 구간 합 (실제 구하려는 값)
  • 쿼리를 만들기 위해 만들어야 하는 트리와 수식은 아래와 같다
    • tsum(i, j) : tsum(i, mid) + tsum(mid+1, j)
    • lsum(i, j) : max(lsum(i, mid), tsum(i, mid) + lsum(mid+1, j))
      • lsum(i, j) = max(좌측 자식의 lsum 값, 좌측의 구간합 + 우측의 좌측 구간 최대 합)
    • rsum(i, j) : max(rsum(mid+1, j), tsum(mid+1, j) + lsum(i, mid))
      • rsum(i, j) = max(우측 자식의 rsum 값, 우측의 구간합 + 좌측의 우측 구간 최대 합)
    • maxsum(i, j) : max(maxsum(i, mid), maxsum(mid+1, j), rsum(1, mid) + lsum(mid + 1, j))
      • maxsum(i, j) = max(좌측 자식 maxsum 값, 우측 자식 maxsum 값, 좌측 자식의 우측 구간 최대 + 우측 자식의 좌측 구간 최대)

maximum subarray 문제를 단편적으로 그리디하게 해결하면 O(N) 이 가장 빠르지만, 매번 갱신이라는 개념이 추가된다면 각각 O(N) 번 수행하는 것 보다. O(N) 에 대해서 O(logN) 수행하는 세그먼트 트리가 더 빠르다는 것이다.

하지만 구현 역시 만만치 않다.

#include <stdio.h>
#include <vector>
#include <algorithm>
#include <tuple>
#define SIZE (3000 << 2)

using namespace std;

typedef long long int lld;
typedef tuple<int, int, int> G;

lld tsum[SIZE];
lld lsum[SIZE];
lld rsum[SIZE];
lld maxsum[SIZE];
vector<G> g;
vector<int> yy;

lld query(int l, int r, int node, int lo, int hi) {
    if (r < lo || hi < l)
        return 0;
    if (l <= lo && hi <= r)
        return maxsum[node];
    int mid = (l + r) / 2;
    return max(query(l, mid, node * 2, lo, hi), query(mid + 1, r, node * 2 + 1, lo, hi));
}

void update(int l, int r, int node, int pos, int val) {
    if (r < pos || pos < l)
        return;
    if (l == r) {
        tsum[node] += val;
        lsum[node] += val;
        rsum[node] += val;
        maxsum[node] += val;
        return;
    }
    int mid = (l + r) / 2;
    update(l, mid, node * 2, pos, val);
    update(mid + 1, r, node * 2 + 1, pos, val);
    tsum[node] = tsum[node * 2] + tsum[node * 2 + 1];
    lsum[node] = max(lsum[node * 2], tsum[node * 2] + lsum[node * 2 + 1]);
    rsum[node] = max(rsum[node * 2 + 1], tsum[node * 2 + 1] + rsum[node * 2]);
    maxsum[node] = max(max(maxsum[node * 2], maxsum[node * 2 + 1]), rsum[node * 2] + lsum[node * 2 + 1]);
}

int main()
{
    int n, i, x, y, w, j, px;
    lld ans = -1e9 - 1;

    scanf("%d", &n);
    for (i = 0; i < n; i++) {
        scanf("%d%d%d", &x, &y, &w);
        g.push_back(G(x, y, w));
        yy.push_back(y);
    }

    sort(g.begin(), g.end());
    sort(yy.begin(), yy.end());

    for (i = 0; i < n; i++) {
        while (i + 1 < n && get<0>(g[i + 1]) == get<0>(g[i]))
            i++;
        x = get<0>(g[i]);

        fill(tsum, tsum + SIZE, 0);
        fill(lsum, lsum + SIZE, 0);
        fill(rsum, rsum + SIZE, 0);
        fill(maxsum, maxsum + SIZE, 0);

        for (j = i; j >= 0; j--) {
            if (i == n - 1)
                i = i;
            int x2 = get<0>(g[j]), y2 = get<1>(g[j]), w2 = get<2>(g[j]);
            int pos = lower_bound(yy.begin(), yy.end(), y2) - yy.begin();
            update(0, n - 1, 1, pos, w2);
            if (j - 1 >= 0 && get<0>(g[j - 1]) == get<0>(g[j]))
                continue;
            lld t = query(0, n - 1, 1, 0, n - 1);
            ans = max(ans, t);
        }
    }

    printf("%lld", ans);

    return 0;
}

results matching ""

    No results matching ""