๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
1๏ธโƒฃ AI•DS/๐Ÿ“˜ GNN

[CS224W] NetworkX , pytorch geometric Tutorial

by isdawell 2022. 10. 7.
728x90

 

 

1๏ธโƒฃ  NetworkX Tutorial 


 

๐Ÿ”น NetworkX

 

https://networkx.org/

 

NetworkX — NetworkX documentation

NetworkX is a Python package for the creation, manipulation, and study of the structure, dynamics, and functions of complex networks.

networkx.org

 

•  Software for complex networks 

•  ์ •ํ˜•, ๋น„์ •ํ˜• ๋ฐ์ดํ„ฐ๋ฅผ ๋„คํŠธ์›Œํฌ์— ์ž…๋ ฅํ•  ์ˆ˜ ์žˆ๋‹ค. 

•  ๋‹ค์–‘ํ•œ ์œ ํ˜•์˜ ๋„คํŠธ์›Œํฌ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค. 

•  analyze network structure, build network models, design new network algorithms, draw networks, and much more

 

 

๐ŸŒ  library import 

 

import networkx as nx 

 

 

๐Ÿ”น NetworkX graph types 

 

•  4๊ฐ€์ง€ ๊ทธ๋ž˜ํ”„ ์œ ํ˜• 

 

Class (๋ถˆ๋Ÿฌ์˜ฌ๋•Œ) Type Self-loop allowed Parallel edges (Multiple edges) allowed
Graph undirected ์–‘๋ฐฉํ–ฅ Yes No
DiGraph directed ๋‹จ๋ฐฉํ–ฅ Yes No
MultiGraph undirected Yes Yes
MultiDiGraph directed Yes Yes

 

 

 

๐Ÿ”น Graph

 

•  networkX ๋Š” ์–‘๋ฐฉํ–ฅ, ๋‹จ๋ฐฉํ–ฅ ๊ทธ๋ž˜ํ”„์™€ ๊ฐ™์€ ๋‹ค์–‘ํ•œ ์œ ํ˜•์˜ ๊ทธ๋ž˜ํ”„๋ฅผ ์ €์žฅํ•˜๋Š”๋ฐ ๋‹ค์–‘ํ•œ ํด๋ž˜์Šค๋ฅผ ์ œ๊ณตํ•œ๋‹ค. Multigraph ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํด๋ž˜์Šค๋„ ์ง€์›ํ•œ๋‹ค. 

 

# ์–‘๋ฐฉํ–ฅ undirected ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ 
G = nx.Graph() 
print(G.is_directed()) 

# ๋‹จ๋ฐฉํ–ฅ directed ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ 
H = nx.DiGraph() 
print(H.is_directed()) 

# ๊ทธ๋ž˜ํ”„์— ์†์„ฑ ์ถ”๊ฐ€ํ•˜๊ธฐ 
G.graph['Name'] = 'bar' 

print(G.graph)

 

 

 

๐Ÿ”น Node 

 

•  NetworkX ์—์„œ๋Š” ๊ทธ๋ž˜ํ”„์— ๋…ธ๋“œ๋ฅผ ์‰ฝ๊ฒŒ ์ถ”๊ฐ€์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค : add_node, add_nodes_from 

 

•  add_node : ์†์„ฑ์„ ๊ฐ€์ง„ ํ•œ ๊ฐœ์˜ ๋…ธ๋“œ ์ถ”๊ฐ€ 

 

# ํ•˜๋‚˜์˜ (์†์„ฑ์„ ๊ฐ€์ง„) ๋…ธ๋“œ ์ถ”๊ฐ€

G.add_node(0, features = 5, label = 0)  # ์ธ๋ฑ์Šค์œ„์น˜, ์†์„ฑ (feature, label)

# node 0 ์˜ ์†์„ฑ ์ถœ๋ ฅํ•˜๊ธฐ 

node_0_attr = G.nodes[0] 
print('node 0 has the attributes {}'.format(node_0_attr))

 

 

•  add_nodes_from : ์†์„ฑ์„ ๊ฐ€์ง„ ์—ฌ๋Ÿฌ๊ฐœ์˜ ๋…ธ๋“œ๋ฅผ ์ถ”๊ฐ€

 

