MST(최소비용 신장트리) 기초개념 및 구현실습

2019-05-14

.

그림, 실습코드 등 학습자료 출처 : https://github.com/ythwork

1. spanning tree (신장트리)

  • 그래프에서 모든 정점이 연결되어 있지만 사이클이 존재하지 않는 그래프를 말한다. 다시말해서 모든 정점을 모두 포함하는 트리를 신장 트리라고 한다.

  • 트리의 특수한 형태로 모든 정점들이 연결되어 있어야 하고 사이클이 있어서는 안된다, n개의 정점을 정확히 (n-1)개의 간선으로 연결된다.

  • 트리의 정점이 파생되어 나온 그래프와 같다

  • 그래프 관련 기초용어

1) 에지의 수 : 정점의 수 - 1

2) 무방향

u — v

(u,v) = (v,u)

3) 방향

u —> v

<u,v> != <v,u>

4) 비용이란 엣지를 이동할때의 가중치들의 합을 말한다.

2. minimum cost spanning tree(최소비용 신장트리)

1) 무방향 그래프이다.

2) 가중치 그래프이다.

3) 가중치 cost가 최소가 되야한다.

3. minimum cost spanning tree(최소비용 신장트리)를 구현하기 위한 알고리즘

  • 최소 비용신장트리를 구현하기 위한 알고리즘 = 탐욕 알고리즘

1) kruskal 알고리즘

2) prim 알고리즘

3) sollin 알고리즘

  • 탐욕 알고리즘이란

큰 문제에 대한 최적해를 찾기 위해서 그때그때마다 그 지역별로 최적해를 선택하면 그 지역해를 다 더했을때 전역 최적해에 근접한다는 이론이다.

지역적 최적 선택의 모음이므로 전역적인 최적해라는 보장이 없다.

  • 탐욕 알고리즘의 실패 예시

아래그림에서 최대 합을 구하는 문제

1

4. 탐욕 알고리즘이 잘 먹히기 위한 조건

1) greedy choice property

지역최적해를 선택해 나가면 전역 최적해에 도달 할 수 있다.

locally optimal choices lead to globally optimal choice

2) optimal substructure

전역적인 문제가 부분으로 쪼개도 그 부분에 대해서도 최적해를 포함한다.

문제에 대한 최적해가 부분 문제에 대한 최적해를 포함할 때

tip) 실제 현업에서 쓸때는 choice property를 만족한다는 것을 보장해야 탐욕알고리즘을 쓸 수 있다.

5. MST가 어떻게 choice property를 만족시키는가

  • 단순화가정

1) edge 가중치는 서로 다르다. 왜냐하면 에지가 같은 경우에 MST가 여러개 나오는 경우가 생길 수 있기 때문이다.

2) 그래프는 연결되어 있다.

3) 그래서 MST의 존재하며 유일하다는 가정

6. Cut 이란

2

컷 지점을 지나가는 것을 크로싱엣지라고 한다.

  • cut property

3

각 스테이지에서 가중치가 가장 작은것을 선택한다는 것은 컷 프로퍼티와 초이스 프로퍼티다.

(0,2)는 반드시 E(T)에 포함될 것이다.

7. Greedy MST algorithm

선택된 edge가 crossing edge가 되지 않는 cut을 찾은 다음에 crossing edge 중 가중치가 가장 작은 edge를 찾아 선택한다.

step 1) 0번부터 출발하는데 사용자 임의로 컷을 나눈다.

step 2) 컷을 나누고 다음 스테이지로 넘어가는데 가장 작은 2를 선택할 것이다.

step 3) 또 다시 컷을 임의로 설정하고 크로싱 에지중 가장 작은 것을 고른다.

이런식으로 진행한다.

4

결론적으로 위의 그림에서 9번 step과 같이 사이클이 없고, 코스트가 최소가 되는 루트를 찾았다.

8. MST 만드는 방법

1) connected로 만듬

2) 반드시 트리여야하고

3) 모든 엣지들의 가중치를 합쳤을때 반드시 최소가 되야 한다.

  • 그렇다면 크루스칼과 프림은 무슨차이일까

