f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191,catalyst/rl/offpolicy/algorithms/td3.py,TD3,_init,#TD3#Any#Any#Any#,18
Before Change
self._hyperbolic_constant,
self._num_heads
)
self._gammas = torch.Tensor(self._gammas).to(self._device)
assert critic_distribution in [None, "categorical", "quantile"]
if critic_distribution == "categorical":
self.num_atoms = self.critic.num_atoms
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self.num_atoms
)
self.z = self._to_tensor(z)
self._loss_fn = self._categorical_loss
elif critic_distribution == "quantile":
self.num_atoms = self.critic.num_atoms
tau_min = 1 / (2 * self.num_atoms)
tau_max = 1 - tau_min
tau = torch.linspace(
start=tau_min, end=tau_max, steps=self.num_atoms
)
self.tau = self._to_tensor(tau)
self._loss_fn = self._quantile_loss
def _add_noise_to_actions(self, actions):
action_noise = torch.normal(
After Change
self._num_heads = self.critic.num_heads
self._num_critics = len(self.critics)
self._hyperbolic_constant = self.critic.hyperbolic_constant
self._gammas = \
utils.hyperbolic_gammas(
self._gamma,
self._hyperbolic_constant,
self._num_heads
)
self._gammas = utils.any2device(self._gammas, device=self._device)
assert critic_distribution in [None, "categorical", "quantile"]
if critic_distribution == "categorical":
self.num_atoms = self.critic.num_atoms
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self.num_atoms
)
self.z = utils.any2device(z, device=self._device)
self._loss_fn = self._categorical_loss
elif critic_distribution == "quantile":
self.num_atoms = self.critic.num_atoms
tau_min = 1 / (2 * self.num_atoms)
tau_max = 1 - tau_min
tau = torch.linspace(
start=tau_min, end=tau_max, steps=self.num_atoms
)
self.tau = utils.any2device(tau, device=self._device)
self._loss_fn = self._quantile_loss
else:
assert self.critic_criterion is not None
In pattern: SUPERPATTERN
Frequency: 8
Non-data size: 18
Instances
Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _init
Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/td3.py
Class Name: TD3
Method Name: _init
Project Name: Scitator/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/dqn.py
Class Name: DQN
Method Name: _init
Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/sac.py
Class Name: SAC
Method Name: _init
Project Name: catalyst-team/catalyst
Commit Name: f2a7ac7952dfd93abdfdb3a2e1dbed96066ab191
Time: 2019-06-25
Author: scitator@gmail.com
File Name: catalyst/rl/offpolicy/algorithms/ddpg.py
Class Name: DDPG
Method Name: _init