n = self.in_ch // 2
if self.init_ds != 0:
x = self.init_psi.forward(x)
out = (x[:, :n, :, :], x[:, n:, :, :])
for block in self.stack:
out = block.forward(out)
out_bij = merge(out[0], out[1])
out = F.relu(self.bn1(out_bij))
After Change
irevnet forward
x = self.init_psi.forward(x)
out = torch.chunk(x, chunks=2, dim=1)
for block in self.stack:
out = block.forward(out)
out_bij = torch.cat((out[0], out[1]), dim=1)