1) 어떻게 컷을 선택할 것인가.

2) 어떻게 가중치가 가장 작은 에지를 찾은것일까.

에 따라서 이 두 알고리즘이 구분이 된다.

9. Disjoint set(분리집합)

트리의 일종으로 우리는 파이썬으로 집합을 구현하기 위해 set이 아니라 이 분리 집합을 이용할 것이다.

포레스트가 분리집합이라고 할 수 있다.

i 가 vertex(정점)이고

parent라는 배열이 있을때 parent[i]가 부모라는 의미다.

루트는 음수로 표현했다.

사이클이 생기는지 안생기는지 확인할때 find를 쓰고

사이클이 없으면 union으로 합칠것이다.

10

  • 파이썬 코드구현 실습
class DisjointSet:
    def __init__(self, vnum):
        self.parent=[-1 for _ in range(vnum)]

    def simple_find(self, i):
        while self.parent[i] >= 0:
            i=self.parent[i]
        return i

    def simple_union(self, i, j):
        self.parent[i]=j

    def collapsing_find(self, i):
        root=trail=lead=None
        #find the root
        root=i
        while self.parent[root] >= 0:
            root=self.parent[root]

        #make all nodes to children of the root
        trail=i
        while trail != root:
            lead=self.parent[trail]
            self.parent[trail]=root
            trail=lead

        return root

    def weighted_union(self, i, j):
        """
        paremeters i, j must be roots!
        if size[i] < size[j] then parent[i]=j
        """
        #abs(parent[i])=size[i]
        #temp_cnt is negative and the sum of size[i], size[j]
        temp_cnt=self.parent[i]+self.parent[j]

        #size[i] < size[j] : consider signs!!
        if self.parent[i] > self.parent[j]:
            self.parent[i]=j
            self.parent[j]=temp_cnt
        #size[i] > size[j]
        else:
            self.parent[j]=i
            self.parent[i]=temp_cnt

if __name__=="__main__":
    ds=DisjointSet(5)

    # ds.simple_union(4, 2)
    # ds.simple_union(0, 4)
    # ds.simple_union(1, 0)
    # ds.simple_union(3, 1)
    # print(ds.parent)

    # print(ds.simple_find(4), ds.simple_find(0),
    #     ds.simple_find(1), ds.simple_find(3))

    # print(ds.collapsing_find(3))
    # print(ds.parent)

    ds.simple_union(4, 2)
    ds.simple_union(0, 4)
    ds.simple_union(3, 1)
    ds.parent[2]=-3
    ds.parent[1]=-2

    ds.weighted_union(2, 1)
    print(ds.parent)
[4, 2, -5, 1, 2]

10. kruskal 알고리즘

step 1) edge를 가중치가 작은 것에서 큰 것 순으로 정렬

5

step 2) 트리에 에지를 하나씩 추가

6

같은 집합이 아니면 사이클이 아니라는 것을 알 수 있다.

7

step 3) 사이클이 생기면 아래 그림과 같이 추가하지 않는다.

8

step 4) 최소비용 신장 트리가 완성되면 E = V -1

9

  • 파이썬 코드구현 실습
class GNode:
    def __init__(self, vertex=None, weight=None):
        self.vertex=vertex
        self.weight=weight
        self.link=None

class Edge:
    def __init__(self, v1, v2, weight):
        self.v1=v1
        self.v2=v2
        self.weight=weight

