import tensorflow as tf
image = tf.constant([[2,3,4],[5,6,7]])
i = tf.expand_dims(image, axis=0)
i
<tf.Tensor: shape=(1, 2, 3), dtype=int32, numpy=
array([[[2, 3, 4],
[5, 6, 7]]])>
i = tf.expand_dims(image, axis=1)
i
<tf.Tensor: shape=(2, 1, 3), dtype=int32, numpy=
array([[[2, 3, 4]],
[[5, 6, 7]]])>
i = tf.expand_dims(image, axis=2)
i
<tf.Tensor: shape=(2, 3, 1), dtype=int32, numpy=
array([[[2],
[3],
[4]],
[[5],
[6],
[7]]])>