๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
1๏ธโƒฃ AI•DS/๐Ÿ“™ Model

Tabnet

by isdawell 2022. 3. 31.
728x90

 

0๏ธโƒฃ Tabnet 

 

Tree ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์˜ ๋ณ€์ˆ˜ ์„ ํƒ ํŠน์ง•์„ ๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ์— ๋ฐ˜์˜ํ•œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ

 

1๏ธโƒฃ ๋ฐฐ๊ฒฝ 

 

โœ” ๊ธฐ์กด ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€, ์Œ์„ฑ, ์–ธ์–ด์™€ ๊ฐ™์€ ๋น„์ •ํ˜• ๋ฐ์ดํ„ฐ์—๋งŒ ์ ์šฉ๋˜์—ˆ์Œ 

โœ” ์ •ํ˜• ๋ฐ์ดํ„ฐ Tabular Data ๋Š” ์ตœ๊ทผ๊นŒ์ง€๋„ kaggle ๊ฐ™์€ ์—ฌ๋Ÿฌ ๋Œ€ํšŒ์—์„œ XGBoost, LightGBM, CatBoost์™€๊ฐ™์€ Tree๊ธฐ๋ฐ˜์˜ ์•™์ƒ๋ธ” ๋ชจ๋ธ์„ ์ฃผ๋กœ ์‚ฌ์šฉํ–ˆ์Œ 

 

๐Ÿ‘€ ๋”ฅ๋Ÿฌ๋‹์˜ ์ ์ง„์  ํ•™์Šต ํŠน์„ฑ + ์‚ฌ์ „ํ•™์Šต ๊ฐ€๋Šฅ์„ฑ์€ ์ƒˆ๋กœ์šด ๋ถ„์„ ๊ธฐํšŒ๋ฅผ ๋„์ถœ 

๐Ÿ‘€ ํŠธ๋ฆฌ๊ธฐ๋ฐ˜ ๋ชจ๋ธ + ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์กฐ ์˜ ์žฅ์ ์„ ๋ชจ๋‘ ๊ฐ–๋Š” Tabnet ์„ ์ œ์•ˆ ๐Ÿ‘‰ feature selection & engineering + ๋ชจ๋ธ ํ•ด์„๋ ฅ์„ ๊ฐ–์ถ˜ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ  

 

 

2๏ธโƒฃ Tabnet ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ 

 

๐Ÿง ์•™์ƒ๋ธ” ๋ชจ๋ธ์ด ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ๋ณด๋‹ค ์šฐ์ˆ˜ํ•œ ์ด์œ  

 

 (1) ์ •ํ˜•๋ฐ์ดํ„ฐ๋Š” ๋Œ€๋žต์ ์ธ ์ดˆํ‰๋ฉด ๊ฒฝ๊ณ„๋ฅผ ๊ฐ€์ง€๋Š” manifold ์ด๊ณ , ๋ถ€์ŠคํŒ… ๋ชจ๋ธ๋“ค์€ ์ด๋Ÿฌํ•œ manifold ์—์„œ ๊ฒฐ์ •์„ ํ•  ๋•Œ ๋” ํšจ์œจ์ ์œผ๋กœ ์ž‘๋™ํ•œ๋‹ค. ์ด๋ฏธ์ง€์™€ ์–ธ์–ด๊ฐ™์€ ๋น„์ •ํ˜• ๋ฐ์ดํ„ฐ๋Š” ์ •ํ˜• ๋ฐ์ดํ„ฐ์— ๋น„ํ•ด ์ƒ๋Œ€์ ์œผ๋กœ ๊ฐ™์€ ์›์ฒœ์—์„œ ๋ฐœ์ƒ๋œ ๋ฐ์ดํ„ฐ์ด๋ฏ€๋กœ ๋Œ€๋žต์ ์ธ ์ดˆํ‰๋ฉด ๊ฒฝ๊ณ„๊ฐ€ ๋šœ๋ ทํ•˜์ง€ ์•Š๋‹ค. 

 

 (2) ํŠธ๋ฆฌ ๊ธฐ๋ฐ˜์˜ ๋ชจ๋ธ๋“ค์ด ํ•™์Šต์ด ๋น ๋ฅด๊ณ  ์‰ฝ๊ฒŒ ๊ฐœ๋ฐœ์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

 

 (3) ํŠธ๋ฆฌ๊ธฐ๋ฐ˜์˜ ๋ชจ๋ธ๋“ค์€ ๋†’์€ ํ•ด์„๋ ฅ์„ ๊ฐ€์ง„๋‹ค. ํŠธ๋ฆฌ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์˜ ํŠน์„ฑ ์ƒ ๋ณ€์ˆ˜ ์ค‘์š”๋„๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ์–ด ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์— ๋น„ํ•ด ์ƒ๋Œ€์ ์œผ๋กœ ํ•ด์„์— ์šฉ์ดํ•˜๋‹ค. 

 

 (4) CNN, MLP ๊ฐ™์€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ ์ง€๋‚˜์น˜๊ฒŒ Overparametrized ๋˜์–ด์„œ ์ •ํ˜• ๋ฐ์ดํ„ฐ ๋‚ด ๋งค๋‹ˆํด๋“œ์—์„œ ์ผ๋ฐ˜ํ™”๋œ ํ•ด๊ฒฐ์ฑ…์„ ์ฐพ๋Š”๋ฐ ์–ด๋ ค์›€์„ ๋ฐœ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค. 

 

 

