자연어처리-2 RNN언어모델에 대해서 알아보자!
앞에서 NNLM에 대해서 간단하게 살펴보았다.
N-gram 언어모델과 NNLM은 고정된 개수의 단어만 입력으로 받을 수 있다는 단점이 있었다.
지금 살펴볼 RNNLM은 timestep이라는 개념이 도입된 RNN으로 언어 모델을 만든 것이다.
예측과정
RNNLM은 기본적으로 예측 과정에서 이전 시점의 출력을 다음 시점의 입력으로 한다.
이게 무슨말이냐면.. what을 입력받으면 will을 예측하고 예측한 will은 다음시점의 입력이 되어 the를 예측한다. the는 또 그 다음시점의 입력이 되어 fat을 예측하게 된다.
결과적으로 $y_3$의 fat은 what, will, the의 시퀀스로 인해 결정된 단어이고, $y_4$의 cat은 what, will, the, fat의 시퀀스로 인해 결정된 단어이다.
RNN에서는 메모리 셀이라는 곳에 이전 정보를 저장해서 참고하기 때문이다.
쉽게 말하자면 will을 입력받은 메모리 셀에는 what과 will의 정보가 들어있고 그 정보를 바탕으로 the를 예측하는 것이다.
훈련과정
RNNLM의 훈련과정에서는 예측 과정에서 하나하나 넣으면서 진행하지 않는다.
what, will, the, fat 시퀀스를 모델의 입력으로 넣고, will, the, fat, cat을 예측하도록 훈련한다. 여기서 will, the, fat, cat은 각 시점의 레이블이 된다.
이러한 RNN의 훈련 기법을 교사 강요 라고 한다.
교사 강요란 테스트 과정에서 $t$시점의 출력이 $t+1$ 시점의 입력으로 사용되는 RNN 모델을 훈련시킬 때 사용하는 기법이다.
위 그림을 보면 훈련 과정 동한 출력층에서 사용하는 활성화 함수는 softmax함수이다.
모델이 예측한 값과 실제 레이블과의 오차를 계산하기 위해서 cross entropy 손실함수를 사용한다.
각 시점에서 시퀀스가 입력되면 Embedding Layer를 통해 임베딩 벡터로 변환되고, 은닉층에서 이전 시점의 은닉 상태와 tanh 연산을 수행하여 마지막 은닉층으로 전달하여 예측값을 계산하는 방식이다.
Embedding Layer는 저번 NNLM에서 소개했던 linear hidden Layer 부분이다. 그 부분에서 임베딩 벡터를 구하기 때문에 Embedding Layer이고, 이는 keras.layers에 구현되어 있다. 또한 RNN은 활성화 함수로 tanh를 사용한다.
전체 구조
위 그림은 훈련의 전체 구조를 나타낸 것이다.
하나씩 살펴보도록 하자. 우선 input layer에 원-핫 벡터가 주어진다. 그 다음은 Embedding layer를 거져서 임베딩 벡터가 구해진다. 이 임베딩 벡터는 은닉층에서 이전 시점의 은닉상태인 $h_{t-1}$과 함께 연산되어 현재 시점의 은닉 상태 $h_t$를 계산하게 된다. 그림에서 초록색 박스 부분이 RNN layer이다. timestep에 맞게 반복해서 최종 시점인 $h_t$를 계산하는 것이다. 계산된 값을 softmax함수를 통해 나온 벡터와 레이블 벡터의 오차값을 구하는 것이다.
학습 파라미터
그렇다면 모델은 무엇을 학습해야 할까?
우선 임베딩 벡터의 식을 정리해보자.
$e_t = embedding(x_t)$ 이렇게 정의된다. 식이 간단한데, 임베딩 벡터도 가중치가 있다는 것을 알아두자.
이제 은닉층의 식을 정리해보자.
$h_t = tanh(W_xe_t + W_hh_{t-1} + b)$
위의 식을 보면 입력층으로 $e_t$가 들어오고 입력층에 대한 가중치 $W_x$가 곱해진다. 그 다음으로 은닉층에 대한 가중치 $W_h$와 이전 은닉상태인 $h_{t-1}$의 값이 곱해진다. 이 두개를 더하고 편향까지 더해준 다음 tanh에 넣어주면 $h_t$가 계산된다.
다음은 출력층의 식을 정리해보자.
$\hat{y_t} = softmax(W_yh_t + b)$
$\hat{y_t}$은 출력층에 대한 가중치 $W_y$와 은닉 상태의 $h_t$를 곱하고 편향을 더해서 나온 결과에 softmax함수를 넣은 결과이다.
결과적으로 임베딩 벡터을 $E$라고 했을때, 학습되는 가중치는 $E$, $W_x$, $W_h$, $W_y $ 4개가 된다.
참고자료
실습을 원한다면, RNN을 이용한 단어단위 텍스트 생성, RNN을 이용한 문자단위 텍스트 생성을 참고바란다.
댓글남기기