n_in = sum([m * (2 * l + 1) for m, l in Rs_in])
x = torch.randn(batch, n_in, input_size, input_size, input_size)
print("x Number = {} Mean = {:.3f} Std = {:.3f}".format(x.numel(), x.data.mean(), x.data.std()))
y = conv(x)
assert y.size(1) == n_out
After Change
assert y.size(1) == n_out
print("y Number = {} Mean = {:.3f} Std = {:.3f}".format(y.numel(), y.mean().item(), y.std().item()))
return y