๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
1๏ธโƒฃ AI•DS/๐Ÿ“˜ GNN

[cs224w] Theory of Graph Neural Networks

by isdawell 2023. 1. 6.
728x90

 

1๏ธโƒฃ  9๊ฐ• ๋ณต์Šต 


 

๐Ÿ”น Main Topic : GNN ์˜ ํ‘œํ˜„๋Šฅ๋ ฅ๊ณผ ๋ฒ”์œ„ 

 

• Expressive power : ์–ด๋–ป๊ฒŒ ์„œ๋กœ๋‹ค๋ฅธ ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ๋ฅผ ๊ตฌ๋ณ„ํ•˜๋Š”๊ฐ€ (node ์™€ graph structure ๋ฅผ ์–ด๋–ป๊ฒŒ ๊ตฌ๋ถ„ํ•˜๋Š”๊ฐ€)

 

• Maximally expressive GNN model : ํ‘œํ˜„๋ ฅ์„ ์–ด๋””์„œ ๊ทน๋Œ€ํ™” ์‹œํ‚ฌ ์ˆ˜ ์žˆ์„๊นŒ 

 

 

 

๐Ÿ”น GNN model  

 

โ‘  GCN : mean pool 

โ‘ก GraphSAGE : max pool 

 

 

 

Local Neighborhood Structure : ๋ชจ๋“  ๋…ธ๋“œ๊ฐ€ ๊ฐ™์€ feature ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๊ทธ๋ž˜ํ”„์—์„œ ์„œ๋กœ๋‹ค๋ฅธ ๋…ธ๋“œ๋ฅผ ๊ตฌ๋ณ„ํ•˜๋Š” ๋ฐฉ๋ฒ• (same color - same feature ๋กœ ๊ฐ„์ฃผ) 

 

โ†ช ๊ธฐ์ค€1 : different node degree

 

โ†ช ๊ธฐ์ค€2 : different neighbors' node degrees 

 

 

โœ” ๋…ธ๋“œ1 ๊ณผ ๋…ธ๋“œ2 ๋Š” ๊ตฌ์กฐ๊ฐ€ ์™„์ „ํžˆ ๋™์ผ symmetric → same embedding  → GNN ์€ ๋…ธ๋“œ1๊ณผ 2๋ฅผ ๊ตฌ๋ถ„ํ•ด๋‚ด์ง€ ๋ชปํ•จ 

 

GNN ์€ ID ๋ฅผ ๋ณด๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ ์˜ค์ง node feature ๋งŒ ๊ณ ๋ คํ•จ

 

 

 

๐Ÿ”น Computational Graph 

 

• computational graph ๋Š” ๋…ธ๋„์˜ Rooted Subtree Structure ์— ๋”ฐ๋ผ ๊ฒฐ์ •๋œ๋‹ค. 

 

 

 

 

• ํ‘œํ˜„๋ ฅ์ด ์ข‹์€ GNN ์€ subtree ๋ฅผ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ์ผ๋Œ€์ผ ๋Œ€์‘ Injective ์‹œํ‚จ๋‹ค. 

 

 

โ†ช Aggregation ์ด ์ด์›ƒ๋…ธ๋“œ ์ •๋ณด๋ฅผ ์ถฉ๋ถ„ํžˆ ๋ฐ˜์˜ํ•œ๋‹ค๋ฉด, ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์€ ์„œ๋กœ๋‹ค๋ฅธ subtree ๋ฅผ ์ž˜ ๊ตฌ๋ถ„ํ•ด๋‚ผ ์ˆ˜ ์žˆ๋‹ค. 

โ†ช GNN ์˜ ํ‘œํ˜„๋Šฅ๋ ฅ์€ ์–ด๋–ค ์ด์›ƒ๋…ธ๋“œ aggregation ํ•จ์ˆ˜๋ฅผ ์“ฐ๋Š๋ƒ์— ๋‹ฌ๋ ค์žˆ์œผ๋ฉฐ, ์ผ๋Œ€์ผํ•จ์ˆ˜ ์„ฑ์งˆ์„ ๊ฐ€์ง„ ๊ฒƒ์ด ํ‘œํ˜„๋ ฅ์ด ๊ฐ€์žฅ ์ข‹๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

 

 

