포스트

[종만북] 문제: INSERTION

서론

문제 설명

삽입 정렬에서 각 원소의 inversion count가 주어질 때, 초기의 배열 값을 복원하시오.

문제 링크

inversion count

inversion count는 삽입 정렬과 같이 진행되는 정렬 알고리즘에서, 각 원소가 자신의 최종 위치에 도달하기 전 자신보다 큰 원소와 교환한 횟수를 의미합니다.

다른 이름으로는 레머 코드(Lehmer code)라고 합니다.

예를 들어, 배열이 $[5, 1, 4, 3, 2]$인 경우를 살펴보면:

  • 5: 처음 위치에서 이동하지 않았으므로 inversion count는 $0$
  • 1: $5$와 교환되어 한 칸 이동하므로 inversion count는 $1$
  • 4: $5$와 교환되어 한 칸 이동하므로 inversion count는 $1$
  • 3: 두 번 교환되어 두 칸 이동하므로 inversion count는 $2$
  • 2: 세 번 교환되어 세 칸 이동하므로 inversion count는 $3$

따라서, inversion count 배열은 $[0, 1, 1, 2, 3]$이 됩니다.

Reverse Engineering

문제 해결의 핵심은 정렬이 완료된 최종 상태(오름차순 배열)로부터 주어진 inversion count 정보를 사용해 원래 배열을 복원하는 역추적(Reverse Engineering) 기법에 있습니다.
즉, 최종 상태인 $[1, 2, 3, …, n]$ 배열에서 뒤쪽부터 inversion count 정보를 이용해 원래 위치의 원소를 하나씩 찾아 제거하면서 초기 배열을 복원해 나가는 방식입니다.

Reverse Engineering의 접근 방법

Naive Approach

  • 아이디어:
    단순하게 배열에서 특정 인덱스의 원소를 제거하고 남은 배열에서 다시 인덱스를 계산하는 방식으로 접근합니다.

  • 시간복잡도:
    매 제거 과정이 $O(n)$ 시간이 걸리므로, 전체적으로 $O(n^2)$의 시간복잡도를 가집니다.

  • 한계:
    $n$의 최대 크기가 50,000까지 가능하기 때문에, 이 방식은 시간 초과로 이어질 수 있습니다.

효율적인 방법 (자료구조 활용)

  • 아이디어:
    Order Statistic Tree, 세그먼트 트리, 혹은 Fenwick Tree(이진 인덱스 트리)를 사용하면 배열에서 $k$번째 원소를 $O(\log n)$에 찾고 삭제할 수 있습니다.

  • 시간복잡도:
    각 원소에 대해 $k$번째 원소를 찾는 연산이 $O(\log n)$이고, 총 $N$번 수행하므로 전체 시간복잡도는 $O(n \log n)$입니다.

트립

트립(Treap)TreeHeap의 합성어로, 각 노드가 키(key)우선순위(priority)를 가지며, 이진 탐색 트리(BST)의 성질과 힙(Heap)의 성질을 동시에 만족하는 자료구조입니다.

보통 각 노드의 우선순위는 무작위(random)로 할당되어, 평균적으로 균형 잡힌 트리를 유지할 수 있습니다.

트립은 다음과 같은 성질을 가지고 있습니다.

  • 이진 탐색 트리(BST) 성질:
    각 노드에 대해, 왼쪽 서브트리의 모든 노드는 현재 노드의 키보다 작고,
    오른쪽 서브트리의 모든 노드는 현재 노드의 키보다 큽니다.

  • 힙(Heap) 성질:
    각 노드의 우선순위는 보통 max-heap 혹은 min-heap의 성질을 따릅니다.
    예를 들어, max-heap의 경우 부모 노드의 우선순위가 자식 노드의 우선순위보다 크거나 같아야 합니다.

  • 랜덤화:
    노드의 우선순위를 무작위로 할당함으로써, 평균적으로 모든 연산(탐색, 삽입, 삭제)이 $O(\log n)$ 시간에 처리될 수 있도록 합니다. 증명은 스킵하겠습니다.

트립의 구현

트립의 한 노드를 구현하는 struct를 선언합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
struct Node {
    int key, priority, size;
    Node *left, *right;
    Node(const int &_key)
        : key(_key), priority(rand()), size(1), left(nullptr), right(nullptr) {}
    void setLeft(Node *_left) {
        left = _left;
        calcSize();
    }
    void setRight(Node *_right) {
        right = _right;
        calcSize();
    }
    void calcSize() {
        size = 1;
        if (left) size += left->size;
        if (right) size += right->size;
    }
}; // 트리의 element를 표현하는 Node 클래스

