""" Checks that the output of the net is not zero or nan """
net = test_nets[0]
if type(net.in_dim) is int:
dummy_input = Variable(torch.ones(2, net.in_dim))
else:
dummy_input = Variable(torch.ones(2, *net.in_dim))
out = net(dummy_input)
flag = True
After Change
flag = True
if net.__class__.__name__.find("MultiMLPNet") != -1:
zero_test = sum([torch.sum(torch.abs(x.data)) for x in out])
nan_test = np.isnan(sum([torch.sum(x.data) for x in out]))
else:
zero_test = torch.sum(torch.abs(out.data))
nan_test = np.isnan(torch.sum(out.data))
if zero_test < SMALL_NUM: