DeepLearning

RNN, LSTM 간단한 설명

jiheek 2022. 4. 11. 16:30

RNN(Recurrent Neural Network)

 

  • Hidden state: Representation of previous input. Loop! 이전 input의 상태를 기억하고 있는다.

Sequential memory인 RNN으로 sequential pattern 더 잘 이해 가능하다.

챗봇의 경우, RNN을 사용해서 input sequence of text를 인코드한다.

임베딩된 RNN 출력을 fc layer에 입력해서 classify 또는 다른 task를 수행한다.

 

what time is it - RNN 코드 예시

 

But, nature of back propagation algorithm인 short term memory & vanishing gradient 때문에 앞의 단어들(what, time)은 점점 비율이 사라진다. 첫부분의 weight들은 매우 작은 gradient로 인해 거의 조정되지 않게 된다.

 

RNN에서는 each time step을 neural network에서의 layer로 생각할 수 있다. 따라서 back propagation through time으로 이해 가능하다.

 

이 문제 해결하기 위해서 LSTM이 고안되었다.

RNN과 비슷하지만, 내부의 gate들로 long term dependency도 확보할 수 있게 되었다. Gate들은 hidden state에서 어떤 정보를 더하고 뺄 지를 배운다. LSTM으로 더 긴 sequence를 다룰 수 있게 되었다.

 

 

 


출처: https://www.youtube.com/watch?v=LHXXI4-IEns