// actor loss
actions_tp0 = self.actor(states_t)
atoms_tp0 = [
x(states_t, actions_tp0).squeeze_(dim=2) for x in self.critics
]
q_values_tp0_min = torch.cat(
atoms_tp0, dim=-1
After Change
actions_tp0 = self.actor(states_t)
// {num_critics} * [bs; num_heads; num_atoms; 1]
atoms_tp0 = [
x(states_t, actions_tp0).squeeze_(dim=2).unsqueeze_(-1)
for x in self.critics
]
// [bs; num_heads, num_atoms; num_critics] -> many-heads view transform
// [{bs * num_heads}; num_atoms; num_critics] -> quantile value
// [{bs * num_heads}; num_critics] -> min over all critics
// [{bs * num_heads};]
q_values_tp0_min = (
torch.cat(atoms_tp0, dim=-1)
.view(-1, self.num_atoms, self._num_critics)
.mean(dim=1)
.min(dim=1)[0]