45499cc575d9e555fd17605104cc2452698ecbad,ml/rl/test/gym/gym_predictor.py,GymDQNPredictor,predict,#GymDQNPredictor#Any#,59
Before Change
return policies
def predict(self, states):
with core.DeviceScope(self.c2_device):
if isinstance(self.trainer, DiscreteActionTrainer):
workspace.FeedBlob("states", states)
else:
raise NotImplementedError("Invalid trainer passed to GymPredictor")
workspace.RunNetOnce(self.trainer.internal_policy_model.net)
policy_output_blob = self.trainer.internal_policy_output
print(self.trainer.internal_policy_output)
q_scores = workspace.FetchBlob(policy_output_blob)
return q_scores
class GymDQNPredictorPytorch(GymPredictor):
def __init__(self, trainer):
GymPredictor.__init__(self, trainer)
After Change
def predict(self, states):
if isinstance(self.trainer, DQNTrainer):
input = states
elif isinstance(self.trainer, ParametricDQNTrainer):
num_actions = len(self.trainer.action_normalization_parameters)
actions = np.eye(num_actions, dtype=np.float32)
actions = np.tile(actions, reps=(len(states), 1))
states = np.repeat(states, repeats=num_actions, axis=0)
input = np.hstack((states, actions))
else:
raise NotImplementedError("Invalid trainer passed to GymPredictor")
q_scores = self.trainer.internal_prediction(input)
return q_scores
def estimate_reward(self, states):
if isinstance(self.trainer, DQNTrainer):
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 8
Instances
Project Name: facebookresearch/Horizon
Commit Name: 45499cc575d9e555fd17605104cc2452698ecbad
Time: 2018-10-30
Author: jjg@fb.com
File Name: ml/rl/test/gym/gym_predictor.py
Class Name: GymDQNPredictor
Method Name: predict
Project Name: facebookresearch/Horizon
Commit Name: 45499cc575d9e555fd17605104cc2452698ecbad
Time: 2018-10-30
Author: jjg@fb.com
File Name: ml/rl/test/gym/gym_predictor.py
Class Name: GymDQNPredictor
Method Name: policy
Project Name: geomstats/geomstats
Commit Name: a1dd11c68e5911f069a747c917e2a4bfdd5ae4f4
Time: 2020-04-08
Author: hadizaatiti@gmail.com
File Name: geomstats/learning/em_expectation_maximization.py
Class Name: RiemannianEM
Method Name: update_variances