페이지

4002번: 닌자배치

https://www.acmicpc.net/problem/4002


$O(n{lgn}^2)$

i번 닌자를 매니저로 정할 때, i번 닌자의 자손들 중 예산 안에서 월급이 작은 닌자부터 차례대로 고용하는 것이 최선이다.
이 때 고용된 닌자들 집합을 Ai라 하자.
i번 닌자의 자식들을 sj라 하면
언제나 Asj들을 만들 때 제외된 닌자들은 Ai에 존재할 수 없다.(귀류법으로 증명할 수 있다.)
따라서 자식들의 Asj 합집합 + i번 닌자에서 예산이 넘치는 만큼 월급이 큰 닌자부터 제외시킨 집합이 Ai가 된다.
값이 큰 원소부터 제외하는 연산은 우선순위 큐로 구현할 수 있고, 이들을 합칠 때 사이즈가 큰 큐쪽으로 합치면 총 O(nlgn)의 연산횟수에 문제를 해결할 수 있다.

#include<cstdio>
#include<algorithm>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
const int MXN = 1e5;
int n, m, p[MXN + 1];
ll c[MXN + 1], l[MXN + 1], r;
priority_queue<int> pq[MXN + 1];
vector<int> adj[MXN + 1];
void f(int h) {
    p[h] = h;
    pq[h].push(c[h]);
    for (auto it : adj[h]) {
        f(it);
        if (pq[p[h]].size() < pq[p[it]].size()) swap(p[h], p[it]);
        while (!pq[p[it]].empty()) pq[p[h]].push(pq[p[it]].top()), pq[p[it]].pop();
        c[h] += c[it];
    }
    while (c[h] > m) c[h] -= pq[p[h]].top(), pq[p[h]].pop();
    r = max(r, (ll)pq[p[h]].size()*l[h]);
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1, x; i <= n; i++) {
        scanf("%d%lld%lld", &x, c + i, l + i);
        adj[x].push_back(i);
    }
    f(1);
    printf("%lld", r);
    return 0;
}



$O(nlgn)$

처음 아이디어는 위와 같다.
트리를 순회하면서 번호를 매기면 어떤 정점의 자손들이 모두 인접한 번호를 가지고 있게 할 수 있다.
세그먼트 트리를 이용해 해당 구역에서 월급이 가장 큰 닌자를 O(lgn)에 찾고 제거할 수 있다.


#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MXN = 1e5;
int n, m, l[MXN + 1], p[MXN + 1], cnt;
ll c[MXN + 1], r;
pair<intint> tree[MXN * 4];
vector<int> adj[MXN + 1];
void insert(int h, int s, int e, int g, int x) {
    if (e < g || g < s) return;
    if (s == e) {
        tree[h] = { x,g };
        return;
    }
    insert(h * 2 + 1, s, (s + e) / 2, g, x);
    insert(h * 2 + 2, (s + e) / 2 + 1, e, g, x);
    tree[h] = max(tree[h * 2 + 1], tree[h * 2 + 2]);
}
pair<intint> find(int h, int s, int e, int gs, int ge) {
    if (ge < s || e < gs) return{ 0,0 };
    if (gs <= s&&e <= ge) return tree[h];
    return max(find(h * 2 + 1, s, (s + e) / 2, gs, ge), find(h * 2 + 2, (s + e) / 2 + 1, e, gs, ge));
}
void f(int h) {
    int s = cnt;
    p[h]++;
    insert(0, 0, n - 1, cnt++, c[h]);
    for (auto it : adj[h]) {
        f(it);
        c[h] += c[it];
        p[h] += p[it];
    }
    while (c[h] > m) {
        pair<intint> t = find(0, 0, n - 1, s, cnt - 1);
        insert(0, 0, n - 1, t.second, 0);
        c[h] -= t.first;
        p[h]--;
    }
    r = max(r, (ll)p[h] * l[h]);
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1, x; i <= n; i++) {
        scanf("%d%lld%d", &x, c + i, l + i);
        adj[x].push_back(i);
    }
    f(1);
    printf("%lld", r);
    return 0;
}

댓글 없음 :

댓글 쓰기