447444fd06594e531ae1141afac78051481e4468,catalyst/rl/offpolicy/algorithms/td3.py,TD3,_categorical_loss,#TD3#Any#Any#Any#Any#Any#,143

Before Change


        policy_loss = -torch.mean(q_values_tp0_min)

        // critic loss (kl-divergence between categorical distributions)
        actions_tp1 = self.target_actor(states_tp1).detach()
        actions_tp1 = self._add_noise_to_actions(actions_tp1)
        logits_t = [
            x(states_t, actions_t).squeeze_(dim=2) for x in self.critics
        ]
        logits_tp1 = [
            x(states_tp1, actions_tp1).squeeze_(dim=2)
            for x in self.target_critics
        ]
        probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
        ]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)
        // B x num_heads

        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        // B x num_heads x num_atoms x num_critics
        // @TODO: smarter way to do this (other than reshaping)?
        probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
        logits_tp1 = logits_tp1.view(-1, self.num_atoms, self._num_critics)

        logits_tp1 = \
            logits_tp1[range(len(logits_tp1)), :, probs_ids_tp1_min].\
            view(-1, self._num_heads, self.num_atoms).detach()

        atoms_target_t = rewards_t + (1 - done_t) * gammas * self.z
        value_loss = [
            utils.categorical_loss(

After Change



        // critic loss (kl-divergence between categorical distributions)
        // [bs; action_size]
        actions_tp1 = self.target_actor(states_tp1)
        actions_tp1 = self._add_noise_to_actions(actions_tp1).detach()

        // {num_critics} * [bs; num_heads; num_atoms]
        // -> many-heads view transform
        // {num_critics} * [{bs * num_heads}; num_atoms]
        logits_t = [
            x(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms)
            for x in self.critics
        ]

        // {num_critics} * [bs; num_heads; num_atoms]
        logits_tp1 = [
            x(states_tp1, actions_tp1).squeeze_(dim=2)
            for x in self.target_critics
        ]
        // {num_critics} * [{bs * num_heads}; num_atoms]
        probs_tp1 = [torch.softmax(x, dim=-1) for x in logits_tp1]
        // {num_critics} * [bs; num_heads; 1]
        q_values_tp1 = [
            torch.sum(x * self.z, dim=-1, keepdim=True) for x in probs_tp1
        ]
        //  [{bs * num_heads}; num_critics] ->  argmin over all critics
        //  [{bs * num_heads}]
        probs_ids_tp1_min = torch.cat(q_values_tp1, dim=-1).argmin(dim=-1)

        // [bs; num_heads; num_atoms; num_critics]
        logits_tp1 = torch.cat([x.unsqueeze(-1) for x in logits_tp1], dim=-1)
        // @TODO: smarter way to do this (other than reshaping)?
        probs_ids_tp1_min = probs_ids_tp1_min.view(-1)
        // [bs; num_heads; num_atoms; num_critics] -> many-heads view transform
        // [{bs * num_heads}; num_atoms; num_critics] -> min over all critics
        // [{bs * num_heads}; num_atoms; 1] -> target view transform
        // [{bs; num_heads}; num_atoms]
        logits_tp1 = (
            logits_tp1
            .view(-1, self.num_atoms, self._num_critics)[
                range(len(probs_ids_tp1_min)), :, probs_ids_tp1_min]
            .view(-1, self.num_atoms)
        ).detach()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 9

Instances


Project Name: catalyst-team/catalyst
Commit Name: 447444fd06594e531ae1141afac78051481e4468
Time: 2019-10-31
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _categorical_loss


Project Name: catalyst-team/catalyst
Commit Name: 447444fd06594e531ae1141afac78051481e4468
Time: 2019-10-31
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _quantile_loss