if self.pad != 0 and self.stride == 1:
x = merge(x1, x2)
x = self.inj_pad.inverse(x)
x1, x2 = split(x)
x = (x1, x2)
else:
x = (x1, x2)
return x
After Change
if self.pad != 0 and self.stride == 1:
x = torch.cat((x1, x2), dim=1)
x = self.inj_pad.inverse(x)
x1, x2 = torch.chunk(x, chunks=2, dim=1)
x = (x1, x2)
else:
x = (x1, x2)
return x