$O(n\lg n)$
a[i]를 j번째로 옮길 때
i) j<=i
c는 s[i-1]-s[j-1]-(i-j)*a[i] 만큼 증가
정리하면
j*a[i]-s[j-1]+s[i-1]-i*a[i]
(기울기, y절편)=(j, -s[j-1]) 직선들을 가지고 convex hull trick을 써서 해결
a[i]는 i에 대한 단조 함수가 아니므로 직선들의 교점을 가지고 이분 탐색을 해야한다.
ii) j>=i
c는 -s[j]+s[i]+(j-i)*a[i] 만큼 증가
정리하면
j*a[i]-s[j]+s[i]-i*a[i]
(기울기, y절편)=(j, -s[j]) 직선들을 가지고 convex hull trick을 써서 해결
마찬가지로 이분 탐색
#include<cstdio> #include<algorithm> using namespace std; const int MXN = 2e5; typedef long long ll; typedef pair<double, double> line; int n, sz; ll a[MXN + 1], s[MXN + 1], maxi, c; line l[MXN]; double x[MXN - 1]; double cross(line i, line j) { return (i.second - j.second) / (j.first - i.first); } void push(line h) { while (sz > 1 && cross(l[sz - 2], h) > cross(l[sz - 1], h)) sz--; if (sz) x[sz - 1] = cross(l[sz - 1], h); l[sz++] = h; } ll query(double h) { int p = lower_bound(x, x + sz - 1, h) - x; return l[p].first*h + l[p].second; } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%lld", a + i); s[i] = s[i - 1] + a[i]; c += i*a[i]; } for (int i = 1; i <= n; i++) { push({ i, -s[i - 1] }); maxi = max(maxi, query(a[i]) + s[i - 1] - i*a[i]); } sz = 0; for (int i = n; i; i--) { push({ -i, -s[i] }); maxi = max(maxi, query(-a[i]) + s[i] - i*a[i]); } printf("%lld", c + maxi); return 0; }
$O(n)$
a[j]를 i번째로 옮길때
i) i<=j
(-a[j])*(-i)+s[j-1]-j*a[j]-s[i-1] 만큼 증가한다.
(기울기, y절편)=(-a[j], s[j-1]-j*a[j]) 직선들을 생각해보면 convex hull trick을 쓰기 위해서 -a[j]가 j에 대해 단조 증가 함수일 필요가 있다.
실제로 잘 생각해보면 -a[j]가 증가할 때만 직선을 추가해도 상관 없다.
결과적으로 위 풀이와는 다르게 -i 또한 단조 증가하므로 O(n)에 해결 가능.
ii) i>=j
a[j]*i-s[i]+s[j]-j*a[j] 만큼 증가한다.
마찬가지로
(기울기, y절편)=(a[j], s[j]-j*a[j]) 직선들 중 a[j]가 증가할 때만 직선을 추가해서 convex hull trick을 적용한다.
#include<cstdio> #include<algorithm> using namespace std; const int MXN = 2e5; typedef long long ll; typedef pair<double, double> line; int n, p, sz; long long a[MXN + 1], s[MXN + 1], c, t; line l[MXN]; double cross(line i, line j) { return (i.second - j.second) / (j.first - i.first); } void push(line x) { while (sz - p > 1 && cross(l[sz - 2], x) > cross(l[sz - 1], x)) sz--; l[sz++] = x; } ll query(int x) { while (sz - p > 1 && cross(l[p], l[p + 1]) < x) p++; return l[p].first*x + l[p].second; } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%lld", a + i); s[i] = s[i - 1] + a[i]; c += i*a[i]; } int maxi = -1e9, mini = 1e9; for (int i = n; i; i--) { if (mini>a[i]) { mini = a[i]; push({ -a[i],s[i - 1] - i*a[i] }); } t = max(t, query(-i) - s[i - 1]); } p = sz = 0; for (int i = 1; i <= n; i++) { if (maxi < a[i]) { maxi = a[i]; push({ a[i],s[i] - i*a[i] }); } t = max(t, query(i) - s[i]); } printf("%lld", c + t); return 0; }
댓글 없음 :
댓글 쓰기