$O(n{lgn}^2)$
i번 닌자를 매니저로 정할 때, i번 닌자의 자손들 중 예산 안에서 월급이 작은 닌자부터 차례대로 고용하는 것이 최선이다.
이 때 고용된 닌자들 집합을 Ai라 하자.
i번 닌자의 자식들을 sj라 하면
언제나 Asj들을 만들 때 제외된 닌자들은 Ai에 존재할 수 없다.(귀류법으로 증명할 수 있다.)
따라서 자식들의 Asj 합집합 + i번 닌자에서 예산이 넘치는 만큼 월급이 큰 닌자부터 제외시킨 집합이 Ai가 된다.
값이 큰 원소부터 제외하는 연산은 우선순위 큐로 구현할 수 있고, 이들을 합칠 때 사이즈가 큰 큐쪽으로 합치면 총 O(nlgn)의 연산횟수에 문제를 해결할 수 있다.
#include<cstdio> #include<algorithm> #include<queue> #include<vector> using namespace std; typedef long long ll; const int MXN = 1e5; int n, m, p[MXN + 1]; ll c[MXN + 1], l[MXN + 1], r; priority_queue<int> pq[MXN + 1]; vector<int> adj[MXN + 1]; void f(int h) { p[h] = h; pq[h].push(c[h]); for (auto it : adj[h]) { f(it); if (pq[p[h]].size() < pq[p[it]].size()) swap(p[h], p[it]); while (!pq[p[it]].empty()) pq[p[h]].push(pq[p[it]].top()), pq[p[it]].pop(); c[h] += c[it]; } while (c[h] > m) c[h] -= pq[p[h]].top(), pq[p[h]].pop(); r = max(r, (ll)pq[p[h]].size()*l[h]); } int main() { scanf("%d%d", &n, &m); for (int i = 1, x; i <= n; i++) { scanf("%d%lld%lld", &x, c + i, l + i); adj[x].push_back(i); } f(1); printf("%lld", r); return 0; }
$O(nlgn)$
처음 아이디어는 위와 같다.
트리를 순회하면서 번호를 매기면 어떤 정점의 자손들이 모두 인접한 번호를 가지고 있게 할 수 있다.
세그먼트 트리를 이용해 해당 구역에서 월급이 가장 큰 닌자를 O(lgn)에 찾고 제거할 수 있다.
#include<cstdio> #include<vector> #include<algorithm> using namespace std; typedef long long ll; const int MXN = 1e5; int n, m, l[MXN + 1], p[MXN + 1], cnt; ll c[MXN + 1], r; pair<int, int> tree[MXN * 4]; vector<int> adj[MXN + 1]; void insert(int h, int s, int e, int g, int x) { if (e < g || g < s) return; if (s == e) { tree[h] = { x,g }; return; } insert(h * 2 + 1, s, (s + e) / 2, g, x); insert(h * 2 + 2, (s + e) / 2 + 1, e, g, x); tree[h] = max(tree[h * 2 + 1], tree[h * 2 + 2]); } pair<int, int> find(int h, int s, int e, int gs, int ge) { if (ge < s || e < gs) return{ 0,0 }; if (gs <= s&&e <= ge) return tree[h]; return max(find(h * 2 + 1, s, (s + e) / 2, gs, ge), find(h * 2 + 2, (s + e) / 2 + 1, e, gs, ge)); } void f(int h) { int s = cnt; p[h]++; insert(0, 0, n - 1, cnt++, c[h]); for (auto it : adj[h]) { f(it); c[h] += c[it]; p[h] += p[it]; } while (c[h] > m) { pair<int, int> t = find(0, 0, n - 1, s, cnt - 1); insert(0, 0, n - 1, t.second, 0); c[h] -= t.first; p[h]--; } r = max(r, (ll)p[h] * l[h]); } int main() { scanf("%d%d", &n, &m); for (int i = 1, x; i <= n; i++) { scanf("%d%lld%d", &x, c + i, l + i); adj[x].push_back(i); } f(1); printf("%lld", r); return 0; }
댓글 없음 :
댓글 쓰기