key를 기준으로 노드를 나누는 split()함수를 선언합니다.

1
2
3
4
5
6
7
8
9
10
11
12
Nodepair split(Node *root, int key) {
    if (root == nullptr) return make_pair(nullptr, nullptr);
    // 루트의 키가 입력 키보다 큼
    if (root->key > key) {
        Nodepair ls = split(root->left, key);
        root->setLeft(ls.second);
        return Nodepair(ls.first, root);
    }
    Nodepair rs = split(root->right, key);
    root->setRight(rs.first);
    return Nodepair(root, rs.second);
} // k를 기준으로 트리 분리

트립에 노드를 삽입하는 insert()함수를 선언합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
Node *insert(Node *root, Node *node) {
    if (root == nullptr) return node;
    if (root->priority < node->priority) {
        Nodepair splitted = split(root, node->key);
        node->setLeft(splitted.first);
        node->setRight(splitted.second);
        return node;
    } else if (root->key < node->key)
        root->setRight(insert(root->right, node));
    else
        root->setLeft(insert(root->left, node));
    return root;
} // 삽입

두 노드 간 우선순위에 따라 합치는 함수 merge()함수를 선언합니다.

1
2
3
4
5
6
7
8
9
10
Node *merge(Node *a, Node *b) {
    if (a == nullptr) return b;
    if (b == nullptr) return a;
    if (a->priority < b->priority) {
        b->setLeft(merge(a, b->left));
        return b;
    }
    a->setRight(merge(a->right, b));
    return a;
} // 합치기

root에 key값을 가진 노드가 존재 시 그 노드를 지우는 함수 erase()를 선언합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Node *erase(Node *root, int key) {
    if (root == nullptr) return nullptr;

    if (root->key == key) {
        Node *ret = merge(root->left, root->right);
        delete root;
        return ret;
    }
    if (root->key > key)
        root->setLeft(erase(root->left, key));
    else
        root->setRight(erase(root->right, key));
    return root;
} // 지우기

앞서 작성한 코드를 바탕으로 k번째 index에 있는 노드를 반환하는 함수 findKth()를 선언합니다.

1
2
3
4
5
6
7
8
Node *findKth(Node *root, int k) {
    int l = root->left ? root->left->size : 0;
    if (k < l + 1)
        return findKth(root->left, k);
    else if (k == l + 1)
        return root;
    return findKth(root->right, k - l - 1);
} // k번째 수 찾기

findKth()의 함수와 비슷하게, key 값보다 작은 값 들의 개수를 세는 함수 countLessThan같은 함수를 작성하는 것도 가능합니다.

1
2
3
4
5
6
7
8
int countLessThan(Node *root, int key) {
    if (root == nullptr) return 0;
    if (root->key >= key) {
        return countLessThan(root->left, key);
    }
    int ls = root->left ? root->left->size : 0;
    return ls + 1 + countLessThan(root->right, key);
}

알고리즘 구상

알고리즘 설계

트립을 구현했으니 자료구조를 이용해서 작성만 하면 됩니다.

  1. 트립을 이용해 1부터 n개의 수를 삽입합니다.
  2. k번째 값을 얻고, 지우는 과정을 n번 반복합니다.
  3. 2번 과정에서 얻은 값들을 출력합니다. (역으로 출력)

알고리즘 구현

main()

트립 자료구조를 이용하여 삽입과 제거 연산 이후 답을 출력합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
int main() {
    fastIO;
    int C;
    cin >> C;
    while (C--) {
        cin >> n;
        for (int i = 0; i < n; i++) {
            cin >> lehmer[i];
        }
        Node *root = nullptr;
        for (int i = 1; i <= n; i++) {
            root = insert(root, new Node(i));
        }
        for (int i = n - 1; i >= 0; i--) {
            Node *k = findKth(root, i + 1 - lehmer[i]);
            answer[i] = k->key;
            root = erase(root, k->key);
        }
        for (int i = 0; i < n; i++) {
            cout << answer[i] << " ";
        }
        cout << "\n";
    }
    return 0;
} // 트립 연습 문제, 펜웍 트리로 풀 수 있다.

최종 코드

여기에서 확인하세요!

알고리즘 개선 방안

세그먼트 트리

트립을 구현하면서 문제를 풀어보았지만, 사실 세그먼트 트리를 통해 더 간단히 풀 수 있습니다. 원리는 비슷합니다.

세그먼트 트리 풀이에 대해서는 이 코드를 참조해주세요.

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.