๐Ÿ”น Most powerful GNN 

 

• GCN → Not injective 

 

 

• GraphSAGE → Not injective 

 

 

 

 

Graph Isomorphism Network GIN ๐Ÿ‘‰ Most Expressive

 

 

 

โ†ช GIN ์€ MLP - sum - MLP ๊ตฌ์กฐ๋กœ ์ผ๋Œ€์ผ ๋Œ€์‘์„ ๊ฐ€๋Šฅ์ผ€ ํ•œ๋‹ค. 

 

 

 

 

 

 

2๏ธโƒฃ  GIN ๋ฆฌ๋ทฐ 


 

https://towardsdatascience.com/how-to-design-the-most-powerful-graph-neural-network-3d18b07a6e66

 

GIN: How to Design the Most Powerful Graph Neural Network

Graph classification with Graph Isomorphism Networks

towardsdatascience.com

 

 

 

๐Ÿ”น  Topic 

 

•  graph embedding ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด, GNN ์œผ๋กœ๋ถ€ํ„ฐ ์–ป์€ node embedding ์„ ์–ด๋–ป๊ฒŒ ๊ฒฐํ•ฉํ• ๊นŒ → Global Pooling  and GIN 

 

 

 

๐Ÿ”น  PROTEINS Dataset 

 

•  PROTEINS ๋ฐ์ดํ„ฐ์…‹์€ ๋‹จ๋ฐฑ์งˆ์„ ํ‘œํ˜„ํ•˜๋Š” 1113 ๊ฐœ์˜ ๊ทธ๋ž˜ํ”„๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ๋‹ค. ์•„๋ฏธ๋…ธ์‚ฐ์ด ๋…ธ๋“œ์— ํ•ด๋‹นํ•˜๋ฉฐ , ๋‘ ๋…ธ๋“œ๊ฐ€ 0.6 ๋‚˜๋…ธ๋ฏธํ„ฐ ๋ฏธ๋งŒ์˜ ๊ฐ€๊นŒ์šด ๊ฑฐ๋ฆฌ์— ์žˆ์œผ๋ฉด ์—ฃ์ง€๋กœ ์—ฐ๊ฒฐ๋˜์–ด ์žˆ๋‹ค. 

•  ์ด ๋ฐ์ดํ„ฐ์…‹์„ ํ†ตํ•ด ์šฐ๋ฆฌ๋Š” ๊ฐ ๋‹จ๋ฐฑ์งˆ์„ ํšจ์†Œ์ธ์ง€ ์•„๋‹Œ์ง€๋กœ ๊ตฌ๋ณ„ํ•ด๋‚ด๋Š” ๊ณผ์ œ๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค. 

 

 

•  GraphSAGE ๋“ฑ์˜ GNN ์˜ ์ผ๋ฐ˜์ ์ธ ๋ชจ๋ธ๋กœ ์ž„๋ฒ ๋”ฉํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์ข‹์ง€ ์•Š๋‹ค๋ฉด, mini-batch ๋ฅผ ํ†ตํ•ด์„œ ์„ฑ๋Šฅ์„ ๋†’์ผ ์ˆ˜ ์žˆ๋‹ค. PROTEINS ๋ฐ์ดํ„ฐ์…‹์€ ํฌ์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ํ†ตํ•ด ํ›ˆ๋ จ ์†๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ๋‹ค. 

 

 

