066fda55de8b2a65d57b6201d5b37daa6f7f399d,catalyst/rl/utils.py,PolicyHandler,_get_q_values,#PolicyHandler#Any#Any#Any#,416

Before Change


    def _get_q_values(self, critic: CriticSpec, state: np.ndarray, device):
        states = torch.Tensor(state).to(device).unsqueeze(0)
        if self.value_distribution == "categorical":
            probs = torch.softmax(critic(states)[0], dim=-1)
            q_values = torch.sum(probs * self.z, dim=-1)
        elif self.value_distribution == "quantile":
            q_values = torch.mean(critic(states)[0], dim=-1)

After Change


        elif self.value_distribution == "quantile":
            q_values = torch.mean(output[0, -1, :, :], dim=-1)
        else:
            q_values = output[0, -1, :]
        return q_values.cpu().numpy()

    @torch.no_grad()
    def _sample_from_actor(
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: catalyst-team/catalyst
Commit Name: 066fda55de8b2a65d57b6201d5b37daa6f7f399d
Time: 2019-06-10
Author: khrulkov.v@gmail.com
File Name: catalyst/rl/utils.py
Class Name: PolicyHandler
Method Name: _get_q_values


Project Name: facebookresearch/Horizon
Commit Name: 8dcbf002d0058939802920c11f74791708838791
Time: 2018-04-06
Author: edoardoc@fb.com
File Name: ml/rl/training/ddpg_predictor.py
Class Name: DDPGPredictor
Method Name: predict_q_value