loss = hook.forward()
if loss[1] is not None:
g_loss += loss[1]
g_loss.mean().backward()
self.g_optimizer.step()
if self.current_step % 10 == 0:
self.print_metrics(self.current_step)
After Change
for p, np in zip(self.gan.d_parameters(), d_grads):
p.grad = np
if(len(d_grads) > 0):
self.d_optimizer.step()
_, g_grads = self.calculate_gradients(["g"])