train_loader = DataLoader(train_dataset, batch_size= 64, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

 

 

 

 

 

๐Ÿ”น Graph Isomorphism Network 

 

 

•  GIN ์€ GNN ์˜ ํ‘œํ˜„์„ฑ์„ ๊ทน๋Œ€ํ™” ํ•˜๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆ๋œ ๋ชจ๋ธ์ด๋‹ค. 

 

 

โ‘  Weisfeiler-Lehman test 

 

GNN ์˜ ํ‘œํ˜„๋Šฅ๋ ฅ์„ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋Š” ๋ฐฉ๋ฒ•์—๋Š” WL graph isomorphism test ๊ฐ€ ์žˆ๋‹ค. 

 

Isomorphic graph ๋ž€ ๋˜‘๊ฐ™์€ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง„ ๊ทธ๋ž˜ํ”„๋ฅผ ์˜๋ฏธํ•œ๋‹ค. 

 

 

์œ„์˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด, ๋˜‘๊ฐ™์€ ์—ฐ๊ฒฐ๊ตฌ์กฐ๋ฅผ ์ง€๋…”์ง€๋งŒ ๋…ธ๋“œ์˜ ์œ„์น˜์ˆœ์„œ๋งŒ ๋ฐ”๋€ ๋‘ ๊ทธ๋ž˜ํ”„์˜ ๊ด€๊ณ„๋ฅผ Isomorphic ํ•œ๋‹ค๊ณ  ๋งํ•œ๋‹ค. WL test ๋Š” ๋งŒ์•ฝ ๋‘ ๊ทธ๋ž˜ํ”„๊ฐ€ non-isomorphic ํ•˜๋‹ค๋ฉด non-isomorphic ํ•˜๋‹ค๊ณ  ๋งํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ isomorphic ํ•˜๋‹ค๋Š” ๊ฒƒ์€ ๋งํ•ด์ฃผ์ง€ ์•Š๋Š”๋‹ค. ์˜ค์ง non-isomorphic ์˜ ๊ฐ€๋Šฅํ•˜๋‹ค ์—ฌ๋ถ€๋งŒ ์•Œ๋ ค์ค„ ์ˆ˜ ์žˆ๋‹ค. 

 

WL test ๋Š” GNN ์ด ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•˜๋‹ค. WL test ์—์„œ๋Š” 

 

1. ๋ชจ๋“  ๋…ธ๋“œ๊ฐ€ ๋™์ผํ•œ label ์„ ๊ฐ€์ง€๊ณ  ์‹œ์ž‘ํ•˜๋ฉฐ

2. ์ด์›ƒ ๋…ธ๋“œ๋กœ๋ถ€ํ„ฐ ์–ป์€ label ์ •๋ณด๋ฅผ ๋ณ‘ํ•ฉํ•˜๊ณ  ํ•ด์‹œํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•ด ์ƒˆ๋กœ์šด label ์„ ํ˜•์„ฑํ•˜๋Š” ๋‹จ๊ณ„๋ฅผ

3. label ์ด ๋” ์ด์ƒ ๋ฐ”๋€Œ์ง€ ์•Š์„ ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณตํ•œ๋‹ค. 

 

์˜ ๊ณผ์ •์„ ํ†ตํ•ด ํ•™์Šต์„ ์ง„ํ–‰ํ•˜๋Š”๋ฐ, ์ด๋Š” GNN ์—์„œ feature vector ๋ฅผ aggregate ํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ ๋งค์šฐ ์œ ์‚ฌํ•  ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ๊ทธ๋ž˜ํ”„ ์œ ์‚ฌ์„ฑ์„ ํŒ๋‹จํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์€ GCN ์ด๋‚˜ GraphSAGE ๊ฐ™์€ ๋‹ค๋ฅธ ๋ชจ๋ธ๋“ค ๋ณด๋‹ค ๋” ๊ฐ•๋ ฅํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ–๋Š”๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

 

 

โ‘ก One aggregator to rule them all 

 

WL test ์˜ ์žฅ์ ์— ๊ณ ์•ˆํ•˜์—ฌ, ์ƒˆ๋กœ์šด aggregator ํ•จ์ˆ˜๋ฅผ ๋งŒ๋“ค๊ฒŒ ๋˜๋Š”๋ฐ, ์ด ํ•จ์ˆ˜๋Š” ๋งŒ์•ฝ non-isomorphic ํ•œ ๊ทธ๋ž˜ํ”„๋“ค ์ด๋ผ๋ฉด ์„œ๋กœ ๋‹ค๋ฅธ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์„ ์ƒ์‚ฐํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค. (์„œ๋กœ ๋‹ค๋ฅธ ๊ทธ๋ž˜ํ”„์— ๋Œ€ํ•ด ์„œ๋กœ ๋‹ค๋ฅธ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์„ ์ถœ๋ ฅ) 

 

๋…ผ๋ฌธ ์—ฐ๊ตฌ์—์„œ๋Š” 2๊ฐœ์˜ ์ผ๋Œ€์ผ ๋Œ€์‘ ํ•จ์ˆ˜๋ฅผ ์ œ์•ˆํ•œ๋‹ค. 

 

1. GAT ๋„คํŠธ์›Œํฌ๋ฅผ ํ†ตํ•ด, ์ฃผ์–ด์ง„ task ์— ๋Œ€ํ•ด best weighting factor (์ตœ์ ์˜ ๊ฐ€์ค‘์น˜ ์ธ์ž) ๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค. 

2. GIN ๋„คํŠธ์›Œํฌ๋ฅผ ํ†ตํ•ด, ๋‘ ์ผ๋Œ€์ผ ํ•จ์ˆ˜์˜ ๊ทผ์ ‘์„ฑ์„ ํ•™์Šตํ•œ๋‹ค. (by Universal Approximation Thm) 

 

โ†ช Universal Approximation Thm : ์ถฉ๋ถ„ํžˆ ํฐ Hidden Dimension ๊ณผ Non-Linear ํ•จ์ˆ˜๊ฐ€ ์žˆ๋‹ค๋ฉด 1-Hidden-Layer NeuralNet ์œผ๋กœ ์–ด๋–ค ์—ฐ์†ํ•จ์ˆ˜๋“  ๊ทผ์‚ฌํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

 

GIN ์—์„œ ํŠน์ •ํ•œ ๋…ธ๋“œ i ์˜ hidden vector (hi) ๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ์ˆ˜์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

 

 

ํ•ด๋‹น ์ˆ˜์‹์—์„œ  ษ› ๋Š” ์ด์›ƒ๋…ธ๋“œ์™€ ๋น„๊ตํ–ˆ์„ ๋•Œ, target ๋…ธ๋“œ์˜ ์ค‘์š”์„ฑ์„ ๊ฒฐ์ •ํ•˜๋Š” ์ธ์ž๋กœ, ๋งŒ์•ฝ ์ด์›ƒ๋…ธ๋“œ์™€ ๋™์ผํ•œ ์ค‘์š”๋„๋ฅผ ๊ฐ€์ง„๋‹ค ํ•˜๋ฉด ๊ฐ’์ด 0์ด ๋œ๋‹ค.  ษ› ๋Š” ๊ณ ์ •๋œ ์ƒ์ˆ˜๊ฐ’์ด๊ฑฐ๋‚˜ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ์ผ ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

 

โ‘ข Global Pooling 

 

๊ทธ๋ž˜ํ”„ ์ˆ˜์ค€์˜ ํ•ด์„ ํ˜น์€ Global pooling ์€, GNN ์„ ํ†ตํ•ด ๊ณ„์‚ฐ๋œ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ์„ ๊ฐ€์ง€๊ณ  graph embedding ์„ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. graph embedding ์„ ์–ป๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ mean, sum, max ์—ฐ์‚ฐ์„ ๋ชจ๋“  ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ hi ์— ๋Œ€ํ•ด ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. 

 

 

๋…ผ๋ฌธ ์ €์ž๋Š” graph ์ˆ˜์ค€์˜ ํ•ด์„์— ๋Œ€ํ•ด 2๊ฐ€์ง€ ์ค‘์š”ํ•œ ์‚ฌ์‹ค์„ ๊ฐ•์กฐํ•œ๋‹ค. 

 

1. ๋ชจ๋“  ๊ตฌ์กฐ์ ์ธ ์ •๋ณด๋ฅผ ๊ณ ๋ คํ•˜๋ ค๋ฉด, ์ด์ „ ๋ ˆ์ด์–ด์˜ ์ž„๋ฒ ๋”ฉ์„ ์œ ์ง€ (keep) ํ•˜๋Š” ๊ฒƒ์ด ํ•„์ˆ˜์ ์ด๋‹ค. 

2. mean ์ด๋‚˜ max ์—ฐ์‚ฐ๋ณด๋‹ค sum ์—ฐ์‚ฐ์ž๊ฐ€ ๊ฐ€์žฅ ํ‘œํ˜„๋ ฅ์ด ์ข‹๋‹ค. 

 

์ด๋Ÿฌํ•œ ์‚ฌ์‹ค๋“ค์„ ๋ฐ”ํƒ•์œผ๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์€ Global pooling ํ•จ์ˆ˜๋ฅผ ์ œ์•ˆํ•˜๊ฒŒ ๋œ๋‹ค. 

 

global pooling

 

 

๊ฐ ๋ ˆ์ด์–ด์— ๋Œ€ํ•ด์„œ ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ ๊ฐ’์€ ๋ชจ๋‘ ๋”ํ•ด์ง€๊ณ , ๋”ํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋‘ concat ์‹œํ‚จ๋‹ค. 

 

 

 

 

๐Ÿ”น GIN in pytorch 

 

 

•  GINConv 

 

โ†ช nn : MLP (๋‘ ์ผ๋Œ€์ผ ๋Œ€์‘ ํ•จ์ˆ˜๋ฅผ ๊ทผ์ ‘์‹œํ‚ค๋Š”๋ฐ ์‚ฌ์šฉ)

โ†ช eps :  ษ› ์˜ ์ดˆ๊ธฐ๊ฐ’ (๋ณดํ†ต 0์œผ๋กœ ์„ค์ •)

โ†ช train_eps :  ษ› ๊ฐ€ ํ›ˆ๋ จ ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ์ธ์ง€ T/F ๋กœ ํ‘œ๊ธฐ (๋ณดํ†ต F ๋กœ ์„ค์ •) 

 

GIN layer ์—์„œ ์‚ฌ์šฉํ•˜๋Š” MLP ๊ตฌ์กฐ

 

 

GIN ์ „์ฒด ๊ตฌ์กฐ

 

 

 

 

[์ฐธ๊ณ ]  GINEConv 

โ†ช Second GIN layer , ์ด์›ƒ๋…ธ๋“œ์˜ feature ์— ReLU function ์„ ์ ์šฉ์‹œํ‚จ๋‹ค. 

 

 

 

 

 

 

3๏ธโƒฃ  ์ฝ”๋“œ๋ฆฌ๋ทฐ 


https://colab.research.google.com/drive/1ycQVANLkgqyk_iezLYGbSUp98JsY7bEQ?usp=sharing 

 

9๊ฐ• ๋ณต์Šต๊ณผ์ œ.ipynb

Colaboratory notebook

colab.research.google.com

 

 

 

• GCN ๊ณผ ์„ฑ๋Šฅ์„ ๋น„๊ตํ•ด๋ณผ ์˜ˆ์ • 

 

 

๐Ÿ”น GIN architecture code 

 

• Library 

 

from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout 
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv #โœจ
from torch_geometric.nn import global_mean_pool, global_add_pool #โœจ

 

 

 

• GCN 

 

class GCN(torch.nn.Module) : 

  def __init__(self, dim_h) : 
    super(GCN, self).__init__() 
    self.conv1 = GCNConv(dataset.num_node_features, dim_h) 
    self.conv2 = GCNConv(dim_h, dim_h) 
    self.conv3 = GCNConv(dim_h, dim_h) 
    self.lin = Linear(dim_h, dataset.num_classes)
  
  def forward(self, x, edge_index, batch) : 
    # ๋…ธ๋“œ ์ž„๋ฒ ๋”ฉ 
    h = self.conv1(x, edge_index) 
    h = h.relu() 
    h = self.conv2(h, edge_index) 
    h = h.relu() 
    h = self.conv3(h, edge_index) 

    # ๊ทธ๋ž˜ํ”„ level ํ•ด์„ (graph embedding)
    hG = global_mean_pool(h, batch) 

    # ๋ถ„๋ฅ˜๊ธฐ 
    h = F.dropout(hG, p=0.5, training = self.training)
    h = self.lin(h) 

    return hG, F.log_softmax(h, dim=1) # ํšจ์†Œ์ธ์ง€ ์•„๋‹Œ์ง€ ๋ถ„๋ฅ˜

 

 

• GIN (3 layer)

 

class GIN(torch.nn.Module) : 

  def __init__(self, dim_h) : 
    super(GIN, self).__init__() 
    self.conv1 = GINConv(
        Sequential(Linear(dataset.num_node_features, dim_h), 
                   BatchNorm1d(dim_h), ReLU(), 
                   Linear(dim_h, dim_h), ReLU())
    )
    self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU())
    )
    self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU())
    )

    self.lin1 = Linear(dim_h*3, dim_h*3) 
    self.lin2 = Linear(dim_h*3, dataset.num_classes)

  
  def forward(self, x, edge_index, batch) : 

    # ๋…ธ๋“œ์ž„๋ฒ ๋”ฉ 
    h1 = self.conv1(x, edge_index) 
    h2 = self.conv2(h1, edge_index) 
    h3 = self.conv3(h2, edge_index) 

    # ๊ทธ๋ž˜ํ”„ level ํ•ด์„ (graph embedding) 
    h1 = global_add_pool(h1, batch) 
    h2 = global_add_pool(h2, batch) 
    h3 = global_add_pool(h3, batch) 

    # concate ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ
    h = torch.cat((h1,h2,h3), dim = 1) 

    # ๋ถ„๋ฅ˜๊ธฐ 
    h = self.lin1(h) 
    h = h.relu() 
    h = F.dropout(h, p = 0.5, training = self.training) 
    h = self.lin2(h) 

    return h, F.log_softmax(h, dim=1) 

    