๐Ÿค” ์ •ํ˜• ๋ฐ์ดํ„ฐ์— ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ๋„ ๋‚˜์˜์ง€ ์•Š์•„  

 

 (1) ๋งค์šฐ ๋งŽ์€ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์…‹์— ๋Œ€ํ•ด ์„ฑ๋Šฅ์„ ๋†’์ผ ์ˆ˜ ์žˆ๋‹ค. 

 

 (2) ์ •ํ˜• ๋ฐ์ดํ„ฐ์™€ ์ด๋ฏธ์ง€(ํ…์ŠคํŠธ) ๋“ฑ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ ํ•™์Šต์— ํ•จ๊ป˜ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋‹ค. (multi - modal Learning) 

 

 (3) ํŠธ๋ฆฌ ๊ธฐ๋ฐ˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์—์„œ ํ•„์ˆ˜์ ์ธ Feature engineering ๊ฐ™์€ ๋‹จ๊ณ„๋ฅผ ํฌ๊ฒŒ ์š”๊ตฌํ•˜์ง€ ์•Š๋Š”๋‹ค. 

 

 (4) ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ Streaming ๋ฐ์ดํ„ฐ๋กœ๋ถ€ํ„ฐ ํ•™์Šต์ด ์šฉ์ดํ•˜๋‹ค. (์ง€์†์ ์ธ ํ•™์Šต) 

 

 

๐Ÿ˜Ž Tabnet ์€ ๋ง์ด์•ผ 

 

 (1) Feature ์˜ ์ „์ฒ˜๋ฆฌ ์—†์ด ์› ๋ฐ์ดํ„ฐ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ณ , ๊ฒฝ์‚ฌ ํ•˜๊ฐ•๋ฒ• ๊ธฐ๋ฐ˜ ์ตœ์ ํ™” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•ด End-to-End learning ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ๋‹ค. 

 

 (2) ์„ฑ๋Šฅ๊ณผ ํ•ด์„๋ ฅ ํ–ฅ์ƒ์„ ์œ„ํ•ด Sequential attention mechanism ์„ ํ†ตํ•ด ์‚ฌ์šฉํ•  feature ๋ฅผ ์„ ํƒํ•œ๋‹ค. 

 

 (3) ๊ธฐ์กด ์ •ํ˜• ๋ถ„๋ฅ˜, ํšŒ๊ท€ ๋ชจ๋ธ๋ณด๋‹ค ์„ฑ๋Šฅ์˜ ์šฐ์ˆ˜์„ฑ์„ ๊ฐ€์ง€๋ฉฐ ํ•ด์„๋ ฅ์—์„œ ์ž…๋ ฅ ํ”ผ์ฒ˜์˜ ์ค‘์š”๋„๋ฅผ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๊ณ , ํ”ผ์ฒ˜์˜ ๊ฒฐํ•ฉ์„ ์‹œ๊ฐํ™”ํ•˜์—ฌ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ž…๋ ฅ ํ”ผ์ฒ˜๋“ค์ด ์–ผ๋งˆ๋‚˜ ์ž์ฃผ ๊ฒฐํ•ฉ๋˜๋Š”์ง€์— ๋Œ€ํ•œ ํ•ด์„๋ ฅ์„ ์ œ์‹œํ•œ๋‹ค. 

 

 

 