# โญ ์†์„ฑ์„ ๊ฐ€์ง„ ์—ฌ๋Ÿฌ๊ฐœ์˜ ๋…ธ๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ธฐ 

G.add_nodes_from([
    (1,{'feature' : 1, 'label' : 1}), 
    (2, {'feature' : 2, 'label' : 2})
]) # node, ์‚ฌ์ „ํ˜•ํƒœ๋กœ ์ €์žฅ 

# โญ ๋ชจ๋“  ๋…ธ๋“œ+์†์„ฑ์„ ์ถœ๋ ฅํ•ด๋ณด์ž : .nodes ๋กœ ์ ‘๊ทผ 

for node in G.nodes(data=True) : 
  print(node)  # data=True ์˜ต์…˜์„ ์ฃผ๋ฉด ๋…ธ๋“œ์˜ ์†์„ฑ์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค. 

# โญ ๋…ธ๋“œ์˜ ๊ฐœ์ˆ˜๋ฅผ ์ถœ๋ ฅํ•ด๋ณด์ž : .number_of_nodes()

num_nodes = G.number_of_nodes() 
print('G has {} nodes'.format(num_nodes))

 

 

๐Ÿ”น Edge

 

•  NetworkX ์—์„œ๋Š” ๊ทธ๋ž˜ํ”„์— ์—ฃ์ง€๋ฅผ ์‰ฝ๊ฒŒ ์ถ”๊ฐ€์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค

 

•  add_edge : ์†์„ฑ์„ ๊ฐ€์ง„ ํ•œ ๊ฐœ์˜ ์—ฃ์ง€ ์ถ”๊ฐ€ 

 

# 0.5์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง„ ์—ฃ์ง€ ํ•˜๋‚˜๋ฅผ ์ถ”๊ฐ€ 

G.add_edge(0, 1, weight = 0.5)  

# (0,1) ์—ฃ์ง€, ์ฆ‰ ๋…ธ๋“œ 0๊ณผ ๋…ธ๋“œ1์„ ์—ฐ๊ฒฐ์‹œ์ผœ์ฃผ๋Š” ์—ฃ์ง€์˜ ์†์„ฑ ์ถœ๋ ฅํ•˜๊ธฐ 

edge_0_1_attr = G.edges[(0,1)] 
print('Edge (0,1) attributes {}'.format(edge_0_1_attr))

 

•  add_edges_from : ์†์„ฑ์„ ๊ฐ€์ง„ ์—ฌ๋Ÿฌ๊ฐœ์˜ ์—ฃ์ง€๋ฅผ ์ถ”๊ฐ€

 

# โญ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง„ ์—ฃ์ง€ ์—ฌ๋Ÿฌ๊ฐœ๋ฅผ ์ถ”๊ฐ€ 

G.add_edges_from([
    (1,2, {'weight' : 0.3}), 
    (2,0,{'weight' : 0.1})
])

# โญ ๊ทธ๋ž˜ํ”„์— ์กด์žฌํ•˜๋Š” ๋ชจ๋“  ์—ฃ์ง€๋ฅผ ์ถœ๋ ฅํ•ด๋ณด๊ธฐ 

for edge in G.edges() : 
  print(edge) 

# โญ ์—ฃ์ง€์˜ ๊ฐœ์ˆ˜ ์ถœ๋ ฅํ•ด๋ณด๊ธฐ 
num_edges = G.number_of_edges() 

print('G ๋Š” {}๊ฐœ์˜ ์—ฃ์ง€๋ฅผ ๊ฐ€์ง„๋‹ค'.format(num_edges))

 

 

 

๐Ÿ”น ์‹œ๊ฐํ™” 

 

•  nx.draw( ๊ทธ๋ž˜ํ”„, with_labels = True ) 

 

nx.draw(G, with_labels = True) # ๋…ธ๋“œ์˜ ๋ ˆ์ด๋ธ”๋„ ํ‘œ์‹œํ•˜๊ธฐ

 

 

๐Ÿ”น Degree, neighbor 

 

