Home [DSLAB] LSTM, GRU
Post
Cancel

[DSLAB] LSTM, GRU

앞서 설명한 Vanishing/Exploding Gradient Problem을 일으키는 Long-Term-DependencyRNN의 고질적인 문제이다. 이를 해결하기 위해 나온 모델이 바로 LSTM이다. LSTM장단기 메모리(Long Short-Term Memory)의 약자로 단기 기억으로 저장하여 상황에 맞는 정보를 꺼내 사용함으로써 장기 기억을 할 수 있도록 만든 모델이다.

LSTM


이미지

기본적으로 LSTMRNN과 다르게 Cell state인 Ct가 존재한다. 이 Cell state는 중요한 핵심 정보들을 담아두고 hidden state와 비슷하게 이전 시점의 Cell state가 다음 시점의 Cell state를 구하기 위한 입력으로 사용된다. 이 Cell statehidden state를 구하기 위해서는 위 그림에 나온것 과 같이 여러개의 게이트를 사용한다. 이러한 게이트들은 전 time step에서 넘어온 Cell state벡터 Ct-1를 적절하게 변환하는데 사용이 된다.

  • Forget Gate

이미지

Forget gate는 직전 hidden state인 ht-1과 xt를 선형결합을 한 뒤 sigmoid함수를 통과해서 나온 벡터를 나타낸다. sigmoid함수는 0~1까지의 값을 가지게 되므로 f벡터(Forget gate)와 Ct-1과 곱하게 되면 기존에 가지고 있던 정보들을 몇퍼센트의 비율로 버릴지 결정하게 되는것이다.

  • Input Gate, Gate Gate

이미지
이미지

이 단계는 Input gate와 Gate gate가 함께 쓰이는 단계로 새로 들어온 정보가 Cell state에 저장될지를 결정하는 단계이다. Input gate는 forget gate와 비슷하게 ht-1과 xt를 선형결합하여 sigmoid함수를 거쳐 나온 값으로 어떤값을 업데이트 할지 결정을 하게 되고 Gate gate는 tanh함수를 거쳐 Cell state에 더해질 수 있는 새로운 후보값을 만들어내는 역할을 한다. 이렇게 만들어진 gate를 서로 곱하여 Ct-1에 더해지게 된다.

  • Cell state

이미지

앞서 만들어진 forget gate와 input gate, gate gate를 이용해 새로운 Cell state로 업데이트가 이루어진다. forget gate에서 만들어진 벡터와 Ct-1이 곱해지고 input gate와 gate gate를 곱한 벡터를 Ct-1에 더해준다. 이는 ft를 통해 버리기로한 데이터를 버리고 it와 gt를 통해 새로운 후보값을 기존의 Ct-1의 값에 영향을 주게 된다.

  • Output gate

이미지
이미지

마지막으로 Output gate는ht-1과 xt를 곱해 sigmoid함수를 거친 벡터를 통해 Cell state에 있는 정보를 얼마나 hidden state에 사용해야할지를 결정한다. 최종적으로 출력값을 결정할 때 tanh 함수를 거친 Ct를 output gate와 곱해 원하는 결과값만 ht와 yt에 반영할 수 있도록 만든다.


이처럼 LSTM은 RNN과 다르게 각 time step마다 필요한 정보를 단기적으로 hidden state에 저장해 관리되도록 학습한다. 또한 역전파를 진행할 때 가중치 W값을 RNN처럼 계속 곱하는게 아니라 만들어진 새로운 정보를 Cell state에 더하게 되어 Vanishing/Exploding Gradient문제를 방지하게 된다.

GRU


이미지

GRU(Gated Recurrent Unit)은 LSTM과 비슷하지만 LSTM의 모델구조를 경량화 해서 보다 적은 메모리 사용와 빠른 계산을 가능하도록 만든 모델이다. LSTM에서 존재하는 Cell state벡터와 hidden state벡터를 일원화시키고 Forget gate와 Input gate를 하나의 업데이트 게이트로 통일시켜 기능은 비슷하지만 hidden state로만 계산을 할 수 있도록 한 것이 GRU의 큰 특징이다.

이미지

Reset gate인 rt는 현재 상태에서 얼마나 이전의 정보를 유지할지 결정하는 게이트이고 Input gate인 zt는 이전 상태 정보와 새로운 정보를 가져오는 것 사이의 균형을 결정하는 역할을 하는 게이트이다. 또한 새로운 hidden state ht를 결정할 때 LSTM과 다르게 input gate에 해당하는 zt를 ht-1과 ht 두 정보간의 가중 평균을 내는 형태로 계산된다.

This post is licensed under CC BY 4.0 by the author.