class Graph:
    def __init__(self):
        self.adjacency_list=[]
        self.edge_list=[]

        self.vertex_num=0

    def add_vertex(self, vnum=1):
        self.adjacency_list.extend([None for _ in range(vnum)])
        self.vertex_num+=vnum
    
    def __add_node(self, vertex, node):
        cur=self.adjacency_list[vertex]
        if not cur:
            self.adjacency_list[vertex]=node
        else:
            while cur.link:
                cur=cur.link
            cur.link=node

    def insert_edge(self, u, v, weight):
        unode=GNode(u, weight)
        vnode=GNode(v, weight)

        self.__add_node(u, vnode)
        self.__add_node(v, unode)

        self.edge_list.append(Edge(u, v, weight))

    def MST_kruskal(self):
        #최종적으로 만들어질 MST
        mst=Graph(); mst.add_vertex(self.vertex_num)        
        #분리집합 : 사이클 형성 검사를 할 정점 집합
        ds=DisjointSet(self.vertex_num)
        #가중치에 따라 에지를 정렬
        self.edge_list.sort(key=lambda e: e.weight)
        #mst에 속하는 에지의 수
        mst_edge_num=0
        #정렬된 에지 리스트에서 인덱스
        edge_idx=0
        
        #|TE| = |TV|-1이면 종료
        while mst_edge_num < self.vertex_num-1:
            #가중치가 작은 순서대로 에지를 가져온다
            edge=self.edge_list[edge_idx]
            
            #FIND(u) != FIND(v)이면 사이클을 형성하지 않는다
            # 즉 사이클을 형성하지 않는다면
            if ds.collapsing_find(edge.v1) != ds.collapsing_find(edge.v2):
                #TE=TE U {(u, v)}
                mst.insert_edge(edge.v1, edge.v2, edge.weight)
                #UNION(u, v)
                ds.weighted_union(ds.collapsing_find(edge.v1), ds.collapsing_find(edge.v2))
                mst_edge_num+=1
            edge_idx+=1

        return mst

    def print_edges(self):
        for edge in self.edge_list:
            print("({}, {}) : {}".format(edge.v1, edge.v2, edge.weight))
            
g=Graph()
g.add_vertex(6)

g.insert_edge(0, 1, 10)
g.insert_edge(0, 2, 2)
g.insert_edge(0, 3, 8)
g.insert_edge(1, 2, 5)
g.insert_edge(1, 4, 12)
g.insert_edge(2, 3, 7)
g.insert_edge(2, 4, 17)
g.insert_edge(3, 4, 4)
g.insert_edge(3, 5, 14)

mst=g.MST_kruskal()
mst.print_edges()
(0, 2) : 2
(3, 4) : 4
(1, 2) : 5
(2, 3) : 7
(3, 5) : 14

11. prim 알고리즘

1) 정점 하나를 가진 트리에서 시작한다.

TV = {v1} 여기서 TV는 트리의 정점

2) 트리 내의 정점 u와 트리 밖의 정점 v를 잇는 에지 중 최소 비용을 가진 (u,v)를 트리 에지로 만든다. TE = TE{(u,v)}

3) 트리 밖의 정점 v도 트리의 정점으로 만든다.

TV = TVu{v}

4) TV = V(G)와 같아지면 종료한다.

트리 내의 정점u와 트리 밖의 정점 v를 잇는 에지를 선택하므로 사이클은 형성되지 않는다.

11

12

  • prim 알고리즘 수도코드

13

  • Min heap

14

15

  • min heap을 이용한 prim 알고리즘 수도코드

16

  • 파이썬 코드구현 실습

먼저 Prim 구현을 위한 min heap 구현실습

class Element:
    def __init__(self, v, w, _from):
        self.w=w
        self.v=v
        self._from=_from

