// B x num_heads x num_actions
q_values_tp1 = q_atoms_tp1.mean(dim=-1)
// B x num_heads x 1
actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
// B x num_heads x 1 x num_atoms
indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
// B x num_heads x num_atoms
atoms_tp1 = q_atoms_tp1.gather(-2, indices_tp1).squeeze(-2)