•  Degree : ์ฐจ์ˆ˜, ์—ฐ๊ฒฐ๋œ ์—ฃ์ง€์˜ ๊ฐœ์ˆ˜ → G.degree[n] 

•  neighbor : ์ด์›ƒ๋…ธ๋“œ → G.neighbors() 

 

node_id = 1 

# node 1 ์˜ ์ฐจ์ˆ˜(์—ฐ๊ฒฐ๋œ ์—ฃ์ง€์˜ ๊ฐœ์ˆ˜) ํ™•์ธ : G.degree[n]
print('node {} ๋Š” {} ์ฐจ์ˆ˜ (degree)๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค'.format(node_id, G.degree[node_id]))

# node 1 ์˜ ์ด์›ƒ ๋…ธ๋“œ ์ถœ๋ ฅํ•˜๊ธฐ : G.neighbors() 
for neighbor in G.neighbors(node_id) :
  print('node {} ๋Š” node {} ๋ฅผ ์ด์›ƒ๋…ธ๋“œ๋กœ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค.'.format(node_id, neighbor))

 

 

๐Ÿ”น Other functions 

 

•  networkX ๋Š” ์œ ์šฉํ•œ ๋‹ค์–‘ํ•œ ๋ฉ”์„œ๋“œ๋“ค์„ ์ œ๊ณตํ•œ๋‹ค. 

•  ์˜ˆ) PageRank 

 

num_nodes = 4 

# ์ƒˆ๋กœ์šด path ๋ฅผ ๋งŒ๋“ค๊ณ  ๋‹จ๋ฐฉํ–ฅ ๊ทธ๋ž˜ํ”„๋กœ ๋ฐ”๊พธ๊ธฐ 

G = nx.DiGraph(nx.path_graph(num_nodes)) 
nx.draw(G, with_labels = True) 

# PageRank ๊ฐ€์ ธ์˜ค๊ธฐ 

pr = nx.pagerank(G, alpha = 0.8) 
pr

 

 

 

 

โ€ป PageRank : ๊ตฌ๊ธ€ ๊ฒ€์ƒ‰ ์—”์ง„์˜ ๊ธฐ๋ฐ˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ, ํ•˜์ดํผ๋งํฌ๋ฅผ ์ด์šฉํ•ด ์›นํŽ˜์ด์ง€ ์ค‘์š”๋„๋ฅผ ์ธก์ •ํ•˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด๋‹ค. ๋ชจ๋“  ๊ทธ๋ž˜ํ”„์—์„œ ๋…ธ๋“œ์˜ ์ค‘์š”๋„ ์ธก์ •์— ์‚ฌ์šฉํ•˜๋Š” ์šฉ๋„๋กœ ์“ฐ์ธ๋‹ค. 

 

 

 

 

 

 

 

 

2๏ธโƒฃ  Pytorch Geometric Tutorial 


 

๐Ÿ”น Pytorch Geometric 

 

โ†ช  simple graph-structured ์˜ˆ์ œ :  Zachary's karate club network

 

 

โ†ช  ๊ฐ€๋ผํ…Œ ์šด๋™ ๋™์•„๋ฆฌ์— ๊ฐ€์ž…๋œ 34๋ช… ๋ฉค๋ฒ„๋“ค ์‚ฌ์ด์˜ social network ๊ด€๊ณ„๋ง 

 

โ†ช ๋ฌธ์ œ์ •์˜ : ๋ฉค๋ฒ„ ๊ฐ„์˜ ์ƒํ˜ธ์ž‘์šฉ (A์™€ B๊ฐ€ ์‚ฌ์ด๊ฐ€ ์•ˆ์ข‹์•„์ ธ์„œ 2๊ฐœ๋กœ club ์ด ๋‚˜๋‰˜์–ด์ง) ์„ ์ผ์œผํ‚ค๋Š” ์ปค๋ฎค๋‹ˆํ‹ฐ๋ฅผ ๋ฐœ๊ฒฌํ•˜๋Š” ๊ฒƒ์— ์ดˆ์  

 

 

๐Ÿ”น  ๋ฐ์ดํ„ฐ์…‹ ๊ฐ€์ ธ์˜ค๊ธฐ 

 

