Implementation/Python

[Python] tensorflow.expand_dims

Eric_Park 2021. 9. 28. 13:41
import numpy as np
import random
import tensorflow as tf

sp = np.random.randint(-90, 100, (2,6))
# >>> sp
# array([[  8, -67,  34, -71, -50, -69],
#        [ 56,  30,  58,  76,   1, -83]])

tf.expand_dims( sp, axis=0 )  
# <tf.Tensor: shape=(1, 2, 6), dtype=int32, numpy=
# array([[[  8, -67,  34, -71, -50, -69],
#        [ 56,  30,  58,  76,   1, -83]]])>

tf.expand_dims( sp, axis=1 )
# <tf.Tensor: shape=(2, 1, 6), dtype=int32, numpy=
# array([[[  8, -67,  34, -71, -50, -69]],
#        [[ 56,  30,  58,  76,   1, -83]]])>

tf.expand_dims( sp, axis=2)
# <tf.Tensor: shape=(2, 6, 1), dtype=int32, numpy= 
# array([[[  8],
#         [-67],
#         [ 34],
#         [-71],
#         [-50],
#         [-69]],

#        [[ 56],
#         [ 30],
#         [ 58],
#         [ 76],
#         [  1],
#         [-83]]])>

 
# axis = 0 or 1 or 2 일 때 
# axis = 0, (expanded, row, col) 
# axis = 1, (row, expanded, col) 
# axis = 2, (row, col, expanded)

# 바뀐 shape은 (batch, row, col) 로 재해석 된다.