3๏ธโƒฃ Tabnet ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ตฌ์กฐ 

 

๐Ÿ“Œ ๊ฐœ์š” 

  • ์ˆœ์ฐจ์ ์ธ ์–ดํ…์…˜์„ ์‚ฌ์šฉํ•ด ๊ฐ ์˜์‚ฌ๊ฒฐ์ • ๋‹จ๊ณ„์—์„œ ์ถ”๋ก ํ•  ํ”ผ์ฒ˜๋ฅผ ์„ ํƒํ•ด๊ฐ€๋ฉด์„œ ํ”ผ๋“œ๋ฐฑ์„ ์ฃผ๋ฉฐ ํ•™์Šตํ•ด๋‚˜์•„๊ฐ€๋Š” ๊ตฌ์กฐ์ด๋‹ค ๐Ÿ‘‰ ๋” ๋‚˜์€ ํ•ด์„ ๋Šฅ๋ ฅ๊ณผ ํ•™์Šต์ด ๊ฐ€๋Šฅ + ์ˆจ๊ฒจ์ง„ ํŠน์ง•์„ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์ „ ๋น„์ง€๋„ ํ•™์Šต (Self-supervised Learning) ์„ ์‚ฌ์šฉ ๊ฐ€๋Šฅ 
  • tabnet ์˜ feature selection ์€ ํŠน์ • ํ”ผ์ฒ˜๋งŒ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ๊ฐ ํ”ผ์ฒ˜์— ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ€์—ฌํ•˜๋Š” ๊ฒƒ์ด๋‹ค. Sparse Feature selection
  • Tabnet ์˜ ๊ตฌ์กฐ๋Š” Encoder - Decoder ๋ฅผ ๊ฑฐ์ณ ๊ฒฐ์ธก๊ฐ’์„ ์˜ˆ์ธกํ•  ์ˆ˜ ์žˆ๋Š” Autoencoder ๊ตฌ์กฐ์ด๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ์…‹์— ๊ฒฐ์ธก๊ฐ’์ด ์žˆ์–ด๋„ ๋ณ„๋„์˜ ์ „์ฒ˜๋ฆฌ ์—†์ด ๊ฐ’๋“ค์„ ์ฑ„์šธ ์ˆ˜ ์žˆ๋‹ค. 

 

 

๐Ÿ‘€ Encoder

  • input ์„ ์‹œ์ž‘์œผ๋กœ ๊ฐ ์˜์‚ฌ๊ฒฐ์ • ๋‹จ๊ณ„ Step ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๊ณ , ๋‹จ๊ณ„๋งˆ๋‹ค Feature transformer, Attentive transformer, Feature masking ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค. 
  • feature transformer ์™€ attentive transformer ๋ธ”๋ก์„ ํ†ต๊ณผํ•˜์—ฌ ์ตœ์ ์˜ mask ๋ฅผ ํ•™์Šตํ•œ๋‹ค. 
  • ๋‹ค์Œ decision step ์œผ๋กœ ์ด์ „์˜ decision ์— ๋Œ€ํ•œ ์ •๋ณด๋“ค์ด ์ „๋‹ฌ๋˜๋Š” ๊ณผ์ •์ด ํŠธ๋ฆฌ๊ธฐ๋ฐ˜ ๋ถ€์ŠคํŒ… ๋ชจ๋ธ์˜ ์ž”์ฐจ๋ฅผ ์ค„์—ฌ๋‚˜๊ฐ€๋Š” ๋ถ€๋ถ„๊ณผ ์œ ์‚ฌํ•˜๋‹ค. 

 

 

  • feature masking ์€ local ํ•ด์„์— ์‚ฌ์šฉ๋˜๋ฉฐ ์ „์ฒด๋ฅผ ์ทจํ•ฉํ•˜์—ฌ global ํ•œ ํ•ด์„์„ ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. 

 

 

 

๐Ÿ‘€ Decoder

  • ๊ฐ step ์—์„œ feature transformer ๋ธ”๋ก์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค. 
  • ์ผ๋ฐ˜์ ์ธ ํ•™์Šต์—์„  Decoder ๋ฅผ ์‚ฌ์šฉํ•˜์ง„ ์•Š์ง€๋งŒ Self-Supervised (Semi-supervised) ํ•™์Šต ์ง„ํ–‰์‹œ ์ธ์ฝ”๋” ๋‹ค์Œ์— ๋ถ™์—ฌ์ ธ ๊ธฐ์กด ๊ฒฐ์ธก๊ฐ’ ๋ณด์™„ ๋ฐ ํ‘œํ˜„ ํ•™์Šต์„ ์ง„ํ–‰ํ•œ๋‹ค. 

 


 

๐Ÿ“Œ ์„ธ๋ถ€ ๊ตฌ์กฐ  

 

๐Ÿคจ Tabnet ์•„ํ‚คํ…์ณ๋ฅผ ๋”ฐ๋ผ๊ฐ€๋ฉฐ ์œ„์˜ (a) ๊ทธ๋ฆผ์— ์ œ์‹œ๋œ ๊ฐ ๋ฐ•์Šค ๋ถ€๋ถ„์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ณ ์ž ํ•œ๋‹ค. 

 

 

๐Ÿ“• Feature transformer ๐Ÿ’จ ์ž„๋ฒ ๋”ฉ์„ ์ˆ˜ํ–‰ 

 

  • ์„ ํƒ๋œ ํ”ผ์ฒ˜๋กœ ์ •ํ™•ํžˆ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ ์ž„๋ฒ ๋”ฉ ๊ธฐ๋Šฅ 

 

 

  • ์ž…๋ ฅ Feature : numerical ํ”ผ์ฒ˜๋Š” ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๊ณ , categorical ํ”ผ์ฒ˜๋Š” ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋ฅผ ๊ตฌ์„ฑํ•ด์ค€๋‹ค ๐Ÿ‘‰ ๋ชจ๋ธ ์ƒ์„ฑ์‹œ cat_idxs, cat_dims, cat_emb_dim ์ธ์ž์™€ ๊ด€๋ จ๋จ 
  • BatchNorm (BN) : ์ •ํ˜• ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„์„ํ•  ๋•Œ ๋ณดํ†ต Min-Max scaler, Standard Scaler ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š”๋ฐ, ์ด๋Ÿฌํ•œ ์ •๊ทœํ™” ๊ณผ์ •์„ BatchNorm ๋ ˆ์ด์–ด๋กœ ๋Œ€์ฒดํ•˜์—ฌ ์‚ฌ์šฉํ–ˆ๋‹ค. 
    • batch ๋ฅผ ๋ถ„ํ• ํ•œ nano batch ์‚ฌ์šฉ์œผ๋กœ ์žก์Œ์„ ์ถ”๊ฐ€ํ•ด ์ง€์—ญ ์ตœ์ ํ™”๋ฅผ ์˜ˆ๋ฐฉํ•œ๋‹ค. 

 

  • Feature transformer : FC-BN-GLU ๋ฅผ 4๋ฒˆ ๋ฐ˜๋ณตํ•œ ๊ตฌ์กฐ 
    • FC : fully connected layer ์ „๊ฒฐํ•ฉ์ธต 
    • GLU : Gated Linear unit, ์„ ํ˜• ๋งคํ•‘์„ ํ†ตํ•ด ๋‚˜์˜จ ๊ฒฐ๊ณผ๋ฌผ์„ ๋ฐ˜์œผ๋กœ ๋‚˜๋ˆ„์–ด Residual connection, sigmoidfunction ์„ ๊ฑฐ์นœ ํ›„ element-wise ๋กœ ๊ณ„์‚ฐํ•˜๋Š” ๊ตฌ์กฐ. ๊ฐ ์ •๋ณด ๋ณ„ ์ •๋ณด์˜ ์–‘์„ ์–ผ๋งˆ๋‚˜ ํ˜๋ ค๋ณด๋‚ผ์ง€ ๊ฒฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ๋น„์„ ํ˜• ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. 

 