# ๋ฐ์ดํ„ฐ์…‹ ๊ฐ€์ ธ์˜ค๊ธฐ 

from torch_geometric.datasets import KarateClub  

dataset = KarateClub() 
print(f'Dataset : {dataset} :')  
print('------------------') 
print(f'๊ทธ๋ž˜ํ”„ ๊ฐœ์ˆ˜ : {len(dataset)}') 
print(f'ํ”ผ์ฒ˜์˜ ๊ฐœ์ˆ˜ : {dataset.num_features}') 
print(f'ํด๋ž˜์Šค์˜ ๊ฐœ์ˆ˜ : {dataset.num_classes}')

 

 

โ†ช ๊ฐ€๋ผํ…Œ ๋™์•„๋ฆฌ ๋ฐ์ดํ„ฐ์…‹์˜ ์†์„ฑ์„ ์‚ดํŽด๋ณด์ž๋ฉด, ์ด ๋ฐ์ดํ„ฐ์…‹์˜ ๊ฒฝ์šฐ ํ•˜๋‚˜์˜ ๊ทธ๋ž˜ํ”„๋กœ ๊ตฌ์„œ์˜ค๋””์–ด ์žˆ์œผ๋ฉฐ, ๊ฐ ๋…ธ๋“œ๋Š” ๋™์•„๋ฆฌ์˜ ๊ฐ ๊ตฌ์„ฑ์› (34๋ช…) ์„ ๋ฌ˜์‚ฌํ•˜๊ธฐ ์œ„ํ•ด์—ฌ 34 ์ฐจ์›์˜ feature vector ๋กœ ํ‘œํ˜„๋œ๋‹ค. ๋˜ํ•œ ๊ทธ๋ž˜ํ”„๋Š” ๊ฐ ๋…ธ๋“œ๊ฐ€ ์†ํ•œ ์ปค๋ฎค๋‹ˆํ‹ฐ ๋‹จ์ฒด๋ฅผ ํ‘œํ˜„ํ•˜๊ธฐ ์œ„ํ•ด 4๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. 

 

 

๐Ÿ”น  ๊ทธ๋ž˜ํ”„ ํŠน์ง•๋“ค ์‚ดํŽด๋ณด๊ธฐ 

 

# ๊ทธ๋ž˜ํ”„ ์‚ดํŽด๋ณด๊ธฐ 

data = dataset[0]  # ์ฒซ๋ฒˆ์งธ ๊ทธ๋ž˜ํ”„ ๊ฐ์ฒด ๊ฐ€์ ธ์˜ค๊ธฐ 

print(data) 
print('--------------------------------------------------') 

print(f'๋…ธ๋“œ์˜ ๊ฐœ์ˆ˜ : {data.num_nodes}') 
print(f'์—ฃ์ง€์˜ ๊ฐœ์ˆ˜ : {data.num_edges}') 
print(f'ํ‰๊ท  ๋…ธ๋“œ ์ฐจ์ˆ˜ : {(2*data.num_edges)/data.num_nodes : .2f}') 
print(f'ํ›ˆ๋ จ ๋…ธ๋“œ ์ˆ˜ : {data.train_mask.sum()}')  
print(f'ํ›ˆ๋ จ ๋…ธ๋“œ ๋ ˆ์ด๋ธ”์˜ ๋น„์œจ : {int(data.train_mask.sum())/data.num_nodes : .2f}') 
print(f'๊ณ ๋ฆฝ ๋…ธ๋“œ๋ฅผ ํฌํ•จํ•˜๋Š”์ง€ ์—ฌ๋ถ€ : {data.has_isolated_nodes()}') 
print(f'self loop ๋ฅผ ํฌํ•จํ•˜๋Š”์ง€ ์—ฌ๋ถ€ : {data.has_self_loops()}') 
print(f'์–‘๋ฐฉํ–ฅ ๊ทธ๋ž˜ํ”„์ธ์ง€ ์—ฌ๋ถ€ : {data.is_undirected()}')

 

 

•  Edge list ๋ฐฉ์‹์œผ๋กœ ์—ฃ์ง€ ์ •๋ณด๋ฅผ ํ‘œํ˜„ 

 

data.edge_index.T

 

 

 

