[웰노운] BIT
서론
BIT(Binary Indexed Tree)에 대해 알아보자.
우선 이 자료구조를 알기 전에 Segment Tree에 대해 알고 있다는 전제로 작성한다.
둘다 query와 update가 핵심이고, 시간복잡도도 같다.
CP에서 BIT라는 자료구조는 세그트리에 비해 구현에서 이점이 존재한다.
하지만, BIT를 사람들이 잘 안쓰는 이유가 존재하는데. 요약하면 세그트리가 만능이기 때문이다.
이유를 알아보자면
- 무엇보다 비트 연산에 대해 이해를 하고 있어야 한다.
- 만약 구간 업데이트를 동반할 경우, BIT를 사용하는 구조일 때 BIT를 두 개 써야하는 굉장히 Tricky한 구조로 설계해야한다.
정리하면, 결국 같은 문제에 대해서 어찌보면 다른 풀이를 여러 개 알고 있는 것이니,
물론 이것에 대한 이점도 있겠지만. 우리는 바쁜 대학생이므로 효율이 떨어지는 것이 사실이다.
예전에 종만북에서 공부했던 것을 복습하는 기념으로, 그리고 문제에서 구간 $[0..i]$를 묻는 쿼리에 대해서도 나와서 이 글을 작성한다.
본론
정의
펜윅트리 내부 배열 $A[i]$는 보통 원래 값 배열 $\mathrm{arr}[1\ldots n]$에 대해 다음과 같이 정의된다. (1-based 인덱싱).
\[A[i] = \sum_{j = i - (i \mathbin{\&} -i) + 1}^{\,i} \mathrm{arr}[j]\]정말 복잡해보이는 수식인데 말로 쉽게 설명해보자.
$i = a \times 2^r$, a는 2의 배수가 아님 이라고 정의하자.
$A[i]$란, index i를 포함한 바로 이전 $2^r$개의 원소의 합으로 정의하는 것이다.
비트 연산
BIT를 구현하기 위해서는, 결국 2진수의 가장 오른쪽 비트를 구하는 방법을 알아야한다.
단순히, 반복문을 통해서 구할 수 있겠지만, 가장 깔끔한 방법은 2의 보수법을 이용하는 것이다.
$a \mathbin{\&} -a$를 하면, i의 가장 오른쪽 비트만 1인 이진수를 구할 수 있다.
이를 이용하여 BIT를 구현한다.
공간복잡도
우선 공간복잡도부터 알아보자.
세그트리는 일종의 이진트리 구조를 사용하여 $4n$ 정도의 배열을 확보해야 하지만,
앞서 정의에 따라 BIT는 정확히 $n + 1$ 크기의 배열만 확보하면 된다.
update
만약, 어떤 idx에 k를 더하는 연산을 한다고 해보자.
예제를 이용해 설명한다.
$110012 = 25{10}$ idx를 update한다고 했을 때,
앞선 $arr$의 정의에 따라 $11010_2, 11100_2, 100000_2, 1000000_2…$를 update해야한다.
i += (i & -i) 이 코드를 이용해 구현할 수 있다.
query
이번에는 어떤 구간 $[0..i]$의 합을 알고 싶다고 하자.
위와 같이 i의 이진수 비트가 1인 정수의 $arr$ 합을 전부 구하면 된다.
i -= (i & -i) 이 코드를 이용해 구현할 수 있다.
findK
여기서는 응용이다.
만약 구간 $query(i) >= k$인 최소의 i를 찾고 싶다고 하자.
물론, 이분 탐색으로 이를 구할 수도 있겠지만, BIT의 정의에 따라 findK() 함수를 작성할 수 있다.
$n$의 가장 왼쪽 비트부터 맨 오른쪽 끝 비트까지 탐색을 해보는 것인데.
우선 $query(i) < k$를 만족하는 최대의 $i$를 구하는 문제로 바꿔보자.
당연하게도, $ < k $ 조건을 만족 할 수 있게 최대의 합을 구해주면 된다.
앞선 $arr$의 성질에 따라, $arr[2^r]$은 그 구간의 전체 합이다.
따라서 $< k$ 조건을 만족하며 $r$을 줄여보면 된다.
여기서, 얻은 idx에 $+ 1$을 더해주면 우리가 원하는 lower_bound 형태의 함수가 된다.
이를 구현하기 위해서, n의 가장 왼쪽 비트만 1인 수를 구하는 코드를 짜야하는데,
1 << (31 - __builtin_clz(n))이면 이를 구현할 수 있다.
의사 코드
위 내용을 구현한 c++ 코드이다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
struct BIT {
int n;
vector<int> a;
BIT(int n) : n(n), a(n + 1, 0) {}
void add(int i, int k) {
for (; i <= n; i += (i & -i))
a[i] += k;
}
int sum(int i) {
int ret = 0;
for (; i; i -= (i & -i))
ret += a[i];
return ret;
}
int findK(int k) {
int ret = 0;
int bitMask = 1 << (31 - __builtin_clz(n));
while (bitMask > 0) {
int next = ret + bitMask;
if (next <= n && a[next] < k) {
k -= a[next];
ret = next;
}
bitMask >>= 1;
}
return ret + 1;
}
};