// TODO: `keras_shape` inference.
pattern = tuple(pattern)
return x.dimshuffle(pattern)
def repeat_elements(x, rep, axis):
Repeat the elements of a tensor along an axis, like np.repeat.
After Change
dimension indices, e.g. [0, 2, 1].
pattern = tuple(pattern)
y = x.dimshuffle(pattern)
if hasattr(x, "_keras_shape"):
y._keras_shape = tuple(np.asarray(x._keras_shape)[list(pattern)])
return y
def repeat_elements(x, rep, axis):
Repeat the elements of a tensor along an axis, like np.repeat.