본문 바로가기

알고리즘

펜윅트리로 최소값구하기

펜윅트리란 어떤 배열의 특정구간 구간합과같은 쿼리가 많을경우 

별도의 특별한 트리구조를 만들어 이를통해 시간복잡도문제를 해결하기 위한 자료구조입니다.


펜윅트리말고도 이와같은 트리로 세그먼트트리도 있으나 

개인적으로 펜윅트리가 구현하기 깔끔하고 필요한 트리를 만들기위한 메모리를 절약할수 있으며 

속도도 왠만하면 더 빠른기대값을 얻을수 있습니다.


하지만 펜윅트리로는 최소값, 최대값같은 구간값을 일반적으로 구할수가 없습니다. 


구간합 같은경우 A - B 구간의 합을 구할때 1-B까지의 구간합 - 1-A까지의 구간합을 통해 계산하며 되지만 최대값 최소값같은경우 

펜윅트리의 구조상 생략되는 값이 존재하기때문입니다. 


펜윅트리는 어떤수를 2진수로 표현할때 1의 위치, 개수에따라 값을 구하는 방식으로


만들어지는 방식을 살펴보면 1부터4까지의 구간합을 구할떄 4를 2진수로하면 100인데 이뜻은 100 = 001 + 010 + 011 구간의 특정값을 저장한다는 뜻이다.

즉 1이 하나일경우 100은 001 ~ 011 구간의 특정값 1000은 0001 ~ 0111 구간의 특정값 을 저장한다는 뜻이다


이는 이러한값을 저장하는것은 2의보수연산AND연산을 통해 구할수 있습니다. while (i <= n)  i += (i & -i)


이러한 구조적특성을 통해 1부터5까지 구간의 특정값을 구한다면 5[101]는 트리의 101 + 100 구간을 조회하면 2번만에 구할수있습니다.


15[1111]는 1111 + 1110 + 1100 + 1000


이 연산작업또한 2의보수와 AND연산으로 구할수있씁니다. while (i > 0) i -= (i & -i)


하지만 최소값, 최대값을 구하는경우 예를들어 3부터 9구간의 최소값을 구한다면 9[1001] 는  min(tree[1000], tree[1001]) 인데 tree[1000]은 1부터 8까지의 

최소값이므로 필요없는 1~3구간이 포함된상태의 최소값이므로 값을 구할수가없습니다. 


하지만 펜윅트리를 2개 사용한다면 가능합니다 트리의 구조를 루트값이 0인구조, n인구조로 2개를 만들면

1000 100 10 같은 형태가 반대방향인 10 100 1000 같은 구조를 하나더 만들수있습니다

이 둘의 구조를 탐색하면 됩니다


이렇게되면 세그먼트트리를 구현하는게 더간단하게되지만 시간복잡도를 줄일수 있습니다.


다음은 백준사이트의 10868 구간 최소값을 구하는 문제를 두가지 형태로 풀어보고 각각 걸리는시간을 확인할수 있습니다.


코드는 다음과 같습니다. [Python]

# fenwick
import sys
read = sys.stdin.readline
MAX = 1000000001


def update(i, x):
    while i <= n:
        tree[i] = min(tree[i], x)
        i += (i & -i)


def update2(i, x):
    while i > 0:
        tree2[i] = min(tree2[i], x)
        i -= (i & -i)


def query(a, b):
    v = MAX

    prev = a
    curr = prev + (prev & -prev)
    while curr <= b:
        v = min(v, tree2[prev])
        prev = curr
        curr = prev + (prev & -prev)

    v = min(v, arr[prev])

    prev = b
    curr = prev - (prev & -prev)
    while curr >= a:
        v = min(v, tree[prev])
        prev = curr
        curr = prev - (prev & -prev)

    return v


n, m = map(int, read().split())
arr = [0] * (n+1)
tree = [MAX] * (n+2)
tree2 = [MAX] * (n+2)

for i in range(1, n+1):
    arr[i] = int(read())
    update(i, arr[i])
    update2(i, arr[i])

for i in range(m):
    a, b = map(int, read().split())
    print(query(a, b))


'알고리즘' 카테고리의 다른 글

트리의 최대 지름  (0) 2017.11.21
LIS 최장 증가 수열  (0) 2017.11.15
알고리즘 기본지식  (0) 2016.11.24
소수 구하기  (0) 2016.11.19
멱집합 구하기  (0) 2016.11.17