Takes a single training step: one forward and one backwards pass. Both x and y are lists of the same length, one x and y per environment
"""
self.train()
self.zero_grad()
self.optim.zero_grad()
if loss is None:
outs = self(xs)
total_loss = torch.tensor(0.0, device=self.device)
for out, y in zip(outs, ys):
loss = self.loss_fn(out, y)
total_loss += loss
loss = total_loss
assert not torch.isnan(loss).any(), loss
if net_util.to_assert_trained():
assert_trained = net_util.gen_assert_trained(self)
loss.backward(retain_graph=retain_graph)
if self.clip_grad:
logger.debug(f"Clipping gradient: {self.clip_grad_val}")
nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
if global_net is None:
self.optim.step()
else: // distributed training with global net
net_util.push_global_grad(self, global_net)
self.optim.step()
net_util.pull_global_param(self, global_net)
if net_util.to_assert_trained():
assert_trained(self, loss)
self.store_grad_norms()
logger.debug(f"Net training_step loss: {loss}")
After Change
"""
Takes a single training step: one forward and one backwards pass. Both x and y are lists of the same length, one x and y per environment
"""
self.lr_scheduler.step(epoch=lr_t)
self.train()
self.optim.zero_grad()
if loss is None:
outs = self(xs)