class MinHeap:
    MAX_ELEMENTS=200
    def __init__(self):
        self.arr=[None for i in range(self.MAX_ELEMENTS)]
        self.heapsize=0
        #정점이 arr에 위치한 현재 인덱스
        self.pos=[None for i in range(self.MAX_ELEMENTS)]

    def is_empty(self):
        if self.heapsize==0:
            return True
        return False

    def is_full(self):
        if self.heapsize>=self.MAX_ELEMENTS:
            return True
        return False

    def __get_parent_idx(self, idx):
        return idx // 2

    def __get_left_child_idx(self, idx):
        return idx * 2

    def __get_right_child_idx(self, idx):
        return idx * 2 + 1

    def push(self, item):
        if self.is_full():
            raise IndexError("the heap is full!!")

        self.heapsize+=1
        cur_idx=self.heapsize

        #cur_idx가 루트가 아니고
        #item의 w가 cur_idx 부모의 키보다 작으면
        while cur_idx!=1 and item.w < self.arr[self.__get_parent_idx(cur_idx)].w:
            #리프 노드로 추가된 새로운 원소의 weight가 부모의 원소의 weight보다 
            #더 작으면 부모 원소를 한 칸 내린다 
            self.arr[cur_idx]=self.arr[self.__get_parent_idx(cur_idx)]
            #아래로 내려오는 부모 원소의 위치 인덱스 업데이트
            self.pos[self.arr[cur_idx].v]=cur_idx

            cur_idx=self.__get_parent_idx(cur_idx)

        self.arr[cur_idx]=item
        self.pos[item.v]=cur_idx

    def __get_smaller_child_idx(self, idx):
        if self.heapsize < self.__get_left_child_idx(idx):
            return None
        elif self.heapsize==self.__get_left_child_idx(idx):
            return self.__get_left_child_idx(idx)
        else:
            left_child=self.__get_left_child_idx(idx)
            right_child=self.__get_right_child_idx(idx)
            if self.arr[left_child].w < self.arr[right_child].w:
                return left_child
            else:
                return right_child

    def pop(self):
        if self.is_empty():
            return None

        #삭제된 후 반환될 원소
        rem_elem=self.arr[1]

        #맨 마지막에 위치한 원소
        temp=self.arr[self.heapsize]

        #루트에서 시작
        cur_idx=1
        smaller_child_idx=self.__get_smaller_child_idx(cur_idx)

        while smaller_child_idx and temp.w > self.arr[smaller_child_idx].w:
            #마지막 원소보다 weight가 큰 정점은 루트쪽으로 한칸 올라간다
            self.arr[cur_idx]=self.arr[smaller_child_idx]
            
            #이와 함께 루트쪽으로 올라간 정점의 현재 인덱스도 업데이트한다
            self.pos[self.arr[cur_idx].v]=cur_idx

            cur_idx=smaller_child_idx
            smaller_child_idx=self.__get_smaller_child_idx(cur_idx)
        
        self.arr[cur_idx]=temp
        self.pos[temp.v]=cur_idx

        self.heapsize-=1

        return rem_elem

    def top(self):
        if self.is_empty():
            return None

        return self.arr[1]

    #프림 알고리즘을 위해 추가된 함수
    def decrease_weight(self, new_elem):
        #업데이트될 정점의 현재 인덱스
        cur=self.pos[new_elem.v]

        #cur가 루트가 아니고 업데이트 될 원소의 weight가
        #부모 원소의 weight보다 작다면 부모 원소를 아래로 내리고
        #cur가 루트 쪽으로 올라간다
        while cur!= 1 and new_elem.w < self.arr[self.__get_parent_idx(cur)].w:
            #업데이트 될 원소의 weight가 부모 원소의 weight보다 작다면
            #부모 원소를 한 칸 아래로 내린다 
            self.arr[cur]=self.arr[self.__get_parent_idx(cur)]

            #내려온 원소의 위치 인덱스 업데이트
            self.pos[self.arr[cur].v]=cur    

            cur=self.__get_parent_idx(cur)

        self.arr[cur]=new_elem
        self.pos[new_elem.v]=cur

def print_heap(h):
    for i in range(1, h.heapsize+1):
        print("{}".format(h.arr[i].w), end="  ")
    print()

min heap 구현후 prim 알고리즘 구현실습

import math

class GNode:
    def __init__(self, vertex=None, weight=None):
        self.vertex=vertex
        self.weight=weight
        self.link=None

class Edge:
    def __init__(self, v1, v2, weight):
        self.v1=v1
        self.v2=v2
        self.weight=weight