•  ๋ฐ์ดํ„ฐ ๊ตฌ์กฐ 

# data 

print(data)

 

โ†ช  ์†์„ฑ๊ณผ shape ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

1. edge_index : graph connectivity ๊ทธ๋ž˜ํ”„ ์—ฐ๊ฒฐ์„ฑ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค.

 

2. x : node features ๋…ธ๋“œ ํ”ผ์ฒ˜์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. (34,34) ํ‘œ๊ธฐ๋ฅผ ํ†ตํ•ด, 34๊ฐœ์˜ ๋…ธ๋“œ ๊ฐ๊ฐ์ด 34 ์ฐจ์›์˜ feature vector ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Œ์„ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

3. y : node labels ์—ฌ๊ธฐ์„  ๊ฐ’์ด 34 ์ด๋ฏ€๋กœ ๊ฐ ๋…ธ๋“œ๊ฐ€ ํŠน์ •ํ•œ ํ•˜๋‚˜์˜ ํด๋ž˜์Šค์— ํ• ๋‹น๋˜์—ˆ์Œ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. 

 

4. train_mask : ๊ฐ ๋…ธ๋“œ๊ฐ€ ์–ด๋–ค ์ปค๋ฎค๋‹ˆํ‹ฐ์— ์†ํ•ด์žˆ๋Š”์ง€ ์„ค๋ช…ํ•˜๋Š” ์†์„ฑ์œผ๋กœ, ์—ฌ๊ธฐ์„œ ์šฐ๋ฆฌ๊ฐ€ ์•Œ๊ณ ์žˆ๋Š” ground-truth label ์€ 4๊ฐœ ๋…ธ๋“œ ๋ฟ์ด๋ฉฐ (์ด์ „ ์ถœ๋ ฅ ์ฝ”๋“œ ์ฐธ๊ณ ) ๋‚˜๋จธ์ง€๋Š” inference ๋ฅผ ํ†ตํ•ด ๋งž์ถฐ์•ผ ํ•œ๋‹ค. 

 

 

 

 

๐Ÿ”น  Edge Index , Viz

 

•  edge_index ๋ฅผ ์ถœ๋ ฅํ•ด๋ณด์ž 

 

 

โ†ช  [ source node , destination node ] ํ˜•์‹์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Œ : COO format (coordinate format) ์ด๋ผ๊ณ ๋„ ๋ถ€๋ฅธ๋‹ค. 

 

โ†ช ์‹œ๊ฐํ™”๋Š” networkx ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ํฌ๋งท์œผ๋กœ ์ง„ํ–‰ํ•œ๋‹ค. 

 

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected = True) 
visualize(G, color = data.y)

 

 

 

 

 

 

๐Ÿ”น  ๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง ๊ตฌ์ถ• Graph convolutional network 

 

•  GCN layer : GCNConv 

 

•  GNN ์˜ output ํ˜•ํƒœ๋Š” ? 

 

โ†ช  ๊ฐ node ๋Š” feature vector Xi ํ˜•ํƒœ๋กœ ํ‘œํ˜„๋˜๋ฉฐ, ์ด๋ฅผ input ํ˜•ํƒœ๋กœ ๋„ฃ์œผ๋ฉด output ์€ embedding ์œผ๋กœ ๋„์ถœ๋˜๋Š”๋ฐ, embedding ์€ downstream task ์— ๋งž๊ฒŒ ํ‘œํ˜„๋œ ๋…ธ๋“œ ๋ฒกํ„ฐ๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ํ•™์Šต๋œ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ๊ฐ ๋…ธ๋“œ feature ๊ฐ€ ๋งตํ•‘๋˜๋ฉด, ๊ทธ๋Ÿฌํ•œ ์ž„๋ฒ ๋”ฉ์„ ํ†ตํ•˜์—ฌ node ๋‹จ์œ„, edge ๋‹จ์œ„, graph ๋‹จ์œ„์˜ regression ํ˜น์€ classification ๋ฌธ์ œ๋“ค์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. 

 

