d1_z_grads = torch_grad(outputs=loss, inputs=z, retain_graph=True)
d1_norm = [torch.norm(_d1_grads.view(-1).cuda(),p=2,dim=0) for _d1_grads in d1_grads]
d1_z_norm = [torch.norm(_d1_grads.reshape(-1).cuda(),p=2,dim=0) for _d1_grads in d1_z_grads]
reg_d1 = [((_d1_norm**2).cuda()) for _d1_norm in d1_norm]
reg_d1 += [((_d1_norm**2).cuda()) for _d1_norm in d1_z_norm]
After Change
def regularize_gradient_norm(self, calculate_loss):
x = Variable(self.x, requires_grad=True).cuda()
d1_logits = self.discriminator(x, context={"z":self.z})
d2_logits = self.d_fake
loss = calculate_loss(d1_logits, d2_logits)