GLU Notation

 

 

  • ์•ž์˜ 2๊ฐœ ๋„คํŠธ์›Œํฌ ๋ฌถ์Œ์€ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ณต์œ ํ•˜๋ฉฐ ๊ธ€๋กœ๋ฒŒ ์„ฑํ–ฅ์„ ํ•™์Šตํ•˜๊ณ , ๋’ค์˜ 2๊ฐœ ๋„คํŠธ์›Œํฌ ๋ฌถ์Œ์€ ๊ฐ ์Šคํ…์—์„œ๋งŒ ์ „์šฉ์œผ๋กœ ์‚ฌ์šฉ๋˜๋Š” ๋ธ”๋ก์œผ๋กœ ๊ฐ ๋กœ์ปฌ ์„ฑํ–ฅ์„ ํ•™์Šตํ•œ๋‹ค. 

 

 

 

๐Ÿ‘€ Split block 

 

  • feature transformer ๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ ๊ฒฐ๊ณผ๋ฅผ ๋‘ ๊ฐœ๋กœ ๋‚˜๋ˆ„์–ด, ํ•˜๋‚˜๋Š” ReLU ๋กœ ๋ณด๋‚ด์–ด ์ตœ์ข… ์•„์›ƒํ’‹ (Decision output) ์œผ๋กœ ๋ณด๋‚ด๊ณ  ๋‹ค๋ฅธ ํ•˜๋‚˜๋Š” ๋‹ค์Œ Attentive transformer ๋กœ ๋„˜๊ฒจ์ค€๋‹ค. 
  • ํ–ฅํ›„ ๊ฐ decision output์˜ ๊ฒฐ๊ณผ๋ฅผ ํ•ฉ์‚ฐํ•ด ์ „์ฒด ์˜์‚ฌ๊ฒฐ์ • ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๊ณ  ์ด ์ž„๋ฒ ๋”ฉ์ด FC layer ๋ฅผ ๊ฑฐ์น˜๋ฉด ์ตœ์ข… classification/regression ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ€ ์‚ฐ์ถœ๋œ๋‹ค. 
  • ReLU layer ์˜ ๊ฒฐ๊ณผ์—์„œ hidden unit ์ฑ„๋„์˜ ๊ฐ’๋“ค์„ ๋ชจ๋‘ ํ•ฉ์‚ฐํ•ด ํ•ด๋‹น step ์˜ ํ”ผ์ฒ˜์ค‘์š”๋„๋ฅผ ์‚ฐ์ถœํ•  ์ˆ˜ ์žˆ๋‹ค. ๊ฐ ๋‹จ๊ณ„์˜ ํ”ผ์ฒ˜์ค‘์š”๋„ ๊ฒฐ๊ณผ๋ฅผ ํ•ฉ์‚ฐํ•˜๋ฉด ์ตœ์ข… ํ”ผ์ฒ˜ ์ค‘์š”๋„๊ฐ€ ๋„์ถœ๋œ๋‹ค. 

 

 

