Implementation/Text

[RNN] 파라미터 개수 카운팅

Eric_Park 2021. 9. 21. 13:55

Vanilla RNN 기본구조

 

 

RNN 내부 파라미터 구조

 

# Simple RNN 파라미터 참고 
# => output_dim, ( time-length or setence-length, input_dim) 
# => Dh, (t, d) 

from keras.models import Sequential
from keras.layers import SimpleRNN

model = Sequential()
model.add(SimpleRNN(3, input_shape=(2,10)))
# model.add(SimpleRNN(3, input_length=2, input_dim=10))와 동일함.
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn_1 (SimpleRNN)     (None, 3)                 42        
=================================================================
Total params: 42
Trainable params: 42
Non-trainable params: 0
_________________________________________________________________

- 계산

위의 예제의 경우 

Dh = 3

t = 2 ( RNN 의 특성상 모든 시점에 히든 스테이트를 공유하므로, time 은 변수의 개수에 관계없다) 

d = 10 

이므로, 아래 계산과정으로 파라미터의 수를 카운팅할 수 있다. 

 

# of params = (Dh * Dh) + (Dh * d) + (Dh)

                 = (3 * 3) + (3 * 10) + (3) 

                 = 42 

 

 

 

 

 

#ref : https://wikidocs.net/22886