def test_output(self, test_nets):
""" Checks that the output of the net is not zero or nan """
net = test_nets[0]
dummy_input = self.init_dummy_input(net)dummy_output = self.init_dummy_output(net)
out = net(dummy_input)
flag = True
if net.__class__.__name__.find("MultiMLPNet") != -1:
zero_test = sum([torch.sum(torch.abs(x.data)) for x in out])
nan_test = np.isnan(sum([torch.sum(x.data) for x in out]))
else:
zero_test = torch.sum(torch.abs(out.data))
nan_test = np.isnan(torch.sum(out.data))
if zero_test < SMALL_NUM: