1๏ธโฃ NetworkX Tutorial
๐น NetworkX
• Software for complex networks
• ์ ํ, ๋น์ ํ ๋ฐ์ดํฐ๋ฅผ ๋คํธ์ํฌ์ ์ ๋ ฅํ ์ ์๋ค.
• ๋ค์ํ ์ ํ์ ๋คํธ์ํฌ๋ฅผ ์์ฑํ ์ ์๋ค.
• analyze network structure, build network models, design new network algorithms, draw networks, and much more
๐ library import
import networkx as nx
๐น NetworkX graph types
• 4๊ฐ์ง ๊ทธ๋ํ ์ ํ
Class (๋ถ๋ฌ์ฌ๋) | Type | Self-loop allowed | Parallel edges (Multiple edges) allowed |
Graph | undirected ์๋ฐฉํฅ | Yes | No |
DiGraph | directed ๋จ๋ฐฉํฅ | Yes | No |
MultiGraph | undirected | Yes | Yes |
MultiDiGraph | directed | Yes | Yes |
๐น Graph
• networkX ๋ ์๋ฐฉํฅ, ๋จ๋ฐฉํฅ ๊ทธ๋ํ์ ๊ฐ์ ๋ค์ํ ์ ํ์ ๊ทธ๋ํ๋ฅผ ์ ์ฅํ๋๋ฐ ๋ค์ํ ํด๋์ค๋ฅผ ์ ๊ณตํ๋ค. Multigraph ๋ฅผ ์์ฑํ๋ ํด๋์ค๋ ์ง์ํ๋ค.
# ์๋ฐฉํฅ undirected ๊ทธ๋ํ ์์ฑ
G = nx.Graph()
print(G.is_directed())
# ๋จ๋ฐฉํฅ directed ๊ทธ๋ํ ์์ฑ
H = nx.DiGraph()
print(H.is_directed())
# ๊ทธ๋ํ์ ์์ฑ ์ถ๊ฐํ๊ธฐ
G.graph['Name'] = 'bar'
print(G.graph)
๐น Node
• NetworkX ์์๋ ๊ทธ๋ํ์ ๋ ธ๋๋ฅผ ์ฝ๊ฒ ์ถ๊ฐ์ํฌ ์ ์๋ค : add_node, add_nodes_from
• add_node : ์์ฑ์ ๊ฐ์ง ํ ๊ฐ์ ๋ ธ๋ ์ถ๊ฐ
# ํ๋์ (์์ฑ์ ๊ฐ์ง) ๋
ธ๋ ์ถ๊ฐ
G.add_node(0, features = 5, label = 0) # ์ธ๋ฑ์ค์์น, ์์ฑ (feature, label)
# node 0 ์ ์์ฑ ์ถ๋ ฅํ๊ธฐ
node_0_attr = G.nodes[0]
print('node 0 has the attributes {}'.format(node_0_attr))
• add_nodes_from : ์์ฑ์ ๊ฐ์ง ์ฌ๋ฌ๊ฐ์ ๋ ธ๋๋ฅผ ์ถ๊ฐ
# โญ ์์ฑ์ ๊ฐ์ง ์ฌ๋ฌ๊ฐ์ ๋
ธ๋๋ฅผ ์ถ๊ฐํ๊ธฐ
G.add_nodes_from([
(1,{'feature' : 1, 'label' : 1}),
(2, {'feature' : 2, 'label' : 2})
]) # node, ์ฌ์ ํํ๋ก ์ ์ฅ
# โญ ๋ชจ๋ ๋
ธ๋+์์ฑ์ ์ถ๋ ฅํด๋ณด์ : .nodes ๋ก ์ ๊ทผ
for node in G.nodes(data=True) :
print(node) # data=True ์ต์
์ ์ฃผ๋ฉด ๋
ธ๋์ ์์ฑ์ ๋ฐํํ๋ค.
# โญ ๋
ธ๋์ ๊ฐ์๋ฅผ ์ถ๋ ฅํด๋ณด์ : .number_of_nodes()
num_nodes = G.number_of_nodes()
print('G has {} nodes'.format(num_nodes))
๐น Edge
• NetworkX ์์๋ ๊ทธ๋ํ์ ์ฃ์ง๋ฅผ ์ฝ๊ฒ ์ถ๊ฐ์ํฌ ์ ์๋ค
• add_edge : ์์ฑ์ ๊ฐ์ง ํ ๊ฐ์ ์ฃ์ง ์ถ๊ฐ
# 0.5์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์ฃ์ง ํ๋๋ฅผ ์ถ๊ฐ
G.add_edge(0, 1, weight = 0.5)
# (0,1) ์ฃ์ง, ์ฆ ๋
ธ๋ 0๊ณผ ๋
ธ๋1์ ์ฐ๊ฒฐ์์ผ์ฃผ๋ ์ฃ์ง์ ์์ฑ ์ถ๋ ฅํ๊ธฐ
edge_0_1_attr = G.edges[(0,1)]
print('Edge (0,1) attributes {}'.format(edge_0_1_attr))
• add_edges_from : ์์ฑ์ ๊ฐ์ง ์ฌ๋ฌ๊ฐ์ ์ฃ์ง๋ฅผ ์ถ๊ฐ
# โญ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์ฃ์ง ์ฌ๋ฌ๊ฐ๋ฅผ ์ถ๊ฐ
G.add_edges_from([
(1,2, {'weight' : 0.3}),
(2,0,{'weight' : 0.1})
])
# โญ ๊ทธ๋ํ์ ์กด์ฌํ๋ ๋ชจ๋ ์ฃ์ง๋ฅผ ์ถ๋ ฅํด๋ณด๊ธฐ
for edge in G.edges() :
print(edge)
# โญ ์ฃ์ง์ ๊ฐ์ ์ถ๋ ฅํด๋ณด๊ธฐ
num_edges = G.number_of_edges()
print('G ๋ {}๊ฐ์ ์ฃ์ง๋ฅผ ๊ฐ์ง๋ค'.format(num_edges))
๐น ์๊ฐํ
• nx.draw( ๊ทธ๋ํ, with_labels = True )
nx.draw(G, with_labels = True) # ๋
ธ๋์ ๋ ์ด๋ธ๋ ํ์ํ๊ธฐ
๐น Degree, neighbor
• Degree : ์ฐจ์, ์ฐ๊ฒฐ๋ ์ฃ์ง์ ๊ฐ์ → G.degree[n]
• neighbor : ์ด์๋ ธ๋ → G.neighbors()
node_id = 1
# node 1 ์ ์ฐจ์(์ฐ๊ฒฐ๋ ์ฃ์ง์ ๊ฐ์) ํ์ธ : G.degree[n]
print('node {} ๋ {} ์ฐจ์ (degree)๋ฅผ ๊ฐ์ง๊ณ ์๋ค'.format(node_id, G.degree[node_id]))
# node 1 ์ ์ด์ ๋
ธ๋ ์ถ๋ ฅํ๊ธฐ : G.neighbors()
for neighbor in G.neighbors(node_id) :
print('node {} ๋ node {} ๋ฅผ ์ด์๋
ธ๋๋ก ๊ฐ์ง๊ณ ์๋ค.'.format(node_id, neighbor))
๐น Other functions
• networkX ๋ ์ ์ฉํ ๋ค์ํ ๋ฉ์๋๋ค์ ์ ๊ณตํ๋ค.
• ์) PageRank
num_nodes = 4
# ์๋ก์ด path ๋ฅผ ๋ง๋ค๊ณ ๋จ๋ฐฉํฅ ๊ทธ๋ํ๋ก ๋ฐ๊พธ๊ธฐ
G = nx.DiGraph(nx.path_graph(num_nodes))
nx.draw(G, with_labels = True)
# PageRank ๊ฐ์ ธ์ค๊ธฐ
pr = nx.pagerank(G, alpha = 0.8)
pr
โป PageRank : ๊ตฌ๊ธ ๊ฒ์ ์์ง์ ๊ธฐ๋ฐ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก, ํ์ดํผ๋งํฌ๋ฅผ ์ด์ฉํด ์นํ์ด์ง ์ค์๋๋ฅผ ์ธก์ ํ๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ค. ๋ชจ๋ ๊ทธ๋ํ์์ ๋ ธ๋์ ์ค์๋ ์ธก์ ์ ์ฌ์ฉํ๋ ์ฉ๋๋ก ์ฐ์ธ๋ค.
2๏ธโฃ Pytorch Geometric Tutorial
๐น Pytorch Geometric
โช simple graph-structured ์์ : Zachary's karate club network
โช ๊ฐ๋ผํ ์ด๋ ๋์๋ฆฌ์ ๊ฐ์ ๋ 34๋ช ๋ฉค๋ฒ๋ค ์ฌ์ด์ social network ๊ด๊ณ๋ง
โช ๋ฌธ์ ์ ์ : ๋ฉค๋ฒ ๊ฐ์ ์ํธ์์ฉ (A์ B๊ฐ ์ฌ์ด๊ฐ ์์ข์์ ธ์ 2๊ฐ๋ก club ์ด ๋๋์ด์ง) ์ ์ผ์ผํค๋ ์ปค๋ฎค๋ํฐ๋ฅผ ๋ฐ๊ฒฌํ๋ ๊ฒ์ ์ด์
๐น ๋ฐ์ดํฐ์ ๊ฐ์ ธ์ค๊ธฐ
# ๋ฐ์ดํฐ์
๊ฐ์ ธ์ค๊ธฐ
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
print(f'Dataset : {dataset} :')
print('------------------')
print(f'๊ทธ๋ํ ๊ฐ์ : {len(dataset)}')
print(f'ํผ์ฒ์ ๊ฐ์ : {dataset.num_features}')
print(f'ํด๋์ค์ ๊ฐ์ : {dataset.num_classes}')
โช ๊ฐ๋ผํ ๋์๋ฆฌ ๋ฐ์ดํฐ์ ์ ์์ฑ์ ์ดํด๋ณด์๋ฉด, ์ด ๋ฐ์ดํฐ์ ์ ๊ฒฝ์ฐ ํ๋์ ๊ทธ๋ํ๋ก ๊ตฌ์์ค๋์ด ์์ผ๋ฉฐ, ๊ฐ ๋ ธ๋๋ ๋์๋ฆฌ์ ๊ฐ ๊ตฌ์ฑ์ (34๋ช ) ์ ๋ฌ์ฌํ๊ธฐ ์ํด์ฌ 34 ์ฐจ์์ feature vector ๋ก ํํ๋๋ค. ๋ํ ๊ทธ๋ํ๋ ๊ฐ ๋ ธ๋๊ฐ ์ํ ์ปค๋ฎค๋ํฐ ๋จ์ฒด๋ฅผ ํํํ๊ธฐ ์ํด 4๊ฐ์ ํด๋์ค๋ฅผ ๊ฐ์ง๊ณ ์๋ค.
๐น ๊ทธ๋ํ ํน์ง๋ค ์ดํด๋ณด๊ธฐ
# ๊ทธ๋ํ ์ดํด๋ณด๊ธฐ
data = dataset[0] # ์ฒซ๋ฒ์งธ ๊ทธ๋ํ ๊ฐ์ฒด ๊ฐ์ ธ์ค๊ธฐ
print(data)
print('--------------------------------------------------')
print(f'๋
ธ๋์ ๊ฐ์ : {data.num_nodes}')
print(f'์ฃ์ง์ ๊ฐ์ : {data.num_edges}')
print(f'ํ๊ท ๋
ธ๋ ์ฐจ์ : {(2*data.num_edges)/data.num_nodes : .2f}')
print(f'ํ๋ จ ๋
ธ๋ ์ : {data.train_mask.sum()}')
print(f'ํ๋ จ ๋
ธ๋ ๋ ์ด๋ธ์ ๋น์จ : {int(data.train_mask.sum())/data.num_nodes : .2f}')
print(f'๊ณ ๋ฆฝ ๋
ธ๋๋ฅผ ํฌํจํ๋์ง ์ฌ๋ถ : {data.has_isolated_nodes()}')
print(f'self loop ๋ฅผ ํฌํจํ๋์ง ์ฌ๋ถ : {data.has_self_loops()}')
print(f'์๋ฐฉํฅ ๊ทธ๋ํ์ธ์ง ์ฌ๋ถ : {data.is_undirected()}')
• Edge list ๋ฐฉ์์ผ๋ก ์ฃ์ง ์ ๋ณด๋ฅผ ํํ
data.edge_index.T
• ๋ฐ์ดํฐ ๊ตฌ์กฐ
# data
print(data)
โช ์์ฑ๊ณผ shape ์ ๋ํ ์ ๋ณด๋ฅผ ํ์ธํด๋ณผ ์ ์๋ค.
1. edge_index : graph connectivity ๊ทธ๋ํ ์ฐ๊ฒฐ์ฑ์ ๋ํ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ค.
2. x : node features ๋ ธ๋ ํผ์ฒ์ ๋ํ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ค. (34,34) ํ๊ธฐ๋ฅผ ํตํด, 34๊ฐ์ ๋ ธ๋ ๊ฐ๊ฐ์ด 34 ์ฐจ์์ feature vector ๋ก ์ด๋ฃจ์ด์ ธ ์์์ ํ์ ํ ์ ์๋ค.
3. y : node labels ์ฌ๊ธฐ์ ๊ฐ์ด 34 ์ด๋ฏ๋ก ๊ฐ ๋ ธ๋๊ฐ ํน์ ํ ํ๋์ ํด๋์ค์ ํ ๋น๋์์์ ์ ์ ์๋ค.
4. train_mask : ๊ฐ ๋ ธ๋๊ฐ ์ด๋ค ์ปค๋ฎค๋ํฐ์ ์ํด์๋์ง ์ค๋ช ํ๋ ์์ฑ์ผ๋ก, ์ฌ๊ธฐ์ ์ฐ๋ฆฌ๊ฐ ์๊ณ ์๋ ground-truth label ์ 4๊ฐ ๋ ธ๋ ๋ฟ์ด๋ฉฐ (์ด์ ์ถ๋ ฅ ์ฝ๋ ์ฐธ๊ณ ) ๋๋จธ์ง๋ inference ๋ฅผ ํตํด ๋ง์ถฐ์ผ ํ๋ค.
๐น Edge Index , Viz
• edge_index ๋ฅผ ์ถ๋ ฅํด๋ณด์
โช [ source node , destination node ] ํ์์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ : COO format (coordinate format) ์ด๋ผ๊ณ ๋ ๋ถ๋ฅธ๋ค.
โช ์๊ฐํ๋ networkx ๋ผ์ด๋ธ๋ฌ๋ฆฌ ํฌ๋งท์ผ๋ก ์งํํ๋ค.
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected = True)
visualize(G, color = data.y)
๐น ๊ทธ๋ํ ์ ๊ฒฝ๋ง ๊ตฌ์ถ Graph convolutional network
• GCN layer : GCNConv
• GNN ์ output ํํ๋ ?
โช ๊ฐ node ๋ feature vector Xi ํํ๋ก ํํ๋๋ฉฐ, ์ด๋ฅผ input ํํ๋ก ๋ฃ์ผ๋ฉด output ์ embedding ์ผ๋ก ๋์ถ๋๋๋ฐ, embedding ์ downstream task ์ ๋ง๊ฒ ํํ๋ ๋ ธ๋ ๋ฒกํฐ๋ฅผ ์๋ฏธํ๋ค. ํ์ต๋ ์๋ฒ ๋ฉ์ผ๋ก ๊ฐ ๋ ธ๋ feature ๊ฐ ๋งตํ๋๋ฉด, ๊ทธ๋ฌํ ์๋ฒ ๋ฉ์ ํตํ์ฌ node ๋จ์, edge ๋จ์, graph ๋จ์์ regression ํน์ classification ๋ฌธ์ ๋ค์ ์ํํ ์ ์๊ฒ ๋๋ค.
๐ ์ฌ๊ธฐ์๋ ๊ฐ ๋ ธ๋๋ฅผ ๊ฐ๋ณ ์ปค๋ฎค๋ํฐ ํด๋์ค์ ๋ถ๋ฅํ๋ Node classification ๋ฌธ์ ๋ฅผ ๋ค๋ค๋ณผ ๊ฒ์ด๋ค. (learn embedding)
class GCN(torch.nn.Module) :
def __init__(self) :
super().__init__()
torch.manual_seed(1234)
self.conv1 = GCNConv(dataset.num_features, 4)
self.conv2 = GCNConv(4,4)
self.conv3 = GCNConv(4,2)
# 3๊ฐ์ graph convolutional layer ๋ฅผ ์์
# ๋ถ๋ฅ๊ธฐ
self.classifier = Linear(2, dataset.num_classes)
def forward(self, x, edge_index) :
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh()
out = self.classifier(h)
return out, h
model = GCN()
print(model)
• GCN ์์ ๊ฐ layer ๋ 1-hop neighborbood ์ฆ, ๋ฐ๋ก ์ฐ๊ฒฐ๋ ์ด์ ๋ ธ๋๋ก๋ถํฐ ์ ๋ณด๋ฅผ ์์งํ์ฌ ์๋ฒ ๋ฉ์ ์งํํ๊ณ 3๊ฐ์ layer ๋ฅผ ํฉ์น ๋ ๊ฐ ๋ ธ๋์ 3-hop neighborgood ๋ก๋ถํฐ ์ ๋ณด๋ฅผ ์์งํ๋ค.
• GCNConv layer ๋ ๋ ธ๋ feature ์ ์ฐจ์์ 34 ์์ 2๋ก ๊ฐ์์ํจ๋ค. ๋ํ ๊ฐ layer ๋ ํ์ดํผ๋ณผ๋ฆญ ํ์ ํธ ํ์ฑํ ํจ์ tanh ๋ฅผ ํตํด ๋น์ ํ์ฑ์ ๋์ธ๋ค.
• ๋ง์ง๋ง์ classifier ์ ํด๋นํ๋ linear ์ธต์ ๋ ธ๋๋ฅผ 4๊ฐ์ ํด๋์ค ์ค์ ํ๋๋ก ํ ๋นํ๊ธฐ ์ํ ๋ถ๋ฅ๋ฅผ ์ํํ๊ธฐ ์ํด ์ ํ ๋ณํ์ ์งํํ๋๋ก ํ๋ค.
๐น ๊ฐ์ค์น ํ์ต ์ด์ ๋ชจ์ต ์๊ฐํ
%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
# Visualization function for NX graph or PyTorch tensor
def visualize(h, color, epoch=None, loss=None, accuracy=None):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
if torch.is_tensor(h):
h = h.detach().cpu().numpy()
plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
if epoch is not None and loss is not None and accuracy['train'] is not None and accuracy['val'] is not None:
plt.xlabel((f'Epoch: {epoch}, Loss: {loss.item():.4f} \n'
f'Training Accuracy: {accuracy["train"]*100:.2f}% \n'
f' Validation Accuracy: {accuracy["val"]*100:.2f}%'),
fontsize=16)
else:
# ๐
nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
node_color=color, cmap="Set2")
plt.show()
โช nx.spring_layout : Fruchterman-Reingold force-directed ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ์ฌ ๋ ธ๋๋ฅผ ๋ฐฐ์น
๐ ๋คํธ์ํฌ์์๋ ๊ฐ๋ณ ๋ ธ๋์ ์ขํ๊ฐ์ด ์๊ธฐ ๋๋ฌธ์ ์ด๋์ ๊ทธ๋ ค์ผ ํ๋์ง ์ ๋งคํ๋ค. ๋ ธ๋์ ์๊ฐ ๋ง๋ค๋ฉด ๋์ฑ์ด ํด๋น ๋คํธ์ํฌ์ ํน์ง์ ์ ๋ณด์ฌ์ง๋๋ก ์๊ฐํ ํด์ผ ํ๊ธฐ ๋๋ฌธ์, ๋ณต์กํ ๋คํธ์ํฌ์์ ๊ฐ ๋ ธ๋์ ์ขํ๋ฅผ ์ ๋ฐฐ์นํ์ฌ ์๊ฐํ ํ๋ ๋ฐฉ๋ฒ์ด ๋ค์ํ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ๋์ค๊ณ ์๋ค.
model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
visualize(h, color=data.y)
๐ ๊ฐ์ค์น training ์ด์ ์ (์ด๊ธฐํ๋ ์ํ์์๋) ๊ฐ์ ์(์ปค๋ฎค๋ํฐ)์ ๋ ธ๋๋ค์ ์ด๋ฏธ ์๋ฒ ๋ฉ ๊ณต๊ฐ์์ ๋ฐ์ ํ๊ฒ ์์นํ๊ณ ์์์ ํ์ ํ ์ ์๋๋ฐ, ์ด๋ GNN ์ด ๋น์ทํ ๋ ธ๋ ์๋ฒ ๋ฉ์ input ๊ทธ๋ํ์์ ๊ฐ๊น์ด ์์น์ ์กด์ฌํ๋ค๋ inductive bias ๋ฅผ ์ด๋ํจ์ ์ ์ ์๋ค.
โช Inductive Bias ๋ : ์ฃผ์ด์ง์ง ์์ ์ ๋ ฅ์ ์ถ๋ ฅ์ ์์ธกํ๋ ๊ฒ์ผ๋ก ์ผ๋ฐํ ์ฑ๋ฅ์ ๋์ด๊ธฐ ์ํด ๋ง์ฝ์ ์ํฉ์ ๋ํ ์ถ๊ฐ์ ์ธ ๊ฐ์ ์ ํด๋นํ๋ค. ๋ณด์ง ๋ชปํ ๋ฐ์ดํฐ์ ๋ํด์๋ ๊ท๋ฉ์ ์ถ๋ก ์ด ๊ฐ๋ฅํด์ง๋๋ก ํ๋ ์๊ณ ๋ฆฌ์ฆ์ด ๊ฐ์ง๊ณ ์๋ ๊ฐ์ ์ ์งํฉ์ด๋ผ ๋ณผ ์ ์๋ค. CNN ์ ๊ฒฝ์ฐ์๋ Locality Translation invariance (์ด๋ค ์ฌ๋ฌผ์ด ๋ค์ด์๋ ์ด๋ฏธ์ง๋ฅผ ์ ๊ณตํด์ค ๋ ์ฌ๋ฌผ์ ์์น๊ฐ ๋ฐ๋์ด๋ ํด๋น ์ฌ๋ฌผ์ ์ธ์ ๊ฐ๋ฅํ๋ค) ์ ๋ํ Relation Inductive Bias ๋ฅผ ๊ฐ๋๋ฐ ๊ฐ entity ๊ฐ์ ๊ด๊ณ๊ฐ ์๋ก ๊ฐ๊น์ด ๋น์ทํ ์์๋ค์ด ๋ชจ์ฌ์๋ค๋ ๋ป์ด๋ค.
โ ์ฐธ๊ณ : https://re-code-cord.tistory.com/entry/Inductive-Bias%EB%9E%80-%EB%AC%B4%EC%97%87%EC%9D%BC%EA%B9%8C
๐น ํ๋ จ
• loss ๊ฐ 1.43 ์์ 0.02 ๊น์ง ๊ฐ์ํ๋ค.
model = GCN()
criterion = torch.nn.CrossEntropyLoss() # loss
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01) # optimizer
def train(data) :
optimizer.zero_grad() # gradients ์ด๊ธฐํ
out, h = model(data.x, data.edge_index) # forward
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # loss ๊ณ์ฐ (ํ๋ จ ๋
ธ๋)
loss.backward() # ์ญ์ ํ
optimizer.step() # ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ
return loss, h
for epoch in range(401) :
loss, h = train(data)
if epoch % 10 == 0 :
print(loss)
โป predicted label : out
โป ground-truth label : data.y
• ๋ ธ๋ ์๋ฒ ๋ฉ์ด ํ๋ จ ์๊ฐ์ ๋ฐ๋ผ evolve ํ๋ ๊ณผ์ ์๊ฐํ
๐ 3-layer GCN ๋ชจ๋ธ์ด ๋น์ทํ ํด๋์ค์ ๋ ธ๋๋ค์ ์ ์ฌํ ์์น์ ๋ฐฐ์นํ๋ฉฐ ์๋ฒ ๋ฉ์ ๋ฐ๊พธ์ด ๊ฐ๋ฉฐ ํ์ต์ ์งํํ๋ ๊ฒ์ ํ์ธํด๋ณผ ์ ์๋ค.
'1๏ธโฃ AIโขDS > ๐ GNN' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[CS224W] Graph Neural Network (0) | 2022.11.24 |
---|---|
[CS224W] Message Passing and Node classification (0) | 2022.11.17 |
[CS224W] PageRank (0) | 2022.11.02 |
[CS224W] 1๊ฐ Machine Learning With Graphs (0) | 2022.10.11 |
Pytorch Geometric Basic code (1) | 2022.09.30 |
๋๊ธ