447444fd06594e531ae1141afac78051481e4468,catalyst/rl/offpolicy/algorithms/td3.py,TD3,_quantile_loss,#TD3#Any#Any#Any#Any#Any#,203

Before Change


        // actor loss
        actions_tp0 = self.actor(states_t)
        atoms_tp0 = [
            x(states_t, actions_tp0).squeeze_(dim=2) for x in self.critics
        ]
        q_values_tp0_min = torch.cat(
            atoms_tp0, dim=-1

After Change


        actions_tp0 = self.actor(states_t)
        // {num_critics} * [bs; num_heads; num_atoms; 1]
        atoms_tp0 = [
            x(states_t, actions_tp0).squeeze_(dim=2).unsqueeze_(-1)
            for x in self.critics
        ]
        // [bs; num_heads, num_atoms; num_critics] -> many-heads view transform
        // [{bs * num_heads}; num_atoms; num_critics] ->  quantile value
        // [{bs * num_heads}; num_critics] ->  min over all critics
        // [{bs * num_heads};]
        q_values_tp0_min = (
            torch.cat(atoms_tp0, dim=-1)
            .view(-1, self.num_atoms, self._num_critics)
            .mean(dim=1)
            .min(dim=1)[0]
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 3

Instances


Project Name: catalyst-team/catalyst
Commit Name: 447444fd06594e531ae1141afac78051481e4468
Time: 2019-10-31
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _quantile_loss


Project Name: Scitator/catalyst
Commit Name: 447444fd06594e531ae1141afac78051481e4468
Time: 2019-10-31
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/ddpg.py
Class Name: DDPG
Method Name: _base_loss