Home Deep Graph Library (Pytorch)
Post
Cancel

Deep Graph Library (Pytorch)

지난 이야기: DGL을 활용해서 Graph Attention Network 구현해보기


Graph Attention Network를 구현하며 DGL (Deep Graph Library)를 사용해보았다. DGL에서 제공하는 API를 여러 부분에서 활용했었는데, 내부적으로 이것들이 어떻게 돌아가는지가 궁금했다. 그래서 오늘은 텐서를 흘려보내고 엣지 단위로 모델링을 수행하는 dgl의 기능들을 뜯어보려 한다.

Deep Graph Library API

DGL이 워낙 거대한 프로젝트라 오늘 확인해 볼 부분은 사실 아주아주 사소한 부분이다. API Reference 여기에 Apply / Batch / Sampler / Graph Store 등에 사용되는 함수들이 모두 모여있다. 전체적인 그림을 보기에는 지식의 깊이가 너무 얕은 이유로 우선 지난 번 사용해봤던 함수들을 살펴보자.

여기를 보면 Attention Layer를 포워드 시킬 때, (1) apply_edges / (2) update_all 이라는 함수들이 사용된다. 세세한 argument 들은 모두 무시하고 두 함수가 어떻게 작동하는지 이해해보자.


apply_edges

직관적으로 위의 함수의 경우에는 edge_attention 이라는 함수가 모든 엣지에 적용된다! 의 의미를 가질 것 같은데, 구체적으로 어떤 방식으로 적용되는 것일까?

1
2
3
4
g = dgl.DGLGraph()
g.add_nodes(3)   # [0,1,2] 3개의 노드 생성
g.add_edges([0, 1], [1, 2])   # [0 -> 1, 1 -> 2]  2개의 엣지 생성
g.edata['y'] = torch.ones(2, 1)   # 각 엣지에 값이 1인 y라는 피쳐 할당

g 그래프의 엣지가 가지고 있는 피쳐들


1
2
def double_feature(edges): return {'y': edges.data['y'] * 2}
g.apply_edges(func=double_feature, edges=1) # index 1의 엣지에 위의 함수를 적용해준다

1번 엣지의 값이 2로 증가했다


만약 기존에 존재하던 y 피쳐를 리턴하지 말고 x라는 것을 리턴했을 때는 어떻게 될까?

만약 edges = 1 이라고 index를 지정하지 않고 그냥 함수를 적용하면 어떻게 될까?

1
2
def new_double_feature(edges): return {'x': edges.data['y'] * 2}
g.apply_edges(func=new_double_feature)

기존 y피쳐는 그대로 있고 엣지에 x라는 피쳐가 새로 생성됬다


좋다. 그렇다면 지난 번 GAT에 적용했던 코드를 복습하여 확실히 이해해보자.

1
2
3
4
5
6
def edge_attention(self, edges):
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e': F.leaky_relu(a)}

self.g.apply_edges(self.edge_attention)

각 edge에는 src와 dst에 대한 클래스가 있고 그 안에 z라는 피쳐가 존재한다.

각 클래스의 z를 concat 한 후 attn_fc를 적용하고 relu 취한 결과물을 e라는 새로운 피쳐로 할당한다.

따라서 위의 함수를 실행한 후 g.edata[‘e’]를 수행하면 방금 할당된 결과물을 볼 수 있을 것이다.


update_all

위의 함수는 Reference에도 따로 예제가 없어서 그냥 그 흐름을 이해하는 것이 좋을 것 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def update_all(self,
               message_func="default",
               reduce_func="default",
               apply_node_func="default"):
        
	if message_func == "default":
    	message_func = self._message_func
	if reduce_func == "default":
    	reduce_func = self._reduce_func
	if apply_node_func == "default":
    	apply_node_func = self._apply_node_func
	assert message_func is not None
    assert reduce_func is not None

    with ir.prog() as prog:
    	scheduler.schedule_update_all(graph=AdaptedDGLGraph(self),
                                          message_func=message_func,
                                          reduce_func=reduce_func,
                                          apply_func=apply_node_func)
		Runtime.run(prog)

오! 흐름을 이해했다고 생각하고 Source 코드를 보니 다시 미궁에 빠졌다. ir.prog()가 의마하는 것은 런타임을 확인하기 위한 작업인 것 같고 scheduler 모듈을 통해 뭔가 해당 작업들을 병렬로 처리할 수 있게 해주나 보다. 물론 나의 추측일 뿐, with 아래의 코드가 어떻게 생겨먹은지는 잘 모르겠다. 패스!

This post is licensed under CC BY 4.0 by the author.

Graph Attention Networks (Pytorch)

grpah2vec (MLG 2017)