포스트

[종만북] 문제: NH

서론

문제 설명

$m(1 \leq m \leq 200)$개의 문자열 패턴이 주어진다. 패턴의 길이 $length(pattern)$은 최대 $10$이다.

길이 $n$의 문자열 $s$가 될 수 있는 경우의 수를 구하시오.

출력할 때는 (경우의 수) $\bmod 10007$을 출력한다.

문제 링크

문제 유형

아호 코라식을 구현하고, 동적 계획법을 이용해 문제를 풀어봅시다.

Aho-corasick

아호 코라식에 대한 설명은 이 글입니다.

아호 코라식에 대한 간략한 설명

앞서 문자열에서 어떤 특정한 하나의 패턴을 찾는 알고리즘을 배웠습니다.

바로 KMP 알고리즘입니다.

여기서 이런 생각이 들 수 있습니다. 어떤 문자열의 여러 가지의 패턴을 찾아야 한다면,

어떤 방식으로 찾아야 하는가.. 생각을 해야 합니다.

우리가 가진 생각으로 Naive하게 접근하면, 단순히 KMP를 패턴의 개수만큼 반복할 수 있습니다.

하지만 패턴이 너무 많아진다면, 이런 식으로는 프로그램이 제 시간 안에 동작할 수 없습니다.

따라서, 우리는 트라이(Trie)를 이용해 이것을 KMP 알고리즘처럼 failure function을 이용해 오토마타적으로 접근할 수 있습니다.

그러면 시간 복잡도를 $O((n + m)k)$에서 $O(n + m + k)$로 줄일 수 있습니다.

실패 함수를 이용하자

따라서 이 문제는 아호 코라식을 이용해서 트라이 구조를 만들고, 그다음에 BFS를 통해서 경우의 수를 계산하면 되겠습니다.

아호 코라식의 실패함수 $f(v)$는, $v$의 부모 노드 $p(v)$와 $p(v)$에서 $v$로 이어지는 문자 $c(v)$를 이용하여 다음과 같이 정의할 수 있습니다.

\[f(v) = \begin{cases} \text{root} & \text{if } p(v) \text{가 root인 경우},\\[6mm] \delta\bigl(f(p(v)),\, c(v)\bigr) & \text{그 외의 경우}, \end{cases}\]

여기서 $\delta(x, c)$는 노드 $x$에서 문자 $c$로의 전이(transition) 함수로, 다음과 같이 재귀적으로 정의됩니다.

\[\delta(x, c) = \begin{cases} x[c] & \text{if } x \text{에 문자 } c \text{로의 간선이 존재한다면},\\[6mm] \delta\bigl(f(x),\, c\bigr) & \text{if } x \text{가 root가 아니면서 간선이 없다면},\\[6mm] \text{root} & \text{if } x \text{가 root인 경우}. \end{cases}\]

이 수식은 실패 함수가 부모의 실패 링크를 통해 바텀-업 방식으로 결정됨을 보여줍니다.


알고리즘 구상

알고리즘 설계

  1. 먼저, 트라이에 대한 입력을 처리한다.
  2. 트라이에 대해 BFS로 실패 경로를 세팅한다.
  3. 이렇게 얻은 트라이에 대해 DP를 이용해 BFS를 수행하여 경우의 수를 계산한다.

시간 복잡도

트라이를 구성할 때 시간 복잡도는 $O(mL)$, $L =$ 패턴의 길이

트라이의 노드 수는 최대 $O(mL)$, 따라서 실패 경로 계산의 시간 복잡도는 $O(26 \times m \times L^2)$

DP를 이용해 BFS를 하는 것은 $O(26 \times n \times m \times L^2)$

정리하면, $O(nmL^2)$정도가 되겠습니다.


알고리즘 구현

구조체 정의

트라이의 노드를 표현할 struct를 정의합니다.

특이하게 각 노드에 대해 int id로 구별해야하기 때문에, static int를 선언합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
struct TrieNode {
    TrieNode *alphabet[26];
    TrieNode *fail;
    int id;
    bool terminal;
    static int counter;
    TrieNode() : fail(0), id(counter++), terminal(0) {
        memset(alphabet, 0, sizeof(alphabet));
    }
    ~TrieNode() {
        for (int i = 0; i < 26; i++) {
            if (alphabet[i]) delete alphabet[i];
        }
    }
};

insert()

TrieNode *root에 패턴을 입력하는 insert 함수를 작성합니다.

1
2
3
4
5
6
7
8
9
void insert() {
    TrieNode *ptr = root;
    for (char &c : s) {
        int k = c - 'a';
        if (!ptr->alphabet[k]) ptr->alphabet[k] = new TrieNode();
        ptr = ptr->alphabet[k];
    }
    ptr->terminal = 1;
}

calFailFunction()

아호 코라식의 실패 경로를 BFS로 파악하는 코드를 작성합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void calFailFunction() {
    queue<TrieNode *> q;
    root->fail = root; // 루트의 실패는 루트
    q.push(root);
    while (!q.empty()) {
        TrieNode *here = q.front();
        q.pop();
        for (int i = 0; i < 26; i++) {
            TrieNode *child = here->alphabet[i];
            if (!child) continue;
            if (here == root)
                child->fail = root;
            else {
                TrieNode *t = here->fail;
                while (t != root && !t->alphabet[i]) {
                    t = t->fail;
                }
                if (t->alphabet[i]) t = t->alphabet[i];
                child->fail = t;
            }
            child->terminal |= child->fail->terminal;
            q.push(child);
        }
    }
}

경우의 수를 구하는 함수

BFSDP를 이용해 각 $depth$에 대해 경우의 수를 구합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 각 노드는 고유한 id[0, 1000]를 가진다.
int DP[1001][101];
int function(TrieNode *node, int depth) {
    int &ret = DP[node->id][depth];
    if (ret != -1) return ret;
    if (node->terminal) return ret = 0;
    if (depth == 0) return ret = 1;
    ret = 0;
    for (int i = 0; i < 26; i++) {
        TrieNode *t = node;
        while (t != root && !t->alphabet[i])
            t = t->fail;
        if (t->alphabet[i]) t = t->alphabet[i];
        ret += function(t, depth - 1);
        ret %= 10007;
    }
    return ret;
}

main()

위 함수들을 이용해 값을 초기화하고 입/출력하는 함수입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
int main() {
    fastIO;
    int C, n, m;
    cin >> C;
    while (C--) {
        cin >> n >> m;
        TrieNode::counter = 0;
        root = new TrieNode();
        for (int i = 0; i < m; i++) {
            cin >> s;
            insert();
        }
        calFailFunction();
        memset(DP, -1, sizeof(DP));
        cout << function(root, n) << "\n";
        delete root;
    }
    return 0;
}

최종 코드

여기에서 확인하세요!

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.