class Graph:
    def __init__(self):
        self.adjacency_list=[]
        self.edge_list=[]

        self.vertex_num=0

    def add_vertex(self, vnum=1):
        self.adjacency_list.extend([None for _ in range(vnum)])
        self.vertex_num+=vnum
    
    def __add_node(self, vertex, node):
        cur=self.adjacency_list[vertex]
        if not cur:
            self.adjacency_list[vertex]=node
        else:
            while cur.link:
                cur=cur.link
            cur.link=node

    def insert_edge(self, u, v, weight):
        unode=GNode(u, weight)
        vnode=GNode(v, weight)

        self.__add_node(u, vnode)
        self.__add_node(v, unode)

        self.edge_list.append(Edge(u, v, weight))

    def get_min_v(self, w):
        _min=math.inf
        min_v=None
        for i in range(self.vertex_num):
            if w[i] < _min:
                _min=w[i]
                min_v=i
        return min_v

    def MST_prim1(self):
        #TE={}
        mst=Graph(); mst.add_vertex(self.vertex_num)
        #TV={}
        TV=set()

        w=[math.inf for _ in range(self.vertex_num)]
        _from=[None for _ in range(self.vertex_num)]

        w[0]=0
        #|TV| < |V|
        while len(TV) < self.vertex_num:
            v=self.get_min_v(w) ## 모든 버텍스를 하나하나 돌면서 작은거를 뽑아줌
            #TV <- TV U {v}
            TV.add(v)
            #TE <- TE U {(v, from[v])}
            if _from[v]!=None:
                mst.insert_edge(v, _from[v], w[v])
            #trick
            w[v]=math.inf
            #u adjacent to v
            u=self.adjacency_list[v]
            while u:
                if u.vertex not in TV and u.weight < w[u.vertex]:
                    w[u.vertex]=u.weight
                    _from[u.vertex]=v
                u=u.link
        return mst

    def MST_prim2(self):
        #최종적으로 만들어질 MST
        mst=Graph(); mst.add_vertex(self.vertex_num)
        #TV={} : MST 정점의 집합, 시작 노드부터 하나씩 채워나간다
        TV=set()

        #w_list : 각 정점의 w 값을 담아두기 위한 배열
        w_list=[None for _ in range(self.vertex_num)]
        #min heap에 w와 from을 가진 정점을 담아둔다
        #heap 초기화 : w->inf, from->None
        h=MinHeap()
        for i in range(1, self.vertex_num):
            w_list[i]=math.inf
            h.push(Element(i, math.inf, None))
        #시작 노드인 0은 w->0, from->None
        w_list[0]=0
        h.push(Element(0, 0, None))

        while not h.is_empty():
            #가중치가 가장 작은 에지 (from, v) : w
            #정보를 가진 정점 Element v
            v=h.pop()
            #TV에 정점을 추가
            TV.add(v.v)
            #TE에 에지 추가
            if v._from != None:
                mst.insert_edge(v.v, v._from, v.w)
            
            #TV에 정점이 추가되면 인접 정점 중 
            #트리 밖에 있는 정점에 대해 업데이트 시도
            #u는 새로 추가된 정점 v에 인접한 정점 노드
            u=self.adjacency_list[v.v]
            while u:
                #u가 트리 밖의 정점이고
                #기존 w 값보다 w(u, v)이 작다면 업데이트
                if u.vertex not in TV and u.weight < w_list[u.vertex]:
                    #w_list 업데이트
                    w_list[u.vertex]=u.weight
                    h.decrease_weight(Element(u.vertex, u.weight, v.v))
                u=u.link

        return mst

    def print_edges(self):
        for edge in self.edge_list:
            print("({}, {}) : {}".format(edge.v1, edge.v2, edge.weight))

if __name__=="__main__":
    g=Graph()
    g.add_vertex(6)

    g.insert_edge(0, 1, 10)
    g.insert_edge(0, 2, 2)
    g.insert_edge(0, 3, 8)
    g.insert_edge(1, 2, 5)
    g.insert_edge(1, 4, 12)
    g.insert_edge(2, 3, 7)
    g.insert_edge(2, 4, 17)
    g.insert_edge(3, 4, 4)
    g.insert_edge(3, 5, 14)

    #mst=g.MST_prim1()
    mst=g.MST_prim2()

    mst.print_edges()
(2, 0) : 2
(1, 2) : 5
(3, 2) : 7
(4, 3) : 4
(5, 3) : 14