๐Ÿ‘‰ ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ ๋…ธ๋“œ๋ฅผ ๊ฐœ๋ณ„ ์ปค๋ฎค๋‹ˆํ‹ฐ ํด๋ž˜์Šค์— ๋ถ„๋ฅ˜ํ•˜๋Š” Node classification ๋ฌธ์ œ๋ฅผ ๋‹ค๋ค„๋ณผ ๊ฒƒ์ด๋‹ค. (learn embedding) 

 

 

class GCN(torch.nn.Module) : 

  def __init__(self) : 
    super().__init__() 
    torch.manual_seed(1234) 
    self.conv1 = GCNConv(dataset.num_features, 4) 
    self.conv2 = GCNConv(4,4) 
    self.conv3 = GCNConv(4,2) 
    
    # 3๊ฐœ์˜ graph convolutional layer ๋ฅผ ์Œ“์Œ 
    
    
    # ๋ถ„๋ฅ˜๊ธฐ
    self.classifier = Linear(2, dataset.num_classes) 

  def forward(self, x, edge_index) : 
    h = self.conv1(x, edge_index) 
    h = h.tanh() 
    h = self.conv2(h, edge_index) 
    h = h.tanh() 
    h = self.conv3(h, edge_index) 
    h = h.tanh() 

    out = self.classifier(h) 

    return out, h 

model = GCN() 
print(model)

 

•  GCN ์—์„œ ๊ฐ layer ๋Š” 1-hop neighborbood ์ฆ‰, ๋ฐ”๋กœ ์—ฐ๊ฒฐ๋œ ์ด์›ƒ ๋…ธ๋“œ๋กœ๋ถ€ํ„ฐ ์ •๋ณด๋ฅผ ์ˆ˜์ง‘ํ•˜์—ฌ ์ž„๋ฒ ๋”ฉ์„ ์ง„ํ–‰ํ•˜๊ณ  3๊ฐœ์˜ layer ๋ฅผ ํ•ฉ์น  ๋• ๊ฐ ๋…ธ๋“œ์˜ 3-hop neighborgood ๋กœ๋ถ€ํ„ฐ ์ •๋ณด๋ฅผ ์ˆ˜์ง‘ํ•œ๋‹ค. 

 

•  GCNConv layer ๋Š” ๋…ธ๋“œ feature ์˜ ์ฐจ์›์„ 34 ์—์„œ 2๋กœ ๊ฐ์†Œ์‹œํ‚จ๋‹ค. ๋˜ํ•œ ๊ฐ layer ๋Š” ํ•˜์ดํผ๋ณผ๋ฆญ ํƒ„์  ํŠธ ํ™œ์„ฑํ™” ํ•จ์ˆ˜ tanh ๋ฅผ ํ†ตํ•ด ๋น„์„ ํ˜•์„ฑ์„ ๋†’์ธ๋‹ค. 

 

•  ๋งˆ์ง€๋ง‰์˜ classifier ์— ํ•ด๋‹นํ•˜๋Š” linear ์ธต์€ ๋…ธ๋“œ๋ฅผ 4๊ฐœ์˜ ํด๋ž˜์Šค ์ค‘์— ํ•˜๋‚˜๋กœ ํ• ๋‹นํ•˜๊ธฐ ์œ„ํ•œ ๋ถ„๋ฅ˜๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์„ ํ˜• ๋ณ€ํ™˜์„ ์ง„ํ–‰ํ•˜๋„๋ก ํ•œ๋‹ค. 

 

 

๐Ÿ”น  ๊ฐ€์ค‘์น˜ ํ•™์Šต ์ด์ „ ๋ชจ์Šต ์‹œ๊ฐํ™” 

 

%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt

# Visualization function for NX graph or PyTorch tensor
def visualize(h, color, epoch=None, loss=None, accuracy=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])

    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None and accuracy['train'] is not None and accuracy['val'] is not None:
            plt.xlabel((f'Epoch: {epoch}, Loss: {loss.item():.4f} \n'
                       f'Training Accuracy: {accuracy["train"]*100:.2f}% \n'
                       f' Validation Accuracy: {accuracy["val"]*100:.2f}%'),
                       fontsize=16)
    else:
    	# ๐Ÿ“Œ
        nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                         node_color=color, cmap="Set2")
    plt.show()

 

 