๐Ÿ“— Attentive transformer ๐Ÿ’จ Mask ๋ฅผ ์ƒ์„ฑ (๋ณ€์ˆ˜์„ ํƒ ๊ธฐ๋Šฅ) 

 

  • FC , BN, Sparsemax ๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๋ฉฐ Mask ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
  • Mask ์—๋Š” ์–ด๋–ค ํ”ผ์ฒ˜๋ฅผ ์ฃผ๋กœ ์‚ฌ์šฉํ•  ๊ฒƒ์ธ์ง€์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ๋‹ค. 
  • ์ƒ์„ฑ๋œ Mask ์— ํ”ผ์ฒ˜๋ฅผ ๊ณฑํ•˜์—ฌ ํ”ผ์ฒ˜์„ ํƒ์ด ์ด๋ฃจ์–ด์ง„๋‹ค. ์ด์ „ step ์˜ ํ”ผ์ฒ˜์™€ ๊ณฑํ•˜์—ฌ Masked feature ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค. ์ด๋Š” ๋‹ค์‹œ Feature transformer ๋กœ ์—ฐ๊ฒฐ๋˜๋ฉฐ Step ์ด ๋ฐ˜๋ณต๋œ๋‹ค. 
  • Prior scale ์‚ฌ์ „ ์ •๋ณด๋Ÿ‰ : ์ด์ „ decision step ๋“ค์—์„œ ๊ฐ feature ๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋งŽ์ด ์‚ฌ์šฉ๋˜์—ˆ๋Š”์ง€ ์ง‘๊ณ„ํ•œ ์ •๋ณด๋กœ, ์ด์ „ step ์—์„œ ์‚ฌ์šฉํ•œ Mask ๋ฅผ ์–ผ๋งˆ๋‚˜ ์žฌ์‚ฌ์šฉํ• ์ง€ ์กฐ์ ˆํ•  ์ˆ˜ ์žˆ๋‹ค. ์„ ํƒ๋œ ๋ณ€์ˆ˜์˜ ๋ฐ˜์˜๋ฅ ์„ ์กฐ์ ˆํ•˜๋Š” ์š”์ธ. 
  • masking ์„ ํ†ตํ•ด ํ•™์Šต์— ํฐ ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š์€ ๋ณ€์ˆ˜๋“ค์˜ ์˜ํ–ฅ๋ ฅ ๊ฐ์†Œ์‹œํ‚ด ๐Ÿ‘‰ mask ๋ฅผ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด attentive transformer ๋ฅผ ์‚ฌ์šฉ 
  • Sparsemax : Softmax์˜ sparseํ•œ ๋ฒ„์ „์œผ๋กœ sparse ํ•œ ๋ฐ์ดํ„ฐ์…‹์— ์ ์šฉํ–ˆ์„ ๋•Œ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ธ ์ •๊ทœํ™”๊ธฐ๋ฒ•์ด๋‹ค. ๊ฐ ๋ณ€์ˆ˜ ๋ณ„ ๊ณ„์ˆ˜ ๊ฐ’๋“ค์˜ ์ผ๋ฐ˜ํ™”๋ฅผ ์œ„ํ•ด ์‚ฌ์šฉํ•œ๋‹ค. ๋ณ€์ˆ˜์˜ ์–‘์ด ๋งŽ์•„์งˆ์ˆ˜๋ก ๊ฐ’์ด 0๊ณผ 1๋กœ ์ˆ˜๋ ด๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์•„์ง ๐Ÿ‘‰ ๋” ๊ฐ•๋ ฅํ•œ ํ”ผ์ฒ˜ ์„ ํƒ ์˜์‚ฌ ๊ฒฐ์ • ๊ณผ์ • (๊ฒฐ์ •์˜ ํšจ๊ณผ๋ฅผ ๋†’์ธ๋‹ค) 

 

 

 

 

 

 

 

๐Ÿ“˜ feature Masking

  • feature importance ๋ฅผ ๊ณ„์‚ฐ
  • ์ด์ „ Step ์˜ feature ์— ๊ณฑํ•˜์—ฌ Masked feature ๋ฅผ ์ƒ์„ฑ 
  • ๋‹ค์Œ Step ์˜ Mask ์—์„œ ์ ์šฉํ•  Prior scale term ๊ณ„์‚ฐ 

  • Masked feature ๋Š” ๋‹ค์Œ step ์˜ input ์ด ๋œ๋‹ค. 

 

๐Ÿ‘€ Agg(regate) block 

  • ์–ด๋–ค feature ๊ฐ€ ์ค‘์š”ํ•œ์ง€ ์•Œ ์ˆ˜ ์žˆ๋‹ค. 

 

 

๐Ÿ“™ feature importance mask 

  • ๊ฐ decision step (M1, M2, ..) ๋ณ„๋กœ ์–ด๋–ค ํ”ผ์ฒ˜๋“ค์ด ์ค‘์š”ํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜์—ˆ๋Š”์ง€๋ฅผ ์‹œ๊ฐํ™” ํ•œ ๊ฒƒ์ด๋‹ค. ๊ฐ ๋‹จ๊ณ„์—์„œ ์–ด๋–ค ๋ณ€์ˆ˜๋“ค์ด ์ฃผ์š”ํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜์—ˆ๋Š”์ง€ ํ•ด์„ํ•  ์ˆ˜ ์žˆ๋‹ค. 

 