gcn = GCN(dim_h=32)
gin = GIN(dim_h=32)

 

 

• Train and Test

 

def train(model, loader) : 

  criterion = torch.nn.CrossEntropyLoss() # ๐Ÿ’ก ํ‰๊ฐ€์ง€ํ‘œ 
  optimizer = torch.optim.Adam(model.parameters() , lr = 0.01, weight_decay = 0.01) # ๐Ÿ’ก ์ตœ์ ํ™”
  epochs = 100 

  model.train() 

  for epoch in range(epochs +1) : 
    total_loss = 0 
    acc = 0 
    val_loss = 0 
    val_acc = 0 

    # ๐Ÿ’ก ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ํ•™์Šต 
    for data in loader : 
      optimizer.zero_grad() 

      _, out = model(data.x, data.edge_index, data.batch) 

      loss = criterion(out, data.y) 

      total_loss += loss / len(loader) 
      acc += accuracy(out.argmax(dim=1), data.y) / len(loader)
      loss.backward()
      optimizer.step() 


      # ๐Ÿ’ก validation 
      val_loss, val_acc = test(model, val_loader)
    
    # 10 epoch ๋งˆ๋‹ค ์ง€ํ‘œ๊ฐ’ ์ถœ๋ ฅ 
    if (epoch%10 == 0) : 
         print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} '
              f'| Train Acc: {acc*100:>5.2f}% '
              f'| Val Loss: {val_loss:.2f} '
              f'| Val Acc: {val_acc*100:.2f}%')
         
    test_loss, test_acc = test(model, test_loader)    
    print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc*100:.2f}%')

    return model  

      
