shape1 = list(input_shapes[0])
shape2 = list(input_shapes[1])
dot_axes = [a - 1 for a in self.dot_axes]
tensordot_output = np.tensordot(np.zeros(tuple(shape1[1:])),
np.zeros(tuple(shape2[1:])),
axes=dot_axes)
if len(tensordot_output.shape) == 0:
shape = (1,)
else:
shape = tensordot_output.shape
return (shape1[0],) + shape
def compute_mask(self, inputs, mask=None):
if mask is None or all([m is None for m in mask]):
return None