Implementation/Text

[LSTM] return_sequence = True or False

Eric_Park 2021. 10. 5. 13:42
### dec_hidden 값을 decoder 즉, lstm layer 에 통과시켰을 때, 
# return_sequence = True or False 에 따른 변경. 

dec_hidden.shape
>>> (2,10,4) 

# return_sequence = False 일 때, 
dec_lstm = tf.keras.layers.LSTM(units= 5, return_sequence=False) 
dec_hidden = dec_lstm(dec_hidden, initial_state=[enc_h_state, enc_c_state], mask=dec_mask) 
dec_hidden.shape
>>> (2,5)


# return sequence = True 일 때, 
dec_lstm = tf.keras.layers.LSTM(units= 5, return_sequence=True) 
dec_hidden = dec_lstm(dec_hidden, initial_state=[enc_h_state, enc_c_state], mask=dec_mask) 
dec_hidden.shape
>>> (2,10,5) 

# time step 에 따른 hidden state 가 함께 출력되었음을 확인할 수 있다.