f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191,catalyst/rl/offpolicy/algorithms/td3.py,TD3,_init,#TD3#Any#Any#Any#,18

Before Change


                self._hyperbolic_constant,
                self._num_heads
            )
        self._gammas = torch.Tensor(self._gammas).to(self._device)
        assert critic_distribution in [None, "categorical", "quantile"]

        if critic_distribution == "categorical":
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self.num_atoms
            )
            self.z = self._to_tensor(z)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(
                start=tau_min, end=tau_max, steps=self.num_atoms
            )
            self.tau = self._to_tensor(tau)
            self._loss_fn = self._quantile_loss

    def _add_noise_to_actions(self, actions):
        action_noise = torch.normal(

After Change


        self._num_heads = self.critic.num_heads
        self._num_critics = len(self.critics)
        self._hyperbolic_constant = self.critic.hyperbolic_constant
        self._gammas = \
            utils.hyperbolic_gammas(
                self._gamma,
                self._hyperbolic_constant,
                self._num_heads
            )
        self._gammas = utils.any2device(self._gammas, device=self._device)
        assert critic_distribution in [None, "categorical", "quantile"]

        if critic_distribution == "categorical":
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(
                start=self.v_min, end=self.v_max, steps=self.num_atoms
            )
            self.z = utils.any2device(z, device=self._device)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(
                start=tau_min, end=tau_max, steps=self.num_atoms
            )
            self.tau = utils.any2device(tau, device=self._device)
            self._loss_fn = self._quantile_loss
        else:
            assert self.critic_criterion is not None
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 8

Non-data size: 18

Instances


Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _init


Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _init


Project Name: Scitator/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _init


Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/sac.py
Class Name: SAC
Method Name: _init


Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/ddpg.py
Class Name: DDPG
Method Name: _init