if axes is None:
// behaves like tf.batch_matmul as default
axes = [(x.ndim - 1,), (y.ndim - 2,)]
return T.batched_tensordot(x, y, axes=axes)
def transpose(x):
return T.transpose(x)
After Change
Tensor with ndim >= 2
"""
if type(axes) == int:
axes = (axes, axes)
if axes is None:
// behaves like tf.batch_matmul as default
axes = [x.ndim - 1, y.ndim - 2]
out = T.batched_tensordot(x, y, axes=axes)
if ndim(out) == 1:
out = expand_dims(out, 1)
return out
def transpose(x):
return T.transpose(x)