def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
      _, out = model(data.x, data.edge_index, data.batch)
      loss += criterion(out, data.y) / len(loader)
      acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

 

 

 

 

๐Ÿ”น test accuracy result 

 

GIN ์˜ ์„ฑ๋Šฅ์ด ํ›จ์”ฌ ๋›ฐ์–ด๋‚จ

 

 

 

• ์„ฑ๋Šฅ์ฐจ์ด๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ์ด์œ  

 

โ†ช  GCN ๊ณผ ๋‹ฌ๋ฆฌ, GIN ์˜ aggregator ๋Š” ํŠนํžˆ graph ๋ฅผ ํŒ๋ณ„ํ•˜๋Š”๋ฐ ๊ณ ์•ˆ๋œ ๋ชจ๋ธ๊ตฌ์กฐ์ด๋‹ค. 

โ†ช  GIN ์€ ๋งจ ๋งˆ์ง€๋ง‰ hidden vector ๋งŒ์„ ๊ณ ๋ คํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ๋ชจ๋“  ๋ ˆ์ด์–ด์˜ hidden vector ๋ฅผ concat ํ•จ์œผ๋กœ์จ ๊ทธ๋ž˜ํ”„ ์ „์ฒด ๊ตฌ์กฐ๋ฅผ ๊ณ ๋ คํ•œ๋‹ค. 