4๏ธโƒฃ ์ฝ”๋“œ ์‹ค์Šต 

 

import torch
import torch.nn as nn
from pytorch_tabnet.tab_model import TabNetClassifier


clf = TabNetClassifier()  #TabNetRegressor()

clf.fit(
  X_train, Y_train,
  eval_set=[(X_valid, y_valid)]
)
preds = clf.predict(X_test)

 

 

5๏ธโƒฃ Plus

 

1. Sparse feature selection = decision blocks 

 

  • ์—ฌ๋Ÿฌ๊ฐœ์˜ ์˜์‚ฌ๊ฒฐ์ • ๋ธ”๋ก์„ ์‚ฌ์šฉ 
  • ๊ทธ๋ฆผ์—์„œ๋Š” ์„ฑ์ธ ์ธ๊ตฌ์กฐ์‚ฌ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์†Œ๋“์„ ์˜ˆ์ธกํ•˜๋Š” Tabnet ์˜ ์—ฐ์‚ฐ ๊ณผ์ •์„ ๋ณด์—ฌ์ฃผ๊ณ  ์žˆ๋‹ค. ์†Œ๋“์ˆ˜์ค€์„ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•ด 2๊ฐœ์˜ ์˜์‚ฌ๊ฒฐ์ • ๋ธ”๋ก์ด ๊ฐ๊ฐ ์ „๋ฌธ์ง ์—ฌ๋ถ€์™€ ํˆฌ์ž์•ก์— ๊ด€๋ จ๋œ ๋ณ€์ˆ˜๊ฐ€ ์„ ํƒํ•œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

 

2. Mask 

  • Mask ๋ž€ ์ž…๋ ฅ๋ณ€์ˆ˜๋“ค ์ค‘ ์„ ํƒ ๋ณ€์ˆ˜ ์™ธ ๋‹ค๋ฅธ ๋ณ€์ˆ˜๋“ค์„ ๊ฐ€๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. 
  • ๋‘ ๋ณ€์ˆ˜ x1, x2๊ฐ€ Sparse matrix Mask ๋ฅผ ํ†ต๊ณผํ•˜๊ฒŒ ๋˜๋ฉด ํŠน์ • ๋ณ€์ˆ˜๋ฅผ ์„ ํƒํ•œ ๊ฒƒ ๊ฐ™์€ ํšจ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค. 
  • Ck : ๊ฐ€์ค‘์น˜. ์ด ๊ฐ’์ด ์ปค์งˆ์ˆ˜๋ก ๋ถ„๋ฅ˜๋ฅผ ์œ„ํ•œ ๊ฒฐ์ •๊ฒฝ๊ณ„๊ฐ€ ๋šœ๋ ทํ•ด์ง„๋‹ค. 
  • ๋ณ€์ˆ˜ ์„ ํƒ ์ดํ›„ ์˜๋ฏธ ์ถ”์ถœ ๊ณผ์ •์—์„œ ๋‹ค๋ฅธ ๋ณ€์ˆ˜๋“ค์ด ๊ฐœ์ž…๋˜์ง€ ์•Š์œผ๋ฏ€๋กœ ReLU ๋ฅผ ํ†ต๊ณผํ•œ ๊ฒฐ๊ณผ๋“ค์€ ์„œ๋กœ ์ƒํ˜ธ ๋…๋ฆฝ์ ์ด๋‹ค. 
  • ReLU ๋ฅผ ํ†ตํ•ด ๊ฒฐ๊ณผ๋กœ ๋‚˜์˜จ output ์„ ํ•ฉ์ณ ์˜์‚ฌ๊ฒฐ์ •์— ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์•™์ƒ๋ธ” ํŠธ๋ฆฌ ๊ตฌ์กฐ์™€ ์œ ์‚ฌํ•˜๋‹ค. 

 

 

