$O(n\lg n)$
정수 배열 a[0...n-1]이 주어져 있다.
f(l,r): [l,r]에서 만들 수 있는 배열들의 가격 합
g(l,r): a[l...r] 가격
분할 정복을 적용해보자.
m=[(l+r)/2]이라 하면
i) l==r
f(l,r)=a[l]*a[l]
ii) l<r
f(l,r)=f(l,m)+f(m+1,r)+(*, l<=i<=m<j<=r 인 임의의 i, j에 대해 a[i...j]의 가격 합)
따라서 (*)을 $O(r-l+1)$에 구하면 전체 문제를 $O(n\lg n)$에 풀 수 있을 것이다.
m[i]: a[m...m+1-i] 최솟값 혹은 a[m+1...m+i] 최솟값
M[i]: a[m...m+1-i] 최댓값 혹은 a[m+1...m+i] 최댓값
x=m+1, y=m에서 시작해서 x=l, y=r이 될 때까지 한쪽을 1씩 증감시켜 f(x,y)를 계속 구해줄 것이다.
이때 증감시킨 한쪽 끝값 a[p]는 그전 구간 m[x...y]보다 작거나 같아야 한다.
한쪽을 구현하면 다른 쪽에도 그대로 적용가능하므로 편의상 p=x-1이고 m[p]<=m[x...y]라 하자.
f(p,y)=f(x,y)+g(p,m+1)+g(p,m+2)+...+g(p,y)이다.
f(x,y)는 이미 구했으므로 g(p,m+1)+g(p,m+2)+...+g(p,y)을 구하면 된다.
x를 1 감소시킨 다음 보면
j<=y인 j에 대해 g(x,j)=m[x]*max(M[x],M[j])*(j-x+1)이 된다.
그럼 M[x]>M[j]인 최대 j(=j')를 구할 생각을 할텐데 놀랍게도 j'은 알고리즘이 진행되는 동안 단조 증가하기 때문에 amortized O(1)에 구할 수 있다.
i) j<=j'
g(x,j)=m[x]*M[x]*(j-x+1)
g(x,m+1)+...+g(x,j')=m[x]*M[x]*{(m+2-x)~(j'-x+1)합}
ii) j>j'
g(x,j)=m[x]*M[j]*(j-x+1)
=m[x]*{ M[j]*(j-m)+M[j]*(m+1-x) }
g(x,j'+1)+...+g(x,y)는 M[j]*(j-m)와 M[j]의 누적합을 관리해서 빠르게 구할 수 있다.
#include<cstdio> #define mod int(1e9) const int MXN = 5e5; typedef long long ll; int n, res, a[MXN]; struct st { int cnt = 0, maxi = 0, s = 0, sl = 0; void add(int x) { if (maxi < x) maxi = x; s = (s + maxi) % mod; sl = (sl + (ll)maxi*++cnt) % mod; } }; void f(int l, int r) { if (l == r) { res = (res + (ll)a[l] * a[l]) % mod; return; } int m = (l + r) / 2, mx = a[m], my = a[m + 1]; st x, y, u, v; while (x.cnt + y.cnt <= r - l) { if (m + y.cnt == r || m - x.cnt >= l&&mx > my) { x.add(a[m - x.cnt]); while (v.cnt < y.cnt&&a[m + v.cnt + 1] < x.maxi) v.add(a[m + v.cnt + 1]); res = (res + (ll)(2 * x.cnt + 1 + v.cnt)*v.cnt / 2 % mod*mx%mod*x.maxi + (y.sl - v.sl + (ll)(y.s - v.s)*x.cnt% mod)*mx) % mod; if (m - x.cnt >= l&&mx > a[m - x.cnt]) mx = a[m - x.cnt]; } else { y.add(a[m + y.cnt + 1]); while (u.cnt < x.cnt&&a[m - u.cnt] < y.maxi) u.add(a[m - u.cnt]); res = (res + (ll)(2 * y.cnt + 1 + u.cnt)*u.cnt / 2 % mod*my%mod*y.maxi + (x.sl - u.sl + (ll)(x.s - u.s)*y.cnt % mod)*my) % mod; if (m + y.cnt < r&&my > a[m + y.cnt + 1]) my = a[m + y.cnt + 1]; } } f(l, m); f(m + 1, r); } int main() { scanf("%d", &n); for (int i = 0; i < n; i++) scanf("%d", a + i); f(0, n - 1); printf("%d", res); return 0; }
댓글 없음 :
댓글 쓰기