policy_loss = -torch.mean(q_values_tp0_min)
// critic loss (kl-divergence between categorical distributions)
actions_tp1 = self.target_actor(states_tp1).detach()
actions_tp1 = self._add_noise_to_actions(actions_tp1)
logits_t = [
x(states_t, actions_t).squeeze_(dim=2) for x in self.critics
]
logits_tp1 = [
x(states_tp1, actions_tp1).squeeze_(dim=2)
for x in self.target_critics
]
probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
q_values_tp1 = [
torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
]
probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)
// B x num_heads
logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
// B x num_heads x num_atoms x num_critics
// @TODO: smarter way to do this (other than reshaping)?
probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
logits_tp1 = logits_tp1.view(-1, self.num_atoms, self._num_critics)logits_tp1 = \
logits_tp1[range(len(logits_tp1)), :, probs_ids_tp1_min].\
view(-1, self._num_heads, self.num_atoms).detach()
atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
value_loss = [
utils.categorical_loss(
After Change
// critic loss (kl-divergence between categorical distributions)
// [bs; action_size]
actions_tp1 = self.target_actor(states_tp1)
actions_tp1 = self._add_noise_to_actions(actions_tp1).detach()
// {num_critics} * [bs; num_heads; num_atoms]
// -> many-heads view transform
// {num_critics} * [{bs * num_heads}; num_atoms]
logits_t = [
x(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms)
for x in self.critics
]
// {num_critics} * [bs; num_heads; num_atoms]
logits_tp1 = [
x(states_tp1, actions_tp1).squeeze_(dim=2)
for x in self.target_critics
]
// {num_critics} * [{bs * num_heads}; num_atoms]
probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
// {num_critics} * [bs; num_heads; 1]
q_values_tp1 = [
torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
]
// [{bs * num_heads}; num_critics] -> argmin over all critics
// [{bs * num_heads}]
probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)
// [bs; num_heads; num_atoms; num_critics]
logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
// @TODO: smarter way to do this (other than reshaping)?
probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
// [bs; num_heads; num_atoms; num_critics] -> many-heads view transform
// [{bs * num_heads}; num_atoms; num_critics] -> min over all critics
// [{bs * num_heads}; num_atoms; 1] -> target view transform
// [{bs; num_heads}; num_atoms]
logits_tp1 = (
logits_tp1
.view(-1, self.num_atoms, self._num_critics)[
range(len(probs_ids_tp1_min)), :, probs_ids_tp1_min]
.view(-1, self.num_atoms)
).detach()