if len(x.shape) < num_dims:
return x
return tf.reshape(x, [-1] + x.shape.as_list()[num_dims:])
After Change
// Shape can"t be inferred statically.
tensor_shape = tf.shape(x)
leading_dim = tf.reduce_prod(tensor_shape[:num_dims], keepdims=True)
other_dims = tensor_shape[num_dims:]
dynamic_shape = tf.concat([leading_dim, other_dims], axis=0)
result = tf.reshape(x, dynamic_shape)
// We lose some static shape information from the above reduce/slice/concat
// dance, so we explicitly pass it in from what we computed earlier.