โ†ช  sum ์—ฐ์‚ฐ์ด mean ์—ฐ์‚ฐ๋ณด๋‹ค ๋” ์„ฑ๋Šฅ์ด ๋›ฐ์–ด๋‚˜๋‹ค (injective ๋ฉด์—์„œ) 

 

 

 

 

• Graph classification ๊ฒฐ๊ณผ 

๋‹ค๋ฅด๊ฒŒ ๋ถ„๋ฅ˜ํ•˜๋Š” ๊ฒฐ๊ณผ๋ฅผ ์‚ดํŽด๋ณผ ์ˆ˜ ์žˆ๋‹ค. (ํšจ์†Œ์ธ์ง€ ์•„๋‹Œ์ง€ ๋ถ„๋ฅ˜)

 

 

• GCN ๊ฒฐ๊ณผ์™€ GIN ๊ฒฐ๊ณผ๋ฅผ ensemble ํ•˜์—ฌ ์„ฑ๋Šฅ์„ ์กฐ๊ธˆ ๋” ๋†’์ผ ์ˆ˜ ์žˆ๋‹ค. normalized ๋œ output vector ์— ํ‰๊ท ์„ ์ทจํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋‘ ๋„คํŠธ์›Œํฌ์˜ ๊ทธ๋ž˜ํ”„ ์ž„๋ฒ ๋”ฉ ๊ฒฐ๊ณผ๋ฅผ ๊ฒฐํ•ฉํ•œ๋‹ค. 

 

