c27dbde9ccec2920f3825538aff07e8533e086ba,catalyst/rl/offpolicy/algorithms/dqn.py,DQN,_quantile_loss,#DQN#Any#Any#Any#Any#Any#,115

Before Change


        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):

        gammas = (self._gammas**self._n_step)[None, :, None]
        // 1 x num_heads x 1

        done_t = done_t[:, None, :]  // B x 1 x 1
        rewards_t = rewards_t[:, None, :]  // B x 1 x 1
        actions_t = actions_t[:, None, None, :]  // B x 1 x 1 x 1
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        // B x num_heads x 1 x num_atoms

        // critic loss (quantile regression)

        atoms_t = self.critic(states_t).gather(-2, indices_t).squeeze(-2)
        // B x num_heads x num_atoms

        all_atoms_tp1 = self.target_critic(states_tp1).detach()

After Change


        return value_loss

    def _quantile_loss(
        self, states_t, actions_t, rewards_t, states_tp1, done_t
    ):
        gammas, done_t, rewards_t = self._process_components(done_t, rewards_t)

        actions_t = actions_t[:, None, None, :]  // B x 1 x 1 x 1
        // B x num_heads x 1 x num_atoms
        indices_t = actions_t.repeat(1, self._num_heads, 1, self.num_atoms)
        // B x num_heads x num_actions x num_atoms
        q_atoms_t = self.critic(states_t)
        // B x num_heads x num_atoms
        atoms_t = q_atoms_t.gather(-2, indices_t).squeeze(-2)

        // B x num_heads x num_actions x num_atoms
        q_atoms_tp1 = self.target_critic(states_tp1).detach()
        // B x num_heads x num_actions
        q_values_tp1 = q_atoms_tp1.mean(dim=-1)
        // B x num_heads x 1
        actions_tp1 = torch.argmax(q_values_tp1, dim=-1, keepdim=True)
        // B x num_heads x 1 x num_atoms
        indices_tp1 = actions_tp1.unsqueeze(-1).repeat(1, 1, 1, self.num_atoms)
        // B x num_heads x num_atoms
        atoms_tp1 = q_atoms_tp1.gather(-2, indices_tp1).squeeze(-2)

        atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1
        value_loss = utils.quantile_loss(
            atoms_t.view(-1, self.num_atoms),
            atoms_target_t.view(-1, self.num_atoms),
            self.tau, self.num_atoms,
            self.critic_criterion
        )

        if self.entropy_regularization is not None:
            q_values_t = torch.mean(q_atoms_t, dim=-1)
            value_loss -= \
                self.entropy_regularization * self._compute_entropy(q_values_t)

        return value_loss

    def update_step(self, value_loss, critic_update=True):
        // critic update
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 27

Instances


Project Name: Scitator/catalyst
Commit Name: c27dbde9ccec2920f3825538aff07e8533e086ba
Time: 2019-07-24
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _quantile_loss


Project Name: catalyst-team/catalyst
Commit Name: c27dbde9ccec2920f3825538aff07e8533e086ba
Time: 2019-07-24
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _categorical_loss


Project Name: catalyst-team/catalyst
Commit Name: c27dbde9ccec2920f3825538aff07e8533e086ba
Time: 2019-07-24
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _quantile_loss