Vanishing Gradients and Fancy RNNs
๐ก ์ฃผ์ : Vanishing Gradients and Fancy RNNs
๐ ํต์ฌ
- Task : ๋ฌธ์ฅ์ด ์ฃผ์ด์ง ๋ ์ง๊ธ๊น์ง ๋์จ ๋จ์ด๋ค ์ดํ์ ๋์ฌ ๋จ์ด๋ฅผ ์์ธก
- Sequential data : ์์๊ฐ ์๋ฏธ ์์ผ๋ฉฐ ์์๊ฐ ๋ฌ๋ผ์ง ๊ฒฝ์ฐ ์๋ฏธ๊ฐ ์์๋๋ ๋ฐ์ดํฐ๋ก ์ํ์ ๊ฒฝ๋ง์ ์ฌ์ฉํ๋ ์ด์ ๋ ์ ๋ ฅ์ ์์ฐจ๋ฐ์ดํฐ๋ก ๋ฐ๊ฑฐ๋, ์ถ๋ ฅ์ ์์ฐจ ๋ฐ์ดํฐ๋ก ๋ด๊ธฐ ์ํด์๋ค.
- RNN : ๋ค์์ ์ฌ ๋จ์ด๋ฅผ ์์ธกํ๋ ๊ณผ์ ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ํํ๊ธฐ ์ํด ๋์ ํ NN ์ ์ผ์ข ๐ ๋ฌธ์ ์ : ๊ธฐ์ธ๊ธฐ์์ค/ํญ์ฆ, ์ฅ๊ธฐ์์กด์ฑ
- LSTM : RNN ์ ์ฅ๊ธฐ์์กด์ฑ์ ๋ฌธ์ ์ ์ ๋ณด์ํด ๋ฑ์ฅํ ๋ชจ๋ธ ๐ cell state , 3 ๊ฐ์ gate ๊ฐ๋ ์ ๋์
1๏ธโฃ Language model, RNN ๋ณต์ต
1. Langauage model
โ ์ ์
- ํ๋ฅ ๋ถํฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ฃผ์ด์ง ๋ฌธ๋งฅ ์ดํ์ ์์นํ ๋จ์ด๋ฅผ ์์ธกํ๋ค.
- ํ๋ฅ ์ ๊ณฑ์ ๋ฒ์น์ ์ํด ์ํ์ค ๊ฒฐํฉํ๋ฅ ์ ๋์ถํ๋ค.
2. N-gram, NNLM
โ N-gram
- unigram, bigram, trigram ๋ฑ ์ด์ ์ ๋ฑ์ฅํ n ๊ฐ์ ๋จ์ด๋ฅผ ํตํด count ๋น๋๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์กฐ๊ฑด๋ถ ํ๋ฅ ์ ๊ณ์ฐํ๋ค
- ๐จ Sparsity ๋ฌธ์ : n-gram ๋ฒ์ ๋ด์ ๋จ์ด๊ฐ ๋ฌธ์ ๋ด์์ ๋ฑ์ฅํ์ง ์์ ๊ฒฝ์ฐ, ์กฐ๊ฑด๋ถํ๋ฅ ์ ๋ถ์๋ ๋ถ๋ชจ ๊ฐ์ด 0 ์ด ๋๋ ๋ฌธ์
โ NNLM
- window size ์ด์ ๋จ์ด๋ฅผ ๋ฐํ์ผ๋ก ์์ธก์ ์ํํ๋ค.
- n-gram ๊ณผ ์ ์ฌํ ๋ฐฉ์์ด๋ ํ๋ฅ ๊ฐ์ ์ฐ์ฐํ๋ ๊ณผ์ ์์ ์ฐจ์ด๊ฐ ์กด์ฌํ๋ค.
- word embedding : distributed representation (์๋ฒ ๋ฉ ๊ณต๊ฐ ์์์ ๋น์ทํ ์๋ฏธ๋ฅผ ๊ฐ์ง ๋จ์ด๋ค์ ๋น์ทํ ๊ณต๊ฐ์ ์์นํ ๊ฒ์ด๋ค)
- ๐จ window size ๊ฐ ์ปค์ง๋ฉด ํ๋ผ๋ฏธํฐ ์๊ฐ ์ฆ๊ฐํ์ฌ ์ฐ์ฐ๋์ด ๋์ด๋๊ณ ๊ณผ์ ํฉ๋ ๊ฐ๋ฅ์ฑ์ด ์กด์ฌ
โ ๋ ๋ฐฉ๋ฒ์ ๊ณตํต์ ์ธ ํ๊ณ์ ๐ฅ : ์ ๋ ฅ๊ธธ์ด (N, windowsize) ๊ฐ ๊ณ ์ ๋์ด ํ์ ๋ ๊ธธ์ด์ ๋จ์ด๋ง ์ดํด๋ณด๊ธฐ ๋๋ฌธ์, ๋ชจ๋ ๋ฌธ๋งฅ์ ๋ํ ๊ณ ๋ ค๋ ๋ฐฐ์ ํ๋ค.
3. Vanilla RNN (๊ธฐ๋ณธ RNN)
๐ฅ ngram ๊ณผ NNLM ์ ๋จ์ ์ ๊ทน๋ณตํ์ฌ, ์ ๋ ฅ ๊ธธ์ด์ ์ ํ์ ๋์ง ์๋ ๋ชจ๋ธ์ด ๋ฑ์ฅ
โ RNN Architexture
โ input word sequence ์ํซ๋ฒกํฐ ๐ ๋จ์ด์ฌ์ ๊ฐ์๋งํผ์ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ์ํซ๋ฒกํฐ๊ฐ ์ ๋ ฅ๋๋ค.
โก Word embedding e(t) = E*x(t) ๐ embedding matrix E ์ ์ ๋ ฅ ์ํซ๋ฒกํฐ x๊ฐ ๊ณฑํด์ ธ embedding ํํ์์ ๋ฒกํฐ e ๋ฅผ ์์ฑํด๋ธ๋ค.
โข Hidden state vector h(t) = sigmoid(Wh*h(t-1) + We*e(t) + b1) ๐ ์ด์ ๋จ๊ณ์ hidden state ์ Wh ๋ฒกํฐ๋ฅผ ๊ณฑํ ๊ฐ๊ณผ, ํ์ฌ ๋จ๊ณ์ (We ๊ฐ์ค์น ํ๋ ฌ๊ณผ ์๋ฒ ๋ฉ ๋จ์ด๋ฒกํฐ e ๋ฅผ ๊ณฑํ) hidden state vector ๋ฅผ ๋ํ์ฌ ์๊ทธ๋ชจ์ด๋ ์ฐ์ฐ ์ํ
โฃ output distribution y(t) = softmax( U*h(t) + b2 ) ๐ weight matrix U ์ hidden state ๋ฅผ ๊ณฑํ์ฌ ์ํํธ๋งฅ์ค ์ฐ์ฐ์ ํตํด ์ต์ข output vector ๋ฅผ ๋ฐํํ๋ค. ์ด๋ output vector ๊ฐ์ด ๋ํ๋ด๋ ๊ฒ์ ํด๋น ๋จ์ด๊ฐ ๋ค์ ๋จ์ด๋ก ๋ฑ์ฅํ ํ๋ฅ ๊ฐ์ด๋ค.
โญ ์์ฐจ์ ์ผ๋ก ์ ๋ ฅ๊ฐ x(1), x(2), ... ์ด ๋ค์ด๊ฐ์ ๋ฐ๋ผ ๋จ๊ณ๋ณ๋ก ๋ค์์ ์์นํ ๋จ์ด๋ค์ ์ฐจ๋ก๋ก ๊ฒฐ์ ํ๊ฒ ๋๋ค.
โญ RNN ์ ์ค์ํ ๊ตฌ์กฐ์ ํน์ง : ๊ฐ ๋จ๊ณ๋ณ๋ก ๋์ผํ weight matrix E, We, Wh, U ๋ฅผ ์ ์ฉํ๋ค.
โ ์์คํจ์
๐ many-to-many RNN
- x(t) : ์ ๋ ฅ๊ฐ
- y(t) : ๋ค์์ ์์นํ ๋จ์ด์ ์์ธก๊ฐ
- U : ์ต์ข ์ ์ธ ์ถ๋ ฅ๋ฒกํฐ๋ฅผ ๋ฐํํ๊ธฐ ์ํด, hidden state ์ ๋ง๋๋ ๊ฐ์ค์นํ๋ ฌ
- Wh : ํ์ฌ ์์ ์ hidden state ๋ฅผ ๋์ถํ๊ธฐ ์ํด, ์ด์ ์์ ์ hidden state ์ ๋ง๋ ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง๋ ๊ฐ์ค์น ํ๋ ฌ
- We : ํ์ฌ ์์ ์ hidden state ๋ฅผ ๋์ถํ๊ธฐ ์ํด, ์๋ฒ ๋ฉ๋ ๋ฒกํฐ์ ๋ง๋ ์ฐ์ฐ์ด ์ด๋ฃจ์ด์ง๋ ๊ฐ์ค์น ํ๋ ฌ
๐ loss function
- ๋จ์ด๋ฅผ ์์ธกํ๋ ๋ถ๋ฅ ๋ฌธ์ ์ ํด๋นํ๋ฏ๋ก ์์คํจ์๋ cross-entropy ๋ก ์ ์๋๋ค.
- many-to-many RNN ๊ตฌ์กฐ์์ ๊ฐ time step ๋ณ๋ก loss function ๊ฐ์ด ๋์ถ๋๊ณ
- ๋ชจ๋ ๋จ๊ณ์ loss function ์ ํ๊ท ๋ด๋ฆฐ ๊ฐ์ด ์ต์ข RNN ๋ชจ๋ธ์ loss ์ ํด๋นํ๋ค.
๐ BPTT : back propagation through time
- RNN ์ ์ญ์ ํ๋ BPTT ๋ฐฉ์์ผ๋ก ์ด๋ฃจ์ด์ง๋ค. ์ผ๋ฐ์ ์ธ MLP ์ญ์ ํ ๊ณผ์ ๊ณผ ์ ์ฌํ์ง๋ง, ๋งค time step ๋ณ ์ ์ถ๋ ฅ์ด ์กด์ฌํ๋ค๋ ์ค์ํ ํน์ง์ ๊ณ ๋ คํ์ฌ ๊ณ์ฐํ๋ค๋ ์ฐจ์ด์ ์ด ์กด์ฌํ๋ค. (โก ๋ฒ ๊ณผ์ ์ฐธ๊ณ - time step ์ ๋ชจ๋ ๊ณ ๋ คํ์ฌ summation ์ฐ์ฐ)
- chain rule โญ
2๏ธโฃ LSTM
1. Problem of RNN
โ ์ฒซ๋ฒ์งธ ๋ฌธ์ : Vanishing/Exploding Gradients
๐ ๊ธฐ์ธ๊ธฐ ์์ค
- BPTT ์ฐ์ฐ์์ loss function ์ Wh ์ ๋ํ ๋ฏธ๋ถ์ฐ์ฐ์, hidden state ์ฌ์ด์ ํธ๋ฏธ๋ถ ์ฐ์ฐ ๊ณผ์ ์์ ๊ธฐ์ธ๊ธฐ ์์ค/ํญ์ฆ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ค.
- BPTT : ๊ฐ ๋ ์ด์ด๋ง๋ค ๋์ผํ ์์น์ weight ์ ํด๋นํ๋ ๋ชจ๋ error ๋ฏธ๋ถ๊ฐ์ ๋ํ๋ค์ ์ญ์ ํํ์ฌ ํ๋ฒ์ ์ ๋ฐ์ดํธํ๋ ๋ฐฉ๋ฒ
- ๋์ผํ ๊ฐ์ค์น Wh ๋ฅผ ๊ณต์ ํ๊ธฐ ๋๋ฌธ์ Wh ๊ฐ 1๋ณด๋ค ์์ ๊ฒฝ์ฐ ๋ฐ๋ณต์ ์ผ๋ก ๊ณฑํด์ง๋ ๊ฐ์ด 0์ ๊ฐ๊น์์ ธ ๊ธฐ์ธ๊ธฐ ์์ค๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ณ , Wh ๊ฐ 1๋ณด๋ค ํฐ๊ฒฝ์ฐ๋ ๊ฐ์ด ๊ธฐํ๊ธ์์ ์ผ๋ก ์ปค์ ธ ๊ธฐ์ธ๊ธฐ ํญ์ฆ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ค ๐ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ๊ฐ ์ด๋ ต๊ฑฐ๋ ๋ถ๊ฐ๋ฅํด์ง๋ค.
๐ ๊ธฐ์ธ๊ธฐ ์์ค์ด ๋ฌธ์ ๊ฐ ๋๋ ์ด์ ๋ ๐ Context ๋ฐ์์ด ์ด๋ ค์
โ ๊ฐ๊น์ด Gradient ์ ํจ๊ณผ๋ง ๋ฐ์ํ๋ค. ์ฆ, Near-effects ๋ง ๋ฐ์๋๊ณ Long-term effects ๋ ๋ฌด์๋๋ค.
โก Gradient ๋ ๋ฏธ๋ ์์ ์ ๋ํด ๊ณผ๊ฑฐ๊ฐ ์ผ๋ง๋ ์ํฅ์ ๋ฏธ์น๋์ง์ ๋ํ ์ฒ๋์ธ๋ฐ, gradient ๊ฐ ์์ค๋๋ฉด Step t ์ Step (t+n) ์ฌ์ด์ ์์กด์ฑ dependency ๊ฐ ์์ด์ 0์ด ๋ ๊ฒ์ธ์ง ํ๋ผ๋ฏธํฐ๋ฅผ ์๋ชป ๊ตฌํด 0์ด๋ ๊ฒ์ธ์ง ๊ตฌ๋ถํ ์ ์๊ฒ ๋๋ค.
๐ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๊ฐ Language Model์์ ๋ฐ์์ํค๋ ๋ฌธ์ : LM Task
๐ ๋ฉ๋ฆฌ์๋ ๋จ์ด๋ค ์ฌ์ด์ dependency ๋ฅผ ํ์ตํ์ง ๋ชปํ๋ ๋ฌธ์ ๊ฐ ๋ฐ์
- ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ก ์์ธกํ ๊ณณ๊ณผ ๊ฑฐ๋ฆฌ๊ฐ ๋จผ ์ ๋ณด๋ฅผ ์ ์ ๋ฌ๋ฐ์ง ๋ชปํด, ๋ณธ๋ ์ ๋ต์ ticket ์ด์ง๋ง, ๊ฐ์ฅ ์ต๊ทผ์ ๋์จ ๋ฌธ์ฅ์ผ๋ก ์๋ชป๋ ์ ๋ต์ธ printer ๋ก ์์ธกํจ
- ์๋ ์ ๋ต์ writer ๋ฅผ ๋ฐ๋ผ is ๊ฐ ๋์ด์ผ ํ์ง๋ง, ๋ฉ๋ฆฌ ์์นํ ๋จ์ด์ ๋ํ ์ ๋ณด๊ฐ ์ค์ด๋ค์ด ๊ฐ์ฅ ๊ฐ๊น์ด books ๋ฅผ ๋ณด๊ณ are ์ด๋ผ๊ณ ์๋ชป ์์ธกํ๋ค.
๐ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ค์ํ ๋ฐฉ๋ฒ๋ค
โ ๋ค๋ฅธ ํ์ฑํ ํจ์์ ์ฌ์ฉ
- ๋์ผํ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํจ์ผ๋ก์จ ๋ฐ์ํ๋ ๊ธฐ์ธ๊ธฐ ๋ฌธ์ ๋ฅผ ์ํํ๊ธฐ ์ํด RNN ์ hidden state ๋ฅผ ๋์ถํ ๋ ์ฌ์ฉํ๋ ํ์ฑํํจ์๋ก ์๊ทธ๋ชจ์ด๋๊ฐ ์๋ ํ์ดํผ๋ณผ๋ฆญ ํ์ ์ธ ํจ์๋ฅผ ์ฑํํ๋ค.
- ์๊ทธ๋ชจ์ด๋ ๋ฏธ๋ถ๊ฐ์ 0์์ 0.25 ์ฌ์ด์ ๊ฐ์ ๊ฐ์ง๋ฏ๋ก 0์ ๊ฐ๊น๊ธฐ ๋๋ฌธ์ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์์
- ๋ฐ๋ฉด, Tanh ๋ฏธ๋ถ๊ฐ์ 0๊ณผ 1์ฌ์ด์ ๊ฐ์ ๊ฐ์ง๋ฏ๋ก sigmoid ๋ณด๋จ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ์ ๋ ๊ฐํจ
- ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ๋ง๊ธฐ ์ํ ๊ฐ์ฅ ์ข์ ํ์ฑํ ํจ์๋ ReLU ์ด๋ค.
โก Gradient Clipping
- ๊ธฐ์ธ๊ธฐ ํญ์ฆ ๋ฌธ์ ์ ํด๊ฒฐ์ฑ (๋งค๋ฒ ํ์ต๋ฅ ์ ์กฐ์ ํ๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ค์)
- gradient ๊ฐ ๋๋ฌด ์ปค์ง๋ฉด SGD update step ๋ ๋๋ฌด ์ปค์ง๋ค.
- ๋๋ฌด ํฐ learning rate ์ ์ฌ์ฉํ๋ฉด ํ ๋ฒ ์ ๋ฐ์ดํธ step ์ ํฌ๊ธฐ๊ฐ ๋๋ฌด ์ปค์ ธ ์๋ชป๋ ๋ฐฉํฅ์ผ๋ก ํ์ต ๋ฐ ๋ฐ์ฐํ๋ค. ๊ทธ๋ฌ๋ learning rate ๋ฅผ ์๊ฒํ๋ฉด ํ์ต ์๋๊ฐ ๋๋ฌด ๋๋ ค์ง๋ค.
- ๐จ ํด๊ฒฐ์ฑ
: gradient ๊ฐ ์ผ์ threshold ๋ฅผ ๋์ด๊ฐ๋ฉด gradient L2 norm ์ผ๋ก ๋๋์ด์ฃผ๋ ๋ฐฉ์
- ์ผ์ ๊ฐ์ด ๋์ด๊ฐ๋ฉด gradient ๊ฐ์ ์ปคํ ํ๋ ๊ฒ!
- ์์คํจ์๋ฅผ ์ต์ํํ๊ธฐ ์ํ ๊ธฐ์ธ๊ธฐ ๋ฐฉํฅ์ ์ ์งํ์ฑ๋ก ํฌ๊ธฐ๋ง ์กฐ์ ํ๋ค.
โข In CNN, FN : ๊ธฐ์ธ๊ธฐ ์์ค์ ๋ฐฉ์งํ๊ธฐ ์ํด ๋ค์ํ Connection ๋ฐฉ๋ฒ ๊ณ ์
- ResNet ๐ Residual connection ๊ฐ๋ ์ ๋์ ํด input x ์ convolution layer ๋ฅผ ์ง๋๊ณ ๋์จ ๊ฒฐ๊ณผ๋ฅผ ๋ํด ๊ณผ๊ฑฐ์ ๋ด์ฉ์ ๊ธฐ์ตํ ์ ์๋๋ก ํ๋ค : ๊ธฐ์กด์ ํ์ตํ ์ ๋ณด ๋ณด์กด + ์ถ๊ฐ์ ์ผ๋ก ํ์ตํ ์ ๋ณด
- DenseNet ๐ Dense connection ๊ฐ๋ ์ ๋์ ํด ์ด์ layer ๋ค์ feature map ์ ๊ณ์ํด์ ๋ค์ layer ๋ค์ ์ ๋ ฅ๊ณผ ์ฐ๊ฒฐํ๋ ๋ฐฉ์ : concatenation
- HighwayNet ๐ Highway connection ๋ด๋ ์ ๋์ ํด output ์ด input์ ๋ํด ์ผ๋ง๋ ๋ณํ๋๊ณ ์ฎ๊ฒจ์ก๋์ง ํํ
๐ค ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ ๋พฐ์กฑํ ํด๊ฒฐ์ฑ ์ด โฝโฝโฝ
โ ๋๋ฒ์งธ ๋ฌธ์ : Long Term Dependency ์ฅ๊ธฐ์์กด์ฑ
- ์ํ์ค์ ๊ธธ์ด๊ฐ ๊ธด ๊ฒฝ์ฐ, ์ํ์ค ์ด๋ฐ์ ์ ๋ณด๊ฐ hidden state ๋ฅผ ๋์ถํ๋๋ฐ ๊น์ง ์ ๋ฌ๋๊ธฐ๊ฐ ์ด๋ ค์ long term dependency ๋ฅผ ์ ๋ฐ์ํ์ง ๋ชปํ๊ฒ ๋๋ค.
2. LSTM โญโญ
๐คจ gradient ๊ฐ 0์ ๊ฐ๊น์์ง๋๊ฒ ๋ฌธ์ ๋ฉด
์ ๋ณด๋ฅผ ์ ์ฅํ๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ฐ๋ก ๋์ด์ gradient ๊ฐ์ด ์์ค๋์ง ์๋๋ก ํ๋ฉด ๋์ง ์์๊น?
- RNN์ ๋ฌธ์ ๋ ๋ฉ๋ฆฌ ์๋ ์ ๋ณด๋ฅผ ์ ๋๋ก ํ์ตํ์ฌ ๊ฒฐ๊ณผ์ ๋ฐ์ํด๋ผ ์ ์๋ค๋ ์ ๐ ๋ง์ฝ ๋ฉ๋ฆฌ ์๋ ์ ๋ณด๋ฅผ ์ ์ฅํ ์ ์๋ ๋ณ๋์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ์๋ค๋ฉด ๐ฒ
โ Long Short - Term Memory RNN
- RNN ์์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ถ๋ฆฌํด ๋ฐ๋ก ์ ๋ณด๋ฅผ ์ ์ฅํจ์ผ๋ก์จ ์๋จ์ ๋ฐ์ดํฐ๋ ํจ๊ป ๊ณ ๋ คํ์ฌ ์ถ๋ ฅ์ ๋ง๋ค๊ณ ์ ํ ๋ชจ๋ธ๋ก ์ด์ ๋จ๊ณ์ ์ ๋ณด๋ฅผ memory cell ์ ์ ์ฅํ์ฌ ํ๋ ค๋ณด๋ธ๋ค.
- simple RNN ์ ๊ฒฝ์ฐ ์ค์ง hidden state ํ๋๋ง์ ๊ฐ์ง๊ณ ๋ชจ๋ time step ์ ์ ๋ณด๋ฅผ ์ ์ฅํ๊ณ ์ ํ๋ค๋ฉด, LSTM ์ ๊ฒฝ์ฐ์ ํฌ๊ฒ 2๊ฐ์ง ๊ตฌ์กฐ๋ฅผ ํตํ์ฌ RNN ๊ณผ์ ์ฐจ๋ณ์ ์ ๋๊ณ ์๋ค.
- ์ด์ ์ ๋ณด๋ฅผ ๊ณ์ํด์ ์ ๋ฌํ๋ Cell state ์ ๋ถํ์ํ ์ ๋ณด๋ฅผ ๊ฑธ๋ฌ์ฃผ๋ Gate ๊ฐ ์กด์ฌ → ํ์ฌ์์ ์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๊ณผ๊ฑฐ์ ๋ด์ฉ์ ์ผ๋ง๋ ์์์ง ๊ณฑํด์ฃผ๊ณ ๊ทธ ๊ฒฐ๊ณผ์ ํ์ฌ ์ ๋ณด๋ฅผ ๋ํด์ ๋ค์ ์์ ์ผ๋ก ์ ๋ณด๋ฅผ ์ ๋ฌํ๋ค.
โ hidden state ๋ฅผ ํตํด ๋จ๊ธฐ๊ธฐ์ต์ ์กฐ์ ํ๊ณ cell state ๋ฅผ ํตํด ์ฅ๊ธฐ๊ธฐ์ต์ ์กฐ์ ํ๋ค.
โก forget, input, output 3๊ฐ์ gate ๋ฅผ ํตํด ๋งค time step ์ cell state ์ hidden state, input ์์ ์ทจํ ์ ๋ณด์ ์์ ๊ฒฐ์ ํ๋ค.
โ ๊ฐ๋ ์ ๋ฆฌ
Cell state
- time step ์์ hidden state h ์ cell state c ๋ฅผ ๊ฐ์ง๋ค ๐จ ๊ธธ์ด๊ฐ n ์ธ ๋ฒกํฐ
- Cell ์ ์ฅ๊ธฐ๊ธฐ์ต์ ์ ์ฅํ๋ค.
- LSTM ์ cell ์ ์ ๋ณด๋ฅผ ์ง์ฐ๊ฑฐ๋, ์ฐ๊ฑฐ๋, ์ฝ์ด์ฌ ์ ์๋ค ๐จ 3๊ฐ์ gate ๋ก ๊ด๋ฆฌ
- ๊ฐ timestep ์์ gate ์ ๊ฐ ์์๋ค์ ์๊ทธ๋ชจ์ด๋ ํจ์๋ฅผ ํต๊ณผํจ์ผ๋ก์จ open(1), closed(0) ํน์ ๊ทธ ์ค๊ฐ ๊ฐ์ ๊ฐ์ง ์ ์๋ค.
Gate ๐จ ๊ธธ์ด๊ฐ n ์ธ ๋ฒกํฐ
- input gate : ๋ฌด์์ ์ธ ๊ฒ์ธ๊ฐ
- output gate : ๋ฌด์์ ์ฝ์ ๊ฒ์ธ๊ฐ
- forget gate : ๋ฌด์์ ์์ ๊ฒ์ธ๊ฐ
โ Architecture
- LSTM ์ ์ด 3๊ฐ์ง gate ๋ฅผ ๋์ด์ ์ ๋ ฅ์ ๋ณด์ ํ๋ฆ(flow) ์ ๊ฒฐ์ ํ๊ณ ์ด๋ฅผ ํตํด ์ฅ๊ธฐ๊ธฐ์ต์ ๋ณด์กดํ๋ cell state ์ ๋จ๊ธฐ๊ธฐ์ต์ ๋ณด์กดํ๋ hidden state ์ด 2๊ฐ์ง state ๋ค์ ๋ฐํํ๊ฒ ๋๋ค.
โ forget gate layer
โฝ ์ด๋ค ์ ๋ณด๋ฅผ ๋ฐ์ํ ์ง ๊ฒฐ์ ํ๋ gate
- ์ด์ ๋จ๊ณ์ cell state ์์ ์ด๋ค ์ ๋ณด๋ฅผ ์๊ณ , ์ด๋ค ์ ๋ณด๋ฅผ Cell state ์ ํ๋ ค๋ฃ์์ง ๊ฒฐ์ ํ๋ค.
- ์ ๋ ฅ ์ ๋ณด (์๋ก์ด ์ ๋ ฅ ์ํ์ค ๋จ์ด์ ์ด์ ์์ ์ hidden state) ์ ์๊ทธ๋ชจ์ด๋๋ฅผ ์ทจํ์ฌ 0~1์ฌ์ด์ ๊ฐ์ ๋ฐํํ๋ค.
- 0 : ์ ์ฅ๋ ์ด์ ์ ๋ณด๋ฅผ ์ง์ด๋ค.
- 1 : ์ ์ฅ๋ ์ด์ ์ ๋ณด๋ฅผ ๋ณด์กดํ๋ค.
๐จ x ๊ฐ๊ณผ ์ด์ ์์ ์ hidden state ๋ฅผ ์ ๋ ฅ๋ฐ์ sigmoid ํจ์๋ฅผ ํตํด 0๊ณผ 1์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ๋ค. 0์ ๊ฐ๊น์ธ์๋ก ์์ ํ ์๊ณ , 1์ ๊ฐ๊น์ธ์๋ก ์ ๋ณด๋ฅผ ๋ง์ด ํ๋ ค๋ฃ๊ฒ ๋๋ค.
โก input gate, new cell content
๐ input gate : ์ ๋ ฅ์ ๋ณด์ ์๊ทธ๋ชจ์ด๋๋ฅผ ์ทจํ๋ค.
- ์๋ก์ด ์ ๋ณด๊ฐ Cell state ์ ์ ์ฅ๋ ์ง ๊ฒฐ์ ํ๋ Gate
- input ์ ์ค์ํ ์ ๋ณด๊ฐ ์๋ค๋ฉด ์ด์ ๋จ๊ณ์ cell state ์ ๋ฃ๋ ์ญํ ์ ์ํํ๋ค.
- i(t) ๊ฐ์ด 1์ ๊ฐ๊น๊ฒ ๋์ค๋ฉด ์๋ก์ด ์ ๋ณด์ ๋ํด ๋ฐ์์ ๋ง์ด ํ๊ณ , 0์ ๊ฐ๊น์ธ์๋ก ์๋๋ค.
๐ new cell content : ์ ๋ ฅ์ ๋ณด์ tanh ๋ฅผ ์ทจํ๋ค.
๐จ sigmoid layer (input gate) : ์ด๋ค ๊ฐ์ ์ ๋ฐ์ดํธํ ์ง ๊ฒฐ์ ํ๋ค.
๐จ Tanh layer (update gate) : cell state ์ ๋ํด์ง ๋ฒกํฐ๋ฅผ ๋ง๋ ๋ค.
โข Update Cell state
- t ๋ฒ์งธ timestep ๊น์ง์ ์ฅ๊ธฐ๊ธฐ์ต์ ์ ๋ฌํ๋ state
๐ forget gate : ์ด์ ์์ ์ cell state ์์ ์ด๋์ ๋์ ์ ๋ณด๋ฅผ ๊ฐ์ ธ๊ฐ๊ฑด์ง ๊ฒฐ์ = ๊ณผ๊ฑฐ์ ์ ๋ณด๋ฅผ ๋ฐ์ํ ์ง์ ์ ๋ฌด
๐ input gate : ์ ๋ ฅ ์ ๋ณด์์ ์ฅ๊ธฐ ๊ธฐ์ต์ผ๋ก ๊ฐ์ ธ๊ฐ ์ ๋ณด์ ์์ ๊ฒฐ์ = ์๋กญ๊ฒ ๋ฐ์ํ ํ์ฌ์ ์ ๋ณด๋ฅผ ๋ฐ์ํ ์ง์ ์ ๋ฌด
- f(t) * C(t-1) : ์ด์ ์์ ์ cell state ์ ๋ง๋์ ์ ์์ ์ผ๋ก๋ถํฐ ์ ๋ฌ๋์ด์จ ์ฅ๊ธฐ๊ธฐ์ต์ผ๋ก๋ถํฐ ์ด๋์ ๋์ ์ ๋ณด๋์ ๊ฐ์ ธ๊ฐ ๊ฒ์ธ์ง ๊ฒฐ์ ํ๋ค.
- i(t) * ~C(t) : ์๋กญ๊ฒ ์ ๋ ฅ๋ ๊ฐ์์ ์ฅ๊ธฐ๊ธฐ์ต์ผ๋ก ๊ฐ์ ธ๊ฐ ์ ๋ณด์ ์์ ๊ฒฐ์ ํ๋ค.
⇒ ์ด ๋ ๊ฐ์ด ๋ํด์ ธ ๋ค์ cell state ์ ์ ๋ ฅ์ผ๋ก ๋ค์ด๊ฐ๋ค (update)
๐ง forget gate ๊ฐ 1์ด๊ณ input gate ๊ฐ 0์ผ ๋ Cell ์ ๋ณด๊ฐ ์์ ํ ๋ณด์กด๋์ด ์ฅ๊ธฐ์์กด์ฑ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์ง๋ง ์ฌ์ ํ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ์ ์์ ํ ํด๊ฒฐ์ ๋ณด์ฅํ์ง ์๋๋ค.
โฃ Output gate, hidden state
๐ output gate : ์ ๋ ฅ ์ ๋ณด์ ์๊ทธ๋ชจ์ด๋๋ฅผ ์ทจํ๋ค.
๐ hidden state: ํ์ฌ ์ ๋ ฅ๊ณผ ๋๋นํ์ฌ ์ฅ๊ธฐ๊ธฐ์ต Ct ์์ ์ด๋์ ๋์ ์ ๋ณด๋ฅผ ๋จ๊ธฐ๊ธฐ์ต์ผ๋ก ์ฌ์ฉํ ์ง ๊ฒฐ์ → ๋ค์ state ์ input ์ผ๋ก ๋ค์ด๊ฐ๋ค.
โ LSTM ์ด ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋๊ฑด๊ฐ
- ์ฌ๋ฌ timestep ์ด์ ์ ๊ฐ๋ ์ ์ฅํ๊ธฐ ์ฝ๊ฒ ํด์ฃผ๋ ๊ตฌ์กฐ ๐ forget gate ๊ฐ ๋ชจ๋ timestep ๊ฐ์ ์ ์ฅํ๋๋ก ์ค์ ๋๋ค๋ฉด cell ๋ด๋ถ์ ์ ๋ณด๋ ๊ณ์ ์ ์ฅ๋๋ค.
- LSTM ๋ ๊ธฐ์ธ๊ธฐ ์์ค/ํญ์ฆ ๋ฌธ์ ๊ฐ ์๋ค๊ณ ๋ณด์ฅํ ์๋ ์์ผ๋, ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํผํ๊ธฐ ์ฌ์ด ๋ชจ๋ธ์ธ ๊ฒ์ ๋ง๋ค.
3๏ธโฃ GRU
1. Gated Recurrent Units
โ idea
- LSTM ์ ๊ตฌ์กฐ๋ฅผ ์กฐ๊ธ ๋ ๋จ์ํ์ํค์
- LSTM ์ ๋ณํ๋ ๋ชจ๋ธ๋ก Reset Gate (rt) ์ Update Gate (zt) 2๊ฐ์ gate ๋ก ์ด๋ฃจ์ด์ ธ ์๋ค.
- GRU ์ ๊ณ์ฐ ์๋๊ฐ LSTM ๋ณด๋ค ๋น ๋ฅด๊ณ ํ๋ผ๋ฏธํฐ ์๊ฐ ๋ ์ ๋ค. LSTM ์ ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ ์ ๋ํ ํด๊ฒฐ์ฑ ์ ์ ์งํ๋ฉด์ ์๋์ํ๋ฅผ ์ ๋ฐ์ดํธ ํ๋ ๊ณ์ฐ๋์ ์ค์ธ ๊ฒ์ด๋ค.
โ Architecture
๐ Reset gate : ์ด์ ์ ๋ณด๋ฅผ ์ด๋ ์ ๋ ๋ฐ์ํ ๊ฒ์ธ์ง ๊ฒฐ์ rt
๐ Update gate : ๊ณผ๊ฑฐ์ ํ์ฌ์ ์ ๋ณด ๋ฐ์ ๋น์ค์ ๊ฒฐ์ Zt
- Zt : LSTM ์ forget gate ์ input gate ๋ฅผ ํฉ์น ๋ถ๋ถ
๐จ LSTM ์ forget ์ ์ญํ ์ด rt ์ zt ๋ ๊ฐ๋ก ๋๋์ด์ง ๊ตฌ์กฐ
๐จ ์ถ๋ ฅ๊ฐ (ht) ๋ฅผ ๊ณ์ฐํ ๋ ์ถ๊ฐ์ ์ธ ๋น์ ํ ํจ์๋ฅผ ์ ์ฉํ์ง ์๋๋ค.
2. LSTM vs GRU
- ์ฅ๊ธฐ๊ธฐ์ต์ ์ข๋ค.
- GRU ๋ ํ๋ผ๋ฏธํฐ๊ฐ LSTM ๋ณด๋ค ์ ์ด ํ์ต ์๋๊ฐ ๋น ๋ฅด๊ณ , LSTM ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ ํํ๊ธฐ ์ข๋ค. ๊ธฐ๋ณธ ๋ชจ๋ธ๋ก LSTM ์ ์ฌ์ฉํด๋ณด๊ณ ํจ์จ์ฑ์ ์ํ๋ฉด GRU ๋ฅผ ์๋ํด๋ณด๋ฉด ์ข์๋ฏ!
4๏ธโฃ More Fancy RNNs
1. Bidirectional RNNs
โ ์ ์
- ๊ณผ๊ฑฐ ์ํ ๋ฟ ์๋๋ผ ๋ฏธ๋์ ์ํ๊น์ง ๊ณ ๋ คํ๋ ๋ชจ๋ธ
- ์๋ฐฉํฅ ์ ๋ณด๋ฅผ ๋ชจ๋ ์ด์ฉํ๊ธฐ ์ํ RNN ๊ตฌ์กฐ
โ Architecture
๐จ ์๋ฐฉํฅ RNN ์ ํ๋์ ์ถ๋ ฅ์ ์์ธกํ๋ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ ๋ ๊ฐ๋ฅผ ์ฌ์ฉํ๋ค.
- Forward RNN : ์ ๋ฐฉํฅ์ผ๋ก ์ ๋ ฅ๋ฐ์ hidden state ๋ฅผ ์์ฑ
- Backward RNN : ์ญ๋ฐฉํฅ์ผ๋ก ์ ๋ ฅ๋ฐ์ hidden state ๋ฅผ ์์ฑ
- ๋ ๊ฐ์ hidden state ๊ฐ ์๋ก ์ฐ๊ฒฐ๋์ด ์์ง๋ ์๊ณ , ์ ๋ ฅ๊ฐ์ ๊ฐ hidden layer ์ ์ ๋ฌํ๊ณ output layer ๋ ์ด ๋ hidden layer ๋ก๋ถํฐ ๊ฐ์ ๋ฐ์์ ์ต์ข Output ์ ๊ณ์ฐํ๋ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ค.
โ Note
- ์๋ฐฉํฅ RNN ์ ์ ์ฒด entire input sequence ๊ฐ ์ฃผ์ด์ก์ ๋๋ง ์ ์ฉํ ์ ์๋ค. Language modeling ์ ๊ฒฝ์ฐ์๋ left context ๋ง ์ ๊ทผ ๊ฐ๋ฅํ๋ฏ๋ก ์ ์ฉํ ์ ์๋ค.
- entire input seqence ๊ฐ ์๋ค๋ฉด bidirectionality ๋ ๋ ๊ฐ๋ ฅํ๊ฒ ์์ฉํ๋ค.
- ๐ BERT ๐ Bidirectional Encoder Representations from Transformers : powerful pretrained contextual representation system built on bidirectionality
2. Multi-layer RNNs
โ ์ ์
- RNN ์ ์ฌ๋ฌ ์ธต์ผ๋ก ์ฌ์ฉํ ๋ชจ๋ธ
- ๋ ๋ณต์กํ ํํ๊ณผ ํน์ฑ์ ํ์ตํ ์ ์๋๋ก ํ๋ค. ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋๋ฌธ์ ์ผ๋ฐ์ ์ผ๋ก 2~4๊ฐ์ ์ธต์ ์์ ์ฌ์ฉํ๋ค. (๋ ๋ง์ ์ธต์ ์๋ ๊ฒ์ ๊ถ์ฅํ์ง ์์)
- BERT ์ ๊ฐ์ด Transfoermer-based network ์ ๊ฒฝ์ฐ์ 24๊ฐ์ layer ๊น์ง ์์ ์ ์๋ค.
๐ ์ค์ต ์๋ฃ
โ LSTM
-
โก GRU
-
โข Fancy RNNs - bidirectional, multi layer
- https://wikidocs.net/22886 ๐ model.add(Bidirectional(SimpleRNN(hidden_units, return_sequences=True), input_shape=(timesteps, input_dim)))
'1๏ธโฃ AIโขDS > ๐ NLP' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[cs224n] 9๊ฐ ๋ด์ฉ ์ ๋ฆฌ (0) | 2022.05.09 |
---|---|
[cs224n] 8๊ฐ ๋ด์ฉ ์ ๋ฆฌ (0) | 2022.05.09 |
[cs224n] 6๊ฐ ๋ด์ฉ ์ ๋ฆฌ (0) | 2022.03.24 |
[cs224n] 5๊ฐ ๋ด์ฉ ์ ๋ฆฌ (0) | 2022.03.22 |
[cs224n] 4๊ฐ ๋ด์ฉ ์ ๋ฆฌ (0) | 2022.03.18 |
๋๊ธ