45499cc575d9e555fd17605104cc2452698ecbad,ml/rl/test/gym/gym_predictor.py,GymDQNPredictor,predict,#GymDQNPredictor#Any#,59

Before Change


            return policies

    def predict(self, states):
        with core.DeviceScope(self.c2_device):
            if isinstance(self.trainer, DiscreteActionTrainer):
                workspace.FeedBlob("states", states)
            else:
                raise NotImplementedError("Invalid trainer passed to GymPredictor")
            workspace.RunNetOnce(self.trainer.internal_policy_model.net)
            policy_output_blob = self.trainer.internal_policy_output
            print(self.trainer.internal_policy_output)
            q_scores = workspace.FetchBlob(policy_output_blob)
            return q_scores


class GymDQNPredictorPytorch(GymPredictor):
    def __init__(self, trainer):
        GymPredictor.__init__(self, trainer)

After Change



    def predict(self, states):
        if isinstance(self.trainer, DQNTrainer):
            input = states
        elif isinstance(self.trainer, ParametricDQNTrainer):
            num_actions = len(self.trainer.action_normalization_parameters)
            actions = np.eye(num_actions, dtype=np.float32)
            actions = np.tile(actions, reps=(len(states), 1))
            states = np.repeat(states, repeats=num_actions, axis=0)
            input = np.hstack((states, actions))
        else:
            raise NotImplementedError("Invalid trainer passed to GymPredictor")
        q_scores = self.trainer.internal_prediction(input)
        return q_scores

    def estimate_reward(self, states):
        if isinstance(self.trainer, DQNTrainer):
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 8

Instances


Project Name: facebookresearch/Horizon
Commit Name: 45499cc575d9e555fd17605104cc2452698ecbad
Time: 2018-10-30
Author: jjg@fb.com
File Name: ml/rl/test/gym/gym_predictor.py
Class Name: GymDQNPredictor
Method Name: predict


Project Name: facebookresearch/Horizon
Commit Name: 45499cc575d9e555fd17605104cc2452698ecbad
Time: 2018-10-30
Author: jjg@fb.com
File Name: ml/rl/test/gym/gym_predictor.py
Class Name: GymDQNPredictor
Method Name: policy


Project Name: geomstats/geomstats
Commit Name: a1dd11c68e5911f069a747c917e2a4bfdd5ae4f4
Time: 2020-04-08
Author: hadizaatiti@gmail.com
File Name: geomstats/learning/em_expectation_maximization.py
Class Name: RiemannianEM
Method Name: update_variances