self.append_input("A", A)
self.append_input("B", B)
return self.exec()
def exec(self):
A = self.inputs["A"]
B = self.inputs["B"]
After Change
self.parameters["axes"] = _normalize_axes(axes)
def __call__(self, A: Variable, B: Variable):
for axis in self.axes[0]:
assert axis in A.order.axes, f
[Tensordot] Input variable "A" must have axes "{axis}":
(op) = {self}
(op.axes[0]) = {self.axes[0]}
(A) = {A}
for axis in A.order.axes:
if axis not in self.axes[0]:
assert axis in self.axes[1] or axis not in B.order.axes, f
[Tensordot] Axes of "A" which are not reduced must not be contained in "B":
(op) = {self}
(A.order.axes) = {A.order.axes}
(B.order.axes) = {B.order.axes}
(op.axes) = {self.axes}
for axis in self.axes[1]:
assert axis in B.order.axes, f
[Tensordot] Input variable "B" must have axes "{axis}":
(op) = {self}
(op.axes[1]) = {self.axes[1]}
(B) = {B}
for axis in B.order.axes:
if axis not in self.axes[1]:
assert axis in self.axes[0] or axis not in A.order.axes, f
[Tensordot] Axes of "B" which are not reduced must not be contained in "A":
(op) = {self}
(A.order.axes) = {A.order.axes}
(B.order.axes) = {B.order.axes}
(op.axes) = {self.axes}
reduction_size_a = mul(A.shape_dict[a] for a in self.axes[0])
reduction_size_b = mul(B.shape_dict[a] for a in self.axes[1])
assert reduction_size_a == reduction_size_b, f
[Tensordot] Reduction size of "A" and "B" must be same:
(A) = {A}
(B) = {B}
(axes) = {self.axes}
(reduction size of A) = {reduction_size_a}
(reduction size of B) = {reduction_size_b}
c_shape_dict = AxisKeyDict()
for axis in A.order.axes:
if axis not in self.axes[0]:
c_shape_dict[axis] = A.shape_dict[axis]
for axis in B.order.axes:
if axis not in self.axes[1]:
c_shape_dict[axis] = B.shape_dict[axis]
C = Variable(list(c_shape_dict.values()), Order(list(c_shape_dict.keys())))
for axis in C.order.axes:
self.attributes.add(Tensorwise(axis))
self.append_input("A", A)
self.append_input("B", B)
self.append_output("C", C)return C,
@property
def axes(self) -> Tuple[Tuple[Axis, ...], Tuple[Axis, ...]]:
return self.parameters["axes"]