loss = -torch.sum(m * log_ps_a, 1) // Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
loss = weights * loss // Importance weight losses before prioritised experience replay (done after for original/non-distributional version)
self.online_net.zero_grad()
loss.mean().backward() // Backpropagate minibatch loss
self.optimiser.step()
nn.utils.clip_grad_norm_(self.online_net.parameters(), self.norm_clip) // Clip gradients by L2 norm
After Change
loss = -torch.sum(m * log_ps_a, 1) // Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
self.online_net.zero_grad()
(weights * loss).mean().backward() // Backpropagate importance-weighted minibatch loss
self.optimiser.step()
mem.update_priorities(idxs, loss.detach()) // Update priorities of sampled transitions