447444fd06594e531ae1141afac78051481e4468,catalyst/rl/offpolicy/algorithms/dqn.py,DQN,_quantile_loss,#DQN#Any#Any#Any#Any#Any#,133

Before Change


        // B x num_heads x num_actions
        q_values_tp1 = q_atoms_tp1.mean(dim=-1)
        // B x num_heads x 1
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        // B x num_heads x 1 x num_atoms
        indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        // B x num_heads x num_atoms
        atoms_tp1 = q_atoms_tp1.gather(-2, indices_tp1).squeeze(-2)

After Change



        // [bs; num_heads; num_atoms] -> many-heads view transform
        // [{bs * num_heads}; num_atoms]
        atoms_target_t = (
            rewards_t + (1 - done_t) * gammas * atoms_tp1
        ).view(-1, self.num_atoms).detach()

        value_loss = utils.quantile_loss(
            // [{bs * num_heads}; num_atoms]
            atoms_t,
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 4

Instances


Project Name: Scitator/catalyst
Commit Name: 447444fd06594e531ae1141afac78051481e4468
Time: 2019-10-31
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _quantile_loss


Project Name: catalyst-team/catalyst
Commit Name: 447444fd06594e531ae1141afac78051481e4468
Time: 2019-10-31
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _categorical_loss