loss = 0
if self._sample is not None:
self._sample = self.criterion.prepare(self.model, self._sample)
net_output = self.model(**self._sample["net_input"])loss_ = self.criterion(net_output, self._sample)if grad_denom is not None:
loss_ /= grad_denom
loss_.backward()
loss = loss_.data[0]
// flatten grads into a contiguous block of memory
if self.flat_grads is None:
self.flat_grads = self._flatten_grads_(self.model)
// all-reduce grads
nccl.all_reduce(self.flat_grads)
// clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
// take an optimization step
self.optimizer.step()
return loss, grad_norm
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
flat_grads = next(model.parameters()).data.new(num_params)
After Change
// calculate loss and grads
loss = 0
loss_dict = {}
if self._sample is not None:
loss_dict = self.criterion(self.model, self._sample, grad_denom)
loss_dict["loss"].backward()
loss = loss_dict["loss"].data[0]
// flatten grads into a contiguous block of memory
if self.flat_grads is None:
self.flat_grads = self._flatten_grads_(self.model)
// all-reduce grads
nccl.all_reduce(self.flat_grads)
// clip grads
loss_dict["gnorm"] = self._clip_grads_(self.flat_grads, self.args.clip_norm)
// take an optimization step
self.optimizer.step()
return loss_dict
def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
flat_grads = next(model.parameters()).data.new(num_params)