if not self.is_batch:
x1 = self.x1.unsqueeze(0)
x2 = self.x2.unsqueeze(0)
else:
x1 = self.x1
x2 = self.x2
After Change
// TODO: until kernels support multi-batch mode, we have to remove any artificial batch dimension
// that we"ve added
if not is_batch:
x1 = x1.squeeze(0)
x2 = x2.squeeze(0)
res = res.squeeze(0)
// Did this Kernel eat the diag option?