if self.cardinality == x.shape[1]:
return x.sum(1)
x_split = list(torch.split(x, self.cardinality, dim=1))
// Check if splits have the same shape (If split cannot be made even, the last chunk will be smaller)
if x_split[-1].shape != x_split[0].shape:
// How much is the last chunk smaller
diff = x_split[0].shape[1] - x_split[-1].shape[1]
// Pad the last chunk by the difference with zeros (=maginalized nodes)
x_split[-1] = F.pad(
x_split[-1], pad=[0, 0, 0, diff], mode="constant", value=0.0
)
// Stack along new split axis
x_split_stack = torch.stack(x_split, dim=2)
// Sum over feature axis
result = torch.sum(x_split_stack, dim=1)
return result
After Change
result = F.conv2d(x, weight=self._conv_weights, stride=(self.cardinality, 1))
// Remove simulated channel
result = result.squeeze(1)
return result
def __repr__(self):
return "Product(in_features={}, cardinality={}, out_shape={})".format(