def clip_grad_norm(self, max_norm):
Clips gradient norm.
if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else:
return torch.sqrt(sum(p.grad.data.norm()**2 for p in self.params if p.grad is not None))
def step(self, closure=None):
After Change
def clip_grad_norm(self, max_norm):
Clips gradient norm.
return utils.clip_grad_norm_(self.params, max_norm)
def step(self, closure=None):
Performs a single optimization step.
self.optimizer.step(closure)