def tensordot(x, y, axes=2):
backend = _CURRENT_BACKEND[0]
cache = _SHARING_STACK[-1]
cache["tensor", id(x)] = x
cache["tensor", id(y)] = y
if isinstance(axes, numbers.Number):
axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]
After Change
def tensordot(x, y, axes=2):
_save_tensors(x, y)
// hash based on the (axes_x,axes_y) form of axes
if isinstance(axes, numbers.Number):
axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]