1๏ธโฃ 9๊ฐ ๋ณต์ต
๐น Main Topic : GNN ์ ํํ๋ฅ๋ ฅ๊ณผ ๋ฒ์
• Expressive power : ์ด๋ป๊ฒ ์๋ก๋ค๋ฅธ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ๊ตฌ๋ณํ๋๊ฐ (node ์ graph structure ๋ฅผ ์ด๋ป๊ฒ ๊ตฌ๋ถํ๋๊ฐ)
• Maximally expressive GNN model : ํํ๋ ฅ์ ์ด๋์ ๊ทน๋ํ ์ํฌ ์ ์์๊น
๐น GNN model
โ GCN : mean pool
โก GraphSAGE : max pool
• Local Neighborhood Structure : ๋ชจ๋ ๋ ธ๋๊ฐ ๊ฐ์ feature ๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ทธ๋ํ์์ ์๋ก๋ค๋ฅธ ๋ ธ๋๋ฅผ ๊ตฌ๋ณํ๋ ๋ฐฉ๋ฒ (same color - same feature ๋ก ๊ฐ์ฃผ)
โช ๊ธฐ์ค1 : different node degree
โช ๊ธฐ์ค2 : different neighbors' node degrees
โ ๋ ธ๋1 ๊ณผ ๋ ธ๋2 ๋ ๊ตฌ์กฐ๊ฐ ์์ ํ ๋์ผ symmetric → same embedding → GNN ์ ๋ ธ๋1๊ณผ 2๋ฅผ ๊ตฌ๋ถํด๋ด์ง ๋ชปํจ
๐น Computational Graph
• computational graph ๋ ๋ ธ๋์ Rooted Subtree Structure ์ ๋ฐ๋ผ ๊ฒฐ์ ๋๋ค.
• ํํ๋ ฅ์ด ์ข์ GNN ์ subtree ๋ฅผ ์๋ฒ ๋ฉ์ผ๋ก ์ผ๋์ผ ๋์ Injective ์ํจ๋ค.
โช Aggregation ์ด ์ด์๋ ธ๋ ์ ๋ณด๋ฅผ ์ถฉ๋ถํ ๋ฐ์ํ๋ค๋ฉด, ๋ ธ๋ ์๋ฒ ๋ฉ์ ์๋ก๋ค๋ฅธ subtree ๋ฅผ ์ ๊ตฌ๋ถํด๋ผ ์ ์๋ค.
โช GNN ์ ํํ๋ฅ๋ ฅ์ ์ด๋ค ์ด์๋ ธ๋ aggregation ํจ์๋ฅผ ์ฐ๋๋์ ๋ฌ๋ ค์์ผ๋ฉฐ, ์ผ๋์ผํจ์ ์ฑ์ง์ ๊ฐ์ง ๊ฒ์ด ํํ๋ ฅ์ด ๊ฐ์ฅ ์ข๋ค๊ณ ๋ณผ ์ ์๋ค.
๐น Most powerful GNN
• GCN → Not injective
• GraphSAGE → Not injective
• Graph Isomorphism Network GIN ๐ Most Expressive
โช GIN ์ MLP - sum - MLP ๊ตฌ์กฐ๋ก ์ผ๋์ผ ๋์์ ๊ฐ๋ฅ์ผ ํ๋ค.
2๏ธโฃ GIN ๋ฆฌ๋ทฐ
https://towardsdatascience.com/how-to-design-the-most-powerful-graph-neural-network-3d18b07a6e66
๐น Topic
• graph embedding ์ ๊ฐ๋ฅํ๊ฒ ํ๊ธฐ ์ํด, GNN ์ผ๋ก๋ถํฐ ์ป์ node embedding ์ ์ด๋ป๊ฒ ๊ฒฐํฉํ ๊น → Global Pooling and GIN
๐น PROTEINS Dataset
• PROTEINS ๋ฐ์ดํฐ์ ์ ๋จ๋ฐฑ์ง์ ํํํ๋ 1113 ๊ฐ์ ๊ทธ๋ํ๋ฅผ ํฌํจํ๊ณ ์๋ค. ์๋ฏธ๋ ธ์ฐ์ด ๋ ธ๋์ ํด๋นํ๋ฉฐ , ๋ ๋ ธ๋๊ฐ 0.6 ๋๋ ธ๋ฏธํฐ ๋ฏธ๋ง์ ๊ฐ๊น์ด ๊ฑฐ๋ฆฌ์ ์์ผ๋ฉด ์ฃ์ง๋ก ์ฐ๊ฒฐ๋์ด ์๋ค.
• ์ด ๋ฐ์ดํฐ์ ์ ํตํด ์ฐ๋ฆฌ๋ ๊ฐ ๋จ๋ฐฑ์ง์ ํจ์์ธ์ง ์๋์ง๋ก ๊ตฌ๋ณํด๋ด๋ ๊ณผ์ ๋ฅผ ์ํํ๋ค.
• GraphSAGE ๋ฑ์ GNN ์ ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ๋ก ์๋ฒ ๋ฉํ ๊ฒฐ๊ณผ๊ฐ ์ข์ง ์๋ค๋ฉด, mini-batch ๋ฅผ ํตํด์ ์ฑ๋ฅ์ ๋์ผ ์ ์๋ค. PROTEINS ๋ฐ์ดํฐ์ ์ ํฌ์ง ์๊ธฐ ๋๋ฌธ์ ๋ฏธ๋๋ฐฐ์น๋ฅผ ํตํด ํ๋ จ ์๋๋ฅผ ๋์ผ ์ ์๋ค.
train_loader = DataLoader(train_dataset, batch_size= 64, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
๐น Graph Isomorphism Network
• GIN ์ GNN ์ ํํ์ฑ์ ๊ทน๋ํ ํ๊ธฐ ์ํด ๊ณ ์๋ ๋ชจ๋ธ์ด๋ค.
โ Weisfeiler-Lehman test
GNN ์ ํํ๋ฅ๋ ฅ์ ์ ์ํ๊ธฐ ์ํด ์ฌ์ฉ๋๋ ๋ฐฉ๋ฒ์๋ WL graph isomorphism test ๊ฐ ์๋ค.
Isomorphic graph ๋ ๋๊ฐ์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง ๊ทธ๋ํ๋ฅผ ์๋ฏธํ๋ค.
์์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด, ๋๊ฐ์ ์ฐ๊ฒฐ๊ตฌ์กฐ๋ฅผ ์ง๋ ์ง๋ง ๋ ธ๋์ ์์น์์๋ง ๋ฐ๋ ๋ ๊ทธ๋ํ์ ๊ด๊ณ๋ฅผ Isomorphic ํ๋ค๊ณ ๋งํ๋ค. WL test ๋ ๋ง์ฝ ๋ ๊ทธ๋ํ๊ฐ non-isomorphic ํ๋ค๋ฉด non-isomorphic ํ๋ค๊ณ ๋งํด์ค ์ ์๋ค. ๊ทธ๋ฌ๋ isomorphic ํ๋ค๋ ๊ฒ์ ๋งํด์ฃผ์ง ์๋๋ค. ์ค์ง non-isomorphic ์ ๊ฐ๋ฅํ๋ค ์ฌ๋ถ๋ง ์๋ ค์ค ์ ์๋ค.
WL test ๋ GNN ์ด ํ์ตํ๋ ๋ฐฉ๋ฒ๊ณผ ๋งค์ฐ ์ ์ฌํ๋ค. WL test ์์๋
1. ๋ชจ๋ ๋ ธ๋๊ฐ ๋์ผํ label ์ ๊ฐ์ง๊ณ ์์ํ๋ฉฐ
2. ์ด์ ๋ ธ๋๋ก๋ถํฐ ์ป์ label ์ ๋ณด๋ฅผ ๋ณํฉํ๊ณ ํด์ํจ์๋ฅผ ์ ์ฉํด ์๋ก์ด label ์ ํ์ฑํ๋ ๋จ๊ณ๋ฅผ
3. label ์ด ๋ ์ด์ ๋ฐ๋์ง ์์ ๋๊น์ง ๋ฐ๋ณตํ๋ค.
์ ๊ณผ์ ์ ํตํด ํ์ต์ ์งํํ๋๋ฐ, ์ด๋ GNN ์์ feature vector ๋ฅผ aggregate ํ๋ ๋ฐฉ๋ฒ๊ณผ ๋งค์ฐ ์ ์ฌํ ๋ฟ๋ง ์๋๋ผ, ๊ทธ๋ํ ์ ์ฌ์ฑ์ ํ๋จํ ์ ์๋ ๋ฅ๋ ฅ์ GCN ์ด๋ GraphSAGE ๊ฐ์ ๋ค๋ฅธ ๋ชจ๋ธ๋ค ๋ณด๋ค ๋ ๊ฐ๋ ฅํ ๊ตฌ์กฐ๋ฅผ ๊ฐ๋๋ค๊ณ ๋ณผ ์ ์๋ค.
โก One aggregator to rule them all
WL test ์ ์ฅ์ ์ ๊ณ ์ํ์ฌ, ์๋ก์ด aggregator ํจ์๋ฅผ ๋ง๋ค๊ฒ ๋๋๋ฐ, ์ด ํจ์๋ ๋ง์ฝ non-isomorphic ํ ๊ทธ๋ํ๋ค ์ด๋ผ๋ฉด ์๋ก ๋ค๋ฅธ ๋ ธ๋ ์๋ฒ ๋ฉ์ ์์ฐํ ์ ์๋๋ก ํ๋ค. (์๋ก ๋ค๋ฅธ ๊ทธ๋ํ์ ๋ํด ์๋ก ๋ค๋ฅธ ๋ ธ๋ ์๋ฒ ๋ฉ์ ์ถ๋ ฅ)
๋ ผ๋ฌธ ์ฐ๊ตฌ์์๋ 2๊ฐ์ ์ผ๋์ผ ๋์ ํจ์๋ฅผ ์ ์ํ๋ค.
1. GAT ๋คํธ์ํฌ๋ฅผ ํตํด, ์ฃผ์ด์ง task ์ ๋ํด best weighting factor (์ต์ ์ ๊ฐ์ค์น ์ธ์) ๋ฅผ ํ์ตํ ์ ์๋๋ก ํ๋ค.
2. GIN ๋คํธ์ํฌ๋ฅผ ํตํด, ๋ ์ผ๋์ผ ํจ์์ ๊ทผ์ ์ฑ์ ํ์ตํ๋ค. (by Universal Approximation Thm)
โช Universal Approximation Thm : ์ถฉ๋ถํ ํฐ Hidden Dimension ๊ณผ Non-Linear ํจ์๊ฐ ์๋ค๋ฉด 1-Hidden-Layer NeuralNet ์ผ๋ก ์ด๋ค ์ฐ์ํจ์๋ ๊ทผ์ฌํ ์ ์๋ค.
GIN ์์ ํน์ ํ ๋ ธ๋ i ์ hidden vector (hi) ๋ฅผ ๊ณ์ฐํ๋ ์์์ ๋ค์๊ณผ ๊ฐ๋ค.
ํด๋น ์์์์ ษ ๋ ์ด์๋ ธ๋์ ๋น๊ตํ์ ๋, target ๋ ธ๋์ ์ค์์ฑ์ ๊ฒฐ์ ํ๋ ์ธ์๋ก, ๋ง์ฝ ์ด์๋ ธ๋์ ๋์ผํ ์ค์๋๋ฅผ ๊ฐ์ง๋ค ํ๋ฉด ๊ฐ์ด 0์ด ๋๋ค. ษ ๋ ๊ณ ์ ๋ ์์๊ฐ์ด๊ฑฐ๋ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ์ผ ์ ์๋ค.
โข Global Pooling
๊ทธ๋ํ ์์ค์ ํด์ ํน์ Global pooling ์, GNN ์ ํตํด ๊ณ์ฐ๋ ๋ ธ๋ ์๋ฒ ๋ฉ์ ๊ฐ์ง๊ณ graph embedding ์ ๋ง๋ค ์ ์๋ค. graph embedding ์ ์ป๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ mean, sum, max ์ฐ์ฐ์ ๋ชจ๋ ๋ ธ๋ ์๋ฒ ๋ฉ hi ์ ๋ํด ์ํํ๋ ๊ฒ์ด๋ค.
๋ ผ๋ฌธ ์ ์๋ graph ์์ค์ ํด์์ ๋ํด 2๊ฐ์ง ์ค์ํ ์ฌ์ค์ ๊ฐ์กฐํ๋ค.
1. ๋ชจ๋ ๊ตฌ์กฐ์ ์ธ ์ ๋ณด๋ฅผ ๊ณ ๋ คํ๋ ค๋ฉด, ์ด์ ๋ ์ด์ด์ ์๋ฒ ๋ฉ์ ์ ์ง (keep) ํ๋ ๊ฒ์ด ํ์์ ์ด๋ค.
2. mean ์ด๋ max ์ฐ์ฐ๋ณด๋ค sum ์ฐ์ฐ์๊ฐ ๊ฐ์ฅ ํํ๋ ฅ์ด ์ข๋ค.
์ด๋ฌํ ์ฌ์ค๋ค์ ๋ฐํ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ Global pooling ํจ์๋ฅผ ์ ์ํ๊ฒ ๋๋ค.
๊ฐ ๋ ์ด์ด์ ๋ํด์ ๋ ธ๋ ์๋ฒ ๋ฉ ๊ฐ์ ๋ชจ๋ ๋ํด์ง๊ณ , ๋ํ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋ concat ์ํจ๋ค.
๐น GIN in pytorch
• GINConv
โช nn : MLP (๋ ์ผ๋์ผ ๋์ ํจ์๋ฅผ ๊ทผ์ ์ํค๋๋ฐ ์ฌ์ฉ)
โช eps : ษ ์ ์ด๊ธฐ๊ฐ (๋ณดํต 0์ผ๋ก ์ค์ )
โช train_eps : ษ ๊ฐ ํ๋ จ ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ์ธ์ง T/F ๋ก ํ๊ธฐ (๋ณดํต F ๋ก ์ค์ )
[์ฐธ๊ณ ] GINEConv
โช Second GIN layer , ์ด์๋ ธ๋์ feature ์ ReLU function ์ ์ ์ฉ์ํจ๋ค.
3๏ธโฃ ์ฝ๋๋ฆฌ๋ทฐ
https://colab.research.google.com/drive/1ycQVANLkgqyk_iezLYGbSUp98JsY7bEQ?usp=sharing
• GCN ๊ณผ ์ฑ๋ฅ์ ๋น๊ตํด๋ณผ ์์
๐น GIN architecture code
• Library
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv #โจ
from torch_geometric.nn import global_mean_pool, global_add_pool #โจ
• GCN
class GCN(torch.nn.Module) :
def __init__(self, dim_h) :
super(GCN, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, dim_h)
self.conv2 = GCNConv(dim_h, dim_h)
self.conv3 = GCNConv(dim_h, dim_h)
self.lin = Linear(dim_h, dataset.num_classes)
def forward(self, x, edge_index, batch) :
# ๋
ธ๋ ์๋ฒ ๋ฉ
h = self.conv1(x, edge_index)
h = h.relu()
h = self.conv2(h, edge_index)
h = h.relu()
h = self.conv3(h, edge_index)
# ๊ทธ๋ํ level ํด์ (graph embedding)
hG = global_mean_pool(h, batch)
# ๋ถ๋ฅ๊ธฐ
h = F.dropout(hG, p=0.5, training = self.training)
h = self.lin(h)
return hG, F.log_softmax(h, dim=1) # ํจ์์ธ์ง ์๋์ง ๋ถ๋ฅ
• GIN (3 layer)
class GIN(torch.nn.Module) :
def __init__(self, dim_h) :
super(GIN, self).__init__()
self.conv1 = GINConv(
Sequential(Linear(dataset.num_node_features, dim_h),
BatchNorm1d(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU())
)
self.conv2 = GINConv(
Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU())
)
self.conv3 = GINConv(
Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
Linear(dim_h, dim_h), ReLU())
)
self.lin1 = Linear(dim_h*3, dim_h*3)
self.lin2 = Linear(dim_h*3, dataset.num_classes)
def forward(self, x, edge_index, batch) :
# ๋
ธ๋์๋ฒ ๋ฉ
h1 = self.conv1(x, edge_index)
h2 = self.conv2(h1, edge_index)
h3 = self.conv3(h2, edge_index)
# ๊ทธ๋ํ level ํด์ (graph embedding)
h1 = global_add_pool(h1, batch)
h2 = global_add_pool(h2, batch)
h3 = global_add_pool(h3, batch)
# concate ๊ทธ๋ํ ์๋ฒ ๋ฉ
h = torch.cat((h1,h2,h3), dim = 1)
# ๋ถ๋ฅ๊ธฐ
h = self.lin1(h)
h = h.relu()
h = F.dropout(h, p = 0.5, training = self.training)
h = self.lin2(h)
return h, F.log_softmax(h, dim=1)
gcn = GCN(dim_h=32)
gin = GIN(dim_h=32)
• Train and Test
def train(model, loader) :
criterion = torch.nn.CrossEntropyLoss() # ๐ก ํ๊ฐ์งํ
optimizer = torch.optim.Adam(model.parameters() , lr = 0.01, weight_decay = 0.01) # ๐ก ์ต์ ํ
epochs = 100
model.train()
for epoch in range(epochs +1) :
total_loss = 0
acc = 0
val_loss = 0
val_acc = 0
# ๐ก ๋ฐฐ์น ๋จ์๋ก ํ์ต
for data in loader :
optimizer.zero_grad()
_, out = model(data.x, data.edge_index, data.batch)
loss = criterion(out, data.y)
total_loss += loss / len(loader)
acc += accuracy(out.argmax(dim=1), data.y) / len(loader)
loss.backward()
optimizer.step()
# ๐ก validation
val_loss, val_acc = test(model, val_loader)
# 10 epoch ๋ง๋ค ์งํ๊ฐ ์ถ๋ ฅ
if (epoch%10 == 0) :
print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} '
f'| Train Acc: {acc*100:>5.2f}% '
f'| Val Loss: {val_loss:.2f} '
f'| Val Acc: {val_acc*100:.2f}%')
test_loss, test_acc = test(model, test_loader)
print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')
return model
def test(model, loader):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
loss = 0
acc = 0
for data in loader:
_, out = model(data.x, data.edge_index, data.batch)
loss += criterion(out, data.y) / len(loader)
acc += accuracy(out.argmax(dim=1), data.y) / len(loader)
return loss, acc
def accuracy(pred_y, y):
"""Calculate accuracy."""
return ((pred_y == y).sum() / len(y)).item()
๐น test accuracy result
• ์ฑ๋ฅ์ฐจ์ด๊ฐ ๋ฐ์ํ๋ ์ด์
โช GCN ๊ณผ ๋ฌ๋ฆฌ, GIN ์ aggregator ๋ ํนํ graph ๋ฅผ ํ๋ณํ๋๋ฐ ๊ณ ์๋ ๋ชจ๋ธ๊ตฌ์กฐ์ด๋ค.
โช GIN ์ ๋งจ ๋ง์ง๋ง hidden vector ๋ง์ ๊ณ ๋ คํ๋ ๊ฒ์ด ์๋๋ผ, ๋ชจ๋ ๋ ์ด์ด์ hidden vector ๋ฅผ concat ํจ์ผ๋ก์จ ๊ทธ๋ํ ์ ์ฒด ๊ตฌ์กฐ๋ฅผ ๊ณ ๋ คํ๋ค.
โช sum ์ฐ์ฐ์ด mean ์ฐ์ฐ๋ณด๋ค ๋ ์ฑ๋ฅ์ด ๋ฐ์ด๋๋ค (injective ๋ฉด์์)
• Graph classification ๊ฒฐ๊ณผ
• GCN ๊ฒฐ๊ณผ์ GIN ๊ฒฐ๊ณผ๋ฅผ ensemble ํ์ฌ ์ฑ๋ฅ์ ์กฐ๊ธ ๋ ๋์ผ ์ ์๋ค. normalized ๋ output vector ์ ํ๊ท ์ ์ทจํ๋ ๋ฐฉ์์ผ๋ก ๋ ๋คํธ์ํฌ์ ๊ทธ๋ํ ์๋ฒ ๋ฉ ๊ฒฐ๊ณผ๋ฅผ ๊ฒฐํฉํ๋ค.
gcn.eval()
gin.eval()
acc_gcn = 0
acc_gin = 0
acc = 0
for data in test_loader :
# ๋ถ๋ฅ
_, out_gcn = gcn(data.x, data.edge_index, data.batch)
_, out_gin = gin(data.x, data.edge_index, data.batch)
out = (out_gcn+ out_gin)/2 # โจ
# accuracy score ๊ณ์ฐ
acc_gcn += accuracy(out_gcn.argmax(dim=1), data.y) / len(test_loader)
acc_gin += accuracy(out_gin.argmax(dim=1), data.y) / len(test_loader)
acc += accuracy(out.argmax(dim=1), data.y) / len(test_loader)
๐น Conclusion
GIN ์ GNN ์ ์ดํดํ๋ ๋ฐ ์์ด์ ๋งค์ฐ ์ค์ํ ๋คํธ์ํฌ์ด๋ค. ์ ํ๋๋ฅผ ํฅ์์ํฌ ๋ฟ๋ง ์๋๋ผ, ์ฌ๋ฌ๊ฐ์ง GNN ๊ตฌ์กฐ ์ค์์ ์ด๋ ํ ๋ชจ๋ธ์ด ๊ฐ์ฅ ํํ๋ ฅ์ด ์ข์์ง ์ด๋ก ์ ์ผ๋ก ์ค๋ช ๊ฐ๋ฅํ ์ ์๋๋ก ๋ง๋ ๋คํธ์ํฌ๋ผ ๋ณผ ์ ์๋ค.
ํด๋น ์ํฐํด์์๋ graph classification task ์ ๋ํด global pooling ์ ์ํํ๊ธฐ ์ํ์ฌ, WL test ์์ด๋์ด์์ ๊ณ ์๋ GIN ๋คํธ์ํฌ๋ฅผ ์ ์ฉํด๋ณด๋ ๋ด์ฉ์ ๋ด๊ณ ์๋ค.
๋น๋ก PROTEINS ๋ฐ์ดํฐ์ ๋ํด์๋ GIN ์ด ์ข์ ์ฑ๋ฅ์ ๋ณด์์ผ๋, social graph ์ ๊ฐ์ด ์ค์ ์ธ๊ณ ๋ฐ์ดํฐ์๋ ์ ์ ์ฉ๋์ง ์๋ ๊ฒฝ์ฐ๋ ์๋ค.
'1๏ธโฃ AIโขDS > ๐ GNN' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[cs224w] Frequent Subgraph Mining with GNNs (0) | 2023.01.27 |
---|---|
[CS224W] Graph Neural Network (0) | 2022.11.24 |
[CS224W] Message Passing and Node classification (0) | 2022.11.17 |
[CS224W] PageRank (0) | 2022.11.02 |
[CS224W] 1๊ฐ Machine Learning With Graphs (0) | 2022.10.11 |
๋๊ธ