gcn.eval()
gin.eval()
acc_gcn = 0
acc_gin = 0
acc = 0

for data in test_loader : 

  # ๋ถ„๋ฅ˜
  _, out_gcn = gcn(data.x, data.edge_index, data.batch) 
  _, out_gin = gin(data.x, data.edge_index, data.batch) 
  out = (out_gcn+ out_gin)/2 # โœจ

  # accuracy score ๊ณ„์‚ฐ 
  acc_gcn += accuracy(out_gcn.argmax(dim=1), data.y) / len(test_loader)
  acc_gin += accuracy(out_gin.argmax(dim=1), data.y) / len(test_loader)
  acc += accuracy(out.argmax(dim=1), data.y) / len(test_loader)

 

 

 

 

๐Ÿ”น Conclusion 

 

 

GIN ์€ GNN ์„ ์ดํ•ดํ•˜๋Š” ๋ฐ ์žˆ์–ด์„œ ๋งค์šฐ ์ค‘์š”ํ•œ ๋„คํŠธ์›Œํฌ์ด๋‹ค. ์ •ํ™•๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ์—ฌ๋Ÿฌ๊ฐ€์ง€ GNN ๊ตฌ์กฐ ์ค‘์—์„œ ์–ด๋– ํ•œ ๋ชจ๋ธ์ด ๊ฐ€์žฅ ํ‘œํ˜„๋ ฅ์ด ์ข‹์€์ง€ ์ด๋ก ์ ์œผ๋กœ ์„ค๋ช…๊ฐ€๋Šฅํ•  ์ˆ˜ ์žˆ๋„๋ก ๋งŒ๋“  ๋„คํŠธ์›Œํฌ๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

ํ•ด๋‹น ์•„ํ‹ฐํด์—์„œ๋Š” graph classification task ์— ๋Œ€ํ•ด global pooling ์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•˜์—ฌ, WL test ์•„์ด๋””์–ด์—์„œ ๊ณ ์•ˆ๋œ GIN ๋„คํŠธ์›Œํฌ๋ฅผ ์ ์šฉํ•ด๋ณด๋Š” ๋‚ด์šฉ์„ ๋‹ด๊ณ  ์žˆ๋‹ค. 

 

๋น„๋ก PROTEINS ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ๋Š” GIN ์ด ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์˜€์œผ๋‚˜, social graph ์™€ ๊ฐ™์ด ์‹ค์ œ ์„ธ๊ณ„ ๋ฐ์ดํ„ฐ์—๋Š” ์ž˜ ์ ์šฉ๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ๋„ ์žˆ๋‹ค. 

 

 

 

 

 

 

 

 

 

 

 

 

728x90

'1๏ธโƒฃ AIโ€ขDS > ๐Ÿ“˜ GNN' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[cs224w] Frequent Subgraph Mining with GNNs  (0) 2023.01.27
[CS224W] Graph Neural Network  (0) 2022.11.24
[CS224W] Message Passing and Node classification  (0) 2022.11.17
[CS224W] PageRank  (0) 2022.11.02
[CS224W] 1๊ฐ• Machine Learning With Graphs  (0) 2022.10.11

๋Œ“๊ธ€