3. Self-supervised tabular learning 

  • Tabnet ์—์„œ๋Š” ์ž๊ธฐ์ง€๋„ํ•™์Šต (self - supervised) ์„ ์œ„ํ•ด ๋ฌด์ž‘์œ„๋กœ ๊ฐ€๋ ค์ง„ ๋ณ€์ˆ˜๊ฐ’์„ ์˜ˆ์ธกํ•˜๋Š” autoencoder ๊ตฌ์กฐ์˜ ๋น„์ง€๋„ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•ด ๋น„์ง€๋„ ํ‘œํ˜„์„ ํ•™์Šตํ•ด encoder ๊ตฌ์กฐ์˜ ์ง€๋„ํ•™์Šต ๋ชจ๋ธ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค. 
  • ํŠน์ • ์˜์—ญ์ด masking ๋œ ์ธ์ฝ”๋”ฉ ๋ฐ์ดํ„ฐ๋ฅผ ์›๋ณธ๋Œ€๋กœ ๋ณต์›ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•™์Šตํ•˜๋Š” ์‚ฌ์ „ ํ•™์Šต์„ ํ†ตํ•ด ์˜ˆ์ธก ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ, ํ•™์Šต ์‹œ๊ฐ„ ๋‹จ์ถ• ๋ฐ ๊ฒฐ์ธก์น˜์— ๋Œ€ํ•œ ๋ณด๊ฐ„ ํšจ๊ณผ
  • encoder ์—์„œ ์ •๋ณด๋ฅผ ์••์ถ•ํ•˜๊ณ  decoder ์—์„œ ํ™•์žฅํ•˜์—ฌ ํ•ด์„ํ•˜๋ฉด์„œ ๊ฒฐ์ธก์น˜๋ฅผ ๋ณด์ •ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค. 

 

 

 

4. Attenstion

  • Encoder : ์ •๋ณด๋ฅผ ์••์ถ•
  • Decoder : ์ •๋ณด๋ฅผ ํ™•์žฅํ•ด์„œ ํ•ด์„ 
  • ๋ฒกํ„ฐ ํ•˜๋‚˜์— ๋ชจ๋“  ์‹œํ€€์Šค์˜ ์ •๋ณด๋ฅผ ์˜์กดํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ๊ธธ์ด๋‚˜ ์ˆœ์„œ์— ์˜ํ–ฅ์„ ๋œ ๋ฐ›์Œ 

 

 

 

6๏ธโƒฃ Tabnet ์žฅ์  

1. ์ „์ฒ˜๋ฆฌ ๊ณผ์ •์ด ํ•„์š”ํ•˜์ง€ ์•Š๋‹ค. 

2. Decision step ์œผ๋กœ feature selection ์„ ์ง„ํ–‰ํ•œ๋‹ค. 

3. decision step ๋ณ„ ํ˜น์€ ๋ชจ๋ธ ์ „์ฒด์— ๋Œ€ํ•ด feature importance ๋ฅผ ์ˆ˜์น˜ํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค. 

4. ๋ฌด์ž‘์œ„๋กœ ๊ฐ€๋ ค์ง„ feature ๊ฐ’์„ ์˜ˆ์ธกํ•˜๋Š” unsupervised pretrain ๋‹จ๊ณ„๋ฅผ ์ ์šฉํ•˜์—ฌ ์ƒ๋‹นํ•œ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ๋ณด์—ฌ์ค€๋‹ค. 

5. ์‹ค์ œ ๋ฐ์ดํ„ฐ๋Š” ๋Š์ž„์—†์ด ์œ ์ž…๋˜๊ณ  ๋ณ€ํ™”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•œ๋ฒˆ์˜ ํ•™์Šต์œผ๋กœ ์˜์›ํžˆ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์€ ์—†๋‹ค. ๋•Œ๋ฌธ์— ๋”ฅ๋Ÿฌ๋‹์˜ pretraining, Incremental learning (iterative train) ํŠน์„ฑ์€ ์ง€์† ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์ธก๋ฉด์—์„œ ์ข‹์€ ๋Œ€์•ˆ์ด๋‹ค. 

 

 

๐Ÿพ ์ฐธ๊ณ ์ž๋ฃŒ 

1. https://wsshin.tistory.com/5

2. https://lv99.tistory.com/83 

3. https://housekdk.gitbook.io/ml/ml/tabular/tabnet-overview 

4. https://themore-dont-know.tistory.com/2

5. https://today-1.tistory.com/54 

6. http://dmqm.korea.ac.kr/activity/seminar/327

 

 

 

 

 

 

 

728x90

๋Œ“๊ธ€