โ†ช nx.spring_layout : Fruchterman-Reingold force-directed ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์‚ฌ์šฉํ•˜์—ฌ ๋…ธ๋“œ๋ฅผ ๋ฐฐ์น˜ 

 

๐Ÿ‘€  ๋„คํŠธ์›Œํฌ์—์„œ๋Š” ๊ฐœ๋ณ„ ๋…ธ๋“œ์— ์ขŒํ‘œ๊ฐ’์ด ์—†๊ธฐ ๋•Œ๋ฌธ์— ์–ด๋””์— ๊ทธ๋ ค์•ผ ํ•˜๋Š”์ง€ ์• ๋งคํ•˜๋‹ค. ๋…ธ๋“œ์˜ ์ˆ˜๊ฐ€ ๋งŽ๋‹ค๋ฉด ๋”์šฑ์ด ํ•ด๋‹น ๋„คํŠธ์›Œํฌ์˜ ํŠน์ง•์„ ์ž˜ ๋ณด์—ฌ์ง€๋„๋ก ์‹œ๊ฐํ™” ํ•ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๋ณต์žกํ•œ ๋„คํŠธ์›Œํฌ์—์„œ ๊ฐ ๋…ธ๋“œ์˜ ์ขŒํ‘œ๋ฅผ ์ž˜ ๋ฐฐ์น˜ํ•˜์—ฌ ์‹œ๊ฐํ™” ํ•˜๋Š” ๋ฐฉ๋ฒ•์ด ๋‹ค์–‘ํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๋‚˜์˜ค๊ณ  ์žˆ๋‹ค.

 

 

 

 

model = GCN()

_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

visualize(h, color=data.y)

 

 

๐Ÿ‘€   ๊ฐ€์ค‘์น˜ training ์ด์ „์— (์ดˆ๊ธฐํ™”๋œ ์ƒํƒœ์—์„œ๋„) ๊ฐ™์€ ์ƒ‰(์ปค๋ฎค๋‹ˆํ‹ฐ)์˜ ๋…ธ๋“œ๋“ค์€ ์ด๋ฏธ ์ž„๋ฒ ๋”ฉ ๊ณต๊ฐ„์—์„œ ๋ฐ€์ ‘ํ•˜๊ฒŒ ์œ„์น˜ํ•˜๊ณ  ์žˆ์Œ์„ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ์ด๋Š” GNN ์ด ๋น„์Šทํ•œ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์€ input ๊ทธ๋ž˜ํ”„์—์„œ ๊ฐ€๊นŒ์šด ์œ„์น˜์— ์กด์žฌํ•œ๋‹ค๋Š” inductive bias ๋ฅผ ์ดˆ๋ž˜ํ•จ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. 

 

โ†ช Inductive Bias ๋ž€ : ์ฃผ์–ด์ง€์ง€ ์•Š์€ ์ž…๋ ฅ์˜ ์ถœ๋ ฅ์„ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์„ ๋†’์ด๊ธฐ ์œ„ํ•ด ๋งŒ์•ฝ์˜ ์ƒํ™ฉ์— ๋Œ€ํ•œ ์ถ”๊ฐ€์ ์ธ ๊ฐ€์ •์— ํ•ด๋‹นํ•œ๋‹ค. ๋ณด์ง€ ๋ชปํ•œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ๋„ ๊ท€๋‚ฉ์  ์ถ”๋ก ์ด ๊ฐ€๋Šฅํ•ด์ง€๋„๋ก ํ•˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๊ฐ€์ •์˜ ์ง‘ํ•ฉ์ด๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. CNN ์˜ ๊ฒฝ์šฐ์—๋Š” Locality Translation invariance (์–ด๋–ค ์‚ฌ๋ฌผ์ด ๋“ค์–ด์žˆ๋Š” ์ด๋ฏธ์ง€๋ฅผ ์ œ๊ณตํ•ด์ค„ ๋•Œ ์‚ฌ๋ฌผ์˜ ์œ„์น˜๊ฐ€ ๋ฐ”๋€Œ์–ด๋„ ํ•ด๋‹น ์‚ฌ๋ฌผ์„ ์ธ์‹ ๊ฐ€๋Šฅํ•˜๋‹ค) ์— ๋Œ€ํ•œ Relation Inductive Bias ๋ฅผ ๊ฐ–๋Š”๋ฐ ๊ฐ entity ๊ฐ„์˜ ๊ด€๊ณ„๊ฐ€ ์„œ๋กœ ๊ฐ€๊นŒ์šด ๋น„์Šทํ•œ ์š”์†Œ๋“ค์ด ๋ชจ์—ฌ์žˆ๋‹ค๋Š” ๋œป์ด๋‹ค. 

 

 

โœ” ์ฐธ๊ณ  : https://re-code-cord.tistory.com/entry/Inductive-Bias%EB%9E%80-%EB%AC%B4%EC%97%87%EC%9D%BC%EA%B9%8C

 

Inductive Bias๋ž€ ๋ฌด์—‡์ผ๊นŒ?

๋จธ์‹ ๋Ÿฌ๋‹์—์„œ Bias๋Š” ๋ฌด์Šจ ์˜๋ฏธ์ผ๊นŒ?  Inductive Bias๋ผ๋Š” ์šฉ์–ด์—์„œ, Bias๋ผ๋Š” ์šฉ์–ด๋Š” ๋ฌด์—‡์„ ์˜๋ฏธํ• ๊นŒ? ๋”ฅ๋Ÿฌ๋‹์„ ๊ณต๋ถ€ํ•˜๋‹ค ๋ณด๋ฉด, Bias๊ณผ Variance๋ฅผ ํ•œ ๋ฒˆ์ฏค์€ ๋“ค์–ด๋ดค์„ ๊ฒƒ์ด๋‹ค. Bias๋Š” ํƒ€๊ฒŸ๊ณผ ์˜ˆ์ธก๊ฐ’์ด ์–ผ

re-code-cord.tistory.com

 

 

 

๐Ÿ”น  ํ›ˆ๋ จ 

 

•  loss ๊ฐ€ 1.43 ์—์„œ 0.02 ๊นŒ์ง€ ๊ฐ์†Œํ–ˆ๋‹ค. 

 

model = GCN()  
criterion = torch.nn.CrossEntropyLoss()  # loss 
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01) # optimizer 

def train(data) : 
  optimizer.zero_grad() # gradients ์ดˆ๊ธฐํ™” 
  out, h = model(data.x, data.edge_index) # forward 
  loss = criterion(out[data.train_mask], data.y[data.train_mask]) # loss ๊ณ„์‚ฐ (ํ›ˆ๋ จ ๋…ธ๋“œ) 
  loss.backward() # ์—ญ์ „ํŒŒ 
  optimizer.step() # ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ 
  return loss, h 

for epoch in range(401) : 
  loss, h = train(data) 
  if epoch % 10 == 0 : 
    print(loss)

 

โ€ป predicted label : out 

โ€ป ground-truth label : data.y 

 

 

•  ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์ด ํ›ˆ๋ จ ์‹œ๊ฐ„์— ๋”ฐ๋ผ evolve ํ•˜๋Š” ๊ณผ์ • ์‹œ๊ฐํ™” 

 

 

 

 

 

๐Ÿ‘‰ 3-layer GCN ๋ชจ๋ธ์ด ๋น„์Šทํ•œ ํด๋ž˜์Šค์˜ ๋…ธ๋“œ๋“ค์„ ์œ ์‚ฌํ•œ ์œ„์น˜์— ๋ฐฐ์น˜ํ•˜๋ฉฐ ์ž„๋ฒ ๋”ฉ์„ ๋ฐ”๊พธ์–ด ๊ฐ€๋ฉฐ ํ•™์Šต์„ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

 

728x90

'1๏ธโƒฃ AIโ€ขDS > ๐Ÿ“˜ GNN' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[CS224W] Graph Neural Network  (0) 2022.11.24
[CS224W] Message Passing and Node classification  (0) 2022.11.17
[CS224W] PageRank  (0) 2022.11.02
[CS224W] 1๊ฐ• Machine Learning With Graphs  (0) 2022.10.11
Pytorch Geometric Basic code  (1) 2022.09.30

๋Œ“๊ธ€