self, states_t, actions_t, rewards_t, states_tp1, done_t
):
gammas = (self._gammas**self._n_step)[None, :, None]
// 1 x num_heads x 1
done_t = done_t[:, None, :] // B x 1 x 1
rewards_t = rewards_t[:, None, :] // B x 1 x 1
actions_t = actions_t[:, None, None, :] // B x 1 x 1 x 1
indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
// B x num_heads x 1 x num_atoms
// critic loss (quantile regression)
atoms_t = self.critic(states_t).gather(-2, indices_t).squeeze(-2)
// B x num_heads x num_atoms
all_atoms_tp1 = self.target_critic(states_tp1).detach()
After Change
return value_loss
def _quantile_loss(
self, states_t, actions_t, rewards_t, states_tp1, done_t
):
gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)
actions_t = actions_t[:, None, None, :] // B x 1 x 1 x 1
// B x num_heads x 1 x num_atoms
indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
// B x num_heads x num_actions x num_atoms
q_atoms_t = self.critic(states_t)
// B x num_heads x num_atoms
atoms_t = q_atoms_t.gather(-2, indices_t).squeeze(-2)
// B x num_heads x num_actions x num_atoms
q_atoms_tp1 = self.target_critic(states_tp1).detach()
// B x num_heads x num_actions
q_values_tp1 = q_atoms_tp1.mean(dim=-1)
// B x num_heads x 1
actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
// B x num_heads x 1 x num_atoms
indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
// B x num_heads x num_atoms
atoms_tp1 = q_atoms_tp1.gather(-2, indices_tp1).squeeze(-2)
atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1
value_loss = utils.quantile_loss(
atoms_t.view(-1, self.num_atoms),
atoms_target_t.view(-1, self.num_atoms),
self.tau, self.num_atoms,
self.critic_criterion
)
if self.entropy_regularization is not None:
q_values_t = torch.mean(q_atoms_t, dim=-1)
value_loss -= \
self.entropy_regularization * self._compute_entropy(q_values_t)
return value_loss
def update_step(self, value_loss, critic_update=True):
// critic update