페이지

10550번: NORMA

https://www.acmicpc.net/problem/10550


$O(n\lg n)$

정수 배열 a[0...n-1]이 주어져 있다.
f(l,r): [l,r]에서 만들 수 있는 배열들의 가격 합
g(l,r): a[l...r] 가격

분할 정복을 적용해보자.
m=[(l+r)/2]이라 하면
i) l==r
f(l,r)=a[l]*a[l]
ii) l<r
f(l,r)=f(l,m)+f(m+1,r)+(*, l<=i<=m<j<=r 인 임의의 i, j에 대해 a[i...j]의 가격 합)
따라서 (*)을 $O(r-l+1)$에 구하면 전체 문제를 $O(n\lg n)$에 풀 수 있을 것이다.

m[i]: a[m...m+1-i] 최솟값 혹은 a[m+1...m+i] 최솟값
M[i]: a[m...m+1-i] 최댓값 혹은 a[m+1...m+i] 최댓값
x=m+1, y=m에서 시작해서 x=l, y=r이 될 때까지 한쪽을 1씩 증감시켜 f(x,y)를 계속 구해줄 것이다.
이때 증감시킨 한쪽 끝값 a[p]는 그전 구간 m[x...y]보다 작거나 같아야 한다.
한쪽을 구현하면 다른 쪽에도 그대로 적용가능하므로 편의상 p=x-1이고 m[p]<=m[x...y]라 하자.
f(p,y)=f(x,y)+g(p,m+1)+g(p,m+2)+...+g(p,y)이다.
f(x,y)는 이미 구했으므로 g(p,m+1)+g(p,m+2)+...+g(p,y)을 구하면 된다.
x를 1 감소시킨 다음 보면
j<=y인 j에 대해 g(x,j)=m[x]*max(M[x],M[j])*(j-x+1)이 된다.
그럼 M[x]>M[j]인 최대 j(=j')를 구할 생각을 할텐데 놀랍게도 j'은 알고리즘이 진행되는 동안 단조 증가하기 때문에 amortized O(1)에 구할 수 있다.
i) j<=j'
g(x,j)=m[x]*M[x]*(j-x+1)
g(x,m+1)+...+g(x,j')=m[x]*M[x]*{(m+2-x)~(j'-x+1)합}
ii) j>j'
g(x,j)=m[x]*M[j]*(j-x+1)
=m[x]*{ M[j]*(j-m)+M[j]*(m+1-x) }
g(x,j'+1)+...+g(x,y)는 M[j]*(j-m)와 M[j]의 누적합을 관리해서 빠르게 구할 수 있다.

#include<cstdio>
#define mod int(1e9)
const int MXN = 5e5;
typedef long long ll;
int n, res, a[MXN];
struct st {
    int cnt = 0, maxi = 0, s = 0, sl = 0;
    void add(int x) {
        if (maxi < x) maxi = x;
        s = (s + maxi) % mod;
        sl = (sl + (ll)maxi*++cnt) % mod;
    }
};
void f(int l, int r) {
    if (l == r) {
        res = (res + (ll)a[l] * a[l]) % mod;
        return;
    }
    int m = (l + r) / 2, mx = a[m], my = a[m + 1];
    st x, y, u, v;
    while (x.cnt + y.cnt <= r - l) {
        if (m + y.cnt == r || m - x.cnt >= l&&mx > my) {
            x.add(a[m - x.cnt]);
            while (v.cnt < y.cnt&&a[m + v.cnt + 1] < x.maxi) v.add(a[m + v.cnt + 1]);
            res = (res + (ll)(2 * x.cnt + 1 + v.cnt)*v.cnt / 2 % mod*mx%mod*x.maxi
                + (y.sl - v.sl + (ll)(y.s - v.s)*x.cnt% mod)*mx) % mod;
            if (m - x.cnt >= l&&mx > a[m - x.cnt]) mx = a[m - x.cnt];
        }
        else {
            y.add(a[m + y.cnt + 1]);
            while (u.cnt < x.cnt&&a[m - u.cnt] < y.maxi) u.add(a[m - u.cnt]);
            res = (res + (ll)(2 * y.cnt + 1 + u.cnt)*u.cnt / 2 % mod*my%mod*y.maxi
                + (x.sl - u.sl + (ll)(x.s - u.s)*y.cnt % mod)*my) % mod;
            if (m + y.cnt < r&&my > a[m + y.cnt + 1]) my = a[m + y.cnt + 1];
        }
    }
    f(l, m);
    f(m + 1, r);
}
int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) scanf("%d", a + i);
    f(0, n - 1);
    printf("%d", res);
    return 0;
}

댓글 없음 :

댓글 쓰기