포스트

[종만북] 문제: FAMILYTREE

서론

문제 설명

트리에서, 노드 $u$, $v$간의 거리를 계산하시오.

트리의 노드의 개수는 $n(1 \leq n \leq 100000)$개이고, 거리를 계산하는 질문(Query)의 수는 $q(1 \leq q \leq 10000)$개이다.

문제 링크

아이디어

노드의 거리를 구하기 위해서는 우리는 먼저 BFS를 떠올릴 수 있습니다만, 공교롭게도, $(1 \leq n \leq 100000, 1 \leq q \leq 10000)$이므로 $O(nq)$시간이 되는 BFS는 사용할 수 없습니다.

그러면 무슨 방법이 있을까요? 이 문제를 풀기 위해서는 LCA를 구하는 알고리즘을 만들어야 합니다.

LCA, Lowest Common Ancestor

LCA(최소 공통 조상)란 트리에서 두 노드의 공통된 조상 중 가장 가까운 조상을 의미합니다.

LCA는 두 노드 간의 관계를 분석할 때 사용할 수 있습니다.

LCA를 구해보자

LCA를 구하는 방법은 대략 4가지 정도의 방법이 있습니다.

  1. 트리를 무방향 그래프로 생각하며, BFS로 노드 $u \to v$ 를 찾는다. (Naive approach)
  2. Binary Lifting을 이용합니다.
  3. Eular Tour를 수행한 트리를 이용하여 이를 RMQ(구간 최솟값 쿼리)문제로 바꿔 풉니다.
  4. Tarjan의 Union-find 알고리즘을 이용한다.

우리는 세그먼트 트리를 배웠으니, 3번 방법을 이용해서 풀어봅시다.

Eular Tour

Eular Tour는 트리 혹은 그래프를 DFS방식으로 순회하며, 노드를 방문하는 순서를 배열에 기록하는 방법입니다.

이를 이용해 트리를 평탄화하여, 구간 쿼리 문제(Range Query Problem)로 전환할 수 있습니다.

Range Minimum Query

Range Minimum Query는 배열의 특정 구간에서 최솟값을 찾기 위한 문제입니다.

단순히 모든 구간에 대해 최솟값을 찾을 수도 있지만, 세그먼트 트리희소 배열을 이용해 풀 수 있습니다.

근데 이거로 뭐함?

조금 고민해보면, Eular Tour로 트리를 평탄화 했을 때 다음과 같은 사실을 알 수 있습니다.

  1. 노드 $u \to v$일때, $u$와 $v$의 LCA를 지나갑니다.
  2. $u \to v$에서는 LCA보다 $depth$가 작은 부모 노드를 지나갈 수 없습니다.

이를 통해 알 수 있는 것은. Eular Tour의 노드 $u$의 첫 등장 $index$와, 노드 $v$의 첫 등장 $index$범위 안의 depth의 최솟값은 $depth(LCA(u, v))$라고 생각할 수 있습니다.

그러면 각 노드 $x$의 첫 등장 $index$을 $First[x]$로 정의한다면,

$Eular Tour[First[u], First[v]]$ 범위에 대해 RMQ를 구하면 되겠네요.

이 범위에 대해 최솟값을 구하는 것은 세그먼트 트리를 이용해서 풀면 되겠습니다.

이렇게 구한 최솟값을 $depth(LCA(u, v))$라고 하면, 노드 $u \to v$의 거리는 $depth(u) + depth(v) - 2 \times depth(LCA(u, v))$의 공식을 통해 구할 수 있습니다.


알고리즘 구상

알고리즘 설계

  1. 트리를 입력받아 노드에서 리프로 나가는 형태로 Tree를 설계합니다.
  2. 트리에 대해 DFS를 수행하며 다음과 같은 내용을 구합니다.
    • 각 노드의 $depth$를 구합니다.
    • Eular Tour의 노드들을 각 노드의 $depth$로 변형합니다.
    • Eular Tour에서 각 노드의 첫 등장 $index$를 구합니다.
  3. Eular Tour의 범위에 대해 최솟값을 구하는 세그먼트 트리를 전처리합니다.
  4. 세그먼트 트리에 대해 최솟값을 묻는 질문(Query)을 합니다.
  5. 4번 과정에서 구한 $depth(LCA(u, v))$를 이용해 값을 구합니다.

시간 복잡도

각각 트리 생성, DFS, 세그먼트 트리 생성, Query에 대해

$O(n) + O(n) + O(n) + q \times O(\log n) = O(n + q \log n)$이네요, 문제를 풀기에 충분해 보입니다.


알고리즘 구현

dfs()

2번 과정을 수행하는 알고리즘을 설계합니다.

$d[i]$: 각 노드 $i$의 $depth$를 저장하는 배열입니다.

$eularTour[i]$: eular Tour한 $node$의 $depth$를 저장하는 배열입니다.

$First[i]$: eular Tour에서 노드 i의 첫 등장 $index$, 알고리즘 수행 시마다 미리 -1로 초기화합니다.

1
2
3
4
5
6
7
8
9
10
11
12
int idx;
void dfs(int node, int depth) {
    if (first[node] == -1) first[node] = idx;
    eularTour[idx++] = depth;
    d[node] = depth;
    if (!tree[node].empty()) {
        for (int &i : tree[node]) {
            dfs(i, depth + 1);
            eularTour[idx++] = depth;
        }
    }
} // dfs 탐색, 깊이 계산, eular Tour에서 노드의 첫 번째 index가 위치하는 곳 탐색

segTreeInit()

3번 과정을 수행하는 알고리즘으로, $eularTour[nl, nr]$ 범위에 대해 세그먼트 트리를 구축합니다.

1
2
3
4
5
6
7
8
int segTreeInit(int node, int nl, int nr) {
    if (nl == nr) return seg[node] = eularTour[nl];
    int mid = (nl + nr) / 2;
    int left = segTreeInit(node * 2, nl, mid);
    int right = segTreeInit(node * 2 + 1, mid + 1, nr);
    return seg[node] = min(left, right);
}

query()

4번 과정을 수행하는 알고리즘으로, $eularTour[nl, nr]$ 범위에 대해 미리 구한 세그먼트 트리를 이용해 $eularTour[start, end]$ 범위의 최솟값이 무엇인지 묻습니다.

1
2
3
4
5
6
7
8
int query(int start, int end, int node, int nl, int nr) {
    if (nr < start || end < nl) return 100001;
    if (start <= nl && nr <= end) return seg[node];
    int mid = (nl + nr) / 2;
    return min(query(start, end, node * 2, nl, mid),
               query(start, end, node * 2 + 1, mid + 1, nr));
}

main()

1번 과정을 수행하고 위에서 만든 알고리즘을 전부 수행하는 main()입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int main() {
    fastIO;
    int C;
    cin >> C;
    while (C--) {
        cin >> n >> q;
        tree.assign(n, {}); // tree 초기화는 이렇게 !!
        memset(first, -1, sizeof(int) * n);
        for (int i = 1; i < n; i++) {
            cin >> input;
            tree[input].push_back(i);
        }
        idx = 0;
        dfs(0, 0);
        segTreeInit(1, 0, 2 * n - 3);
        while (q--) {
            int a, b, aa, bb, lcaDepth;
            cin >> a >> b;
            aa = first[a], bb = first[b];
            if (aa > bb) swap(aa, bb);
            lcaDepth = query(aa, bb, 1, 0, 2 * n - 3);
            cout << d[a] + d[b] - 2 * lcaDepth << "\n";
        }
    }
    return 0;
}

최종 코드

여기에서 확인하세요!

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.