f66da063f023c162b1d8dc2e191a202ff5c84843,ml/rl/training/dqn_trainer.py,DQNTrainer,train,#DQNTrainer#Any#Any#,169

Before Change


        self.reward_network_optimizer.step()

        if evaluator is not None:
            self.evaluate(
                evaluator,
                training_samples.actions,
                training_samples.propensities,
                boosted_rewards,
                training_samples.episode_values,
            )

    def evaluate(
        self,
        evaluator: Evaluator,

After Change


                self.all_action_scores.cpu().numpy(), self.rl_temperature
            )

            cpe_stats = BatchStatsForCPE(
                td_loss=self.loss.cpu().numpy(),
                logged_actions=training_samples.actions.cpu().numpy(),
                logged_propensities=training_samples.propensities.cpu().numpy(),
                logged_rewards=rewards.cpu().numpy(),
                logged_values=None,  // Compute at end of each epoch for CPE
                model_propensities=model_propensities,
                model_rewards=self.reward_estimates.cpu().numpy(),
                model_values=self.all_action_scores.cpu().numpy(),
                model_values_on_logged_actions=None,  // Compute at end of each epoch for CPE
                model_action_idxs=self.all_action_scores.argmax(dim=1, keepdim=True)
                .cpu()
                .numpy(),
            )
            evaluator.report(cpe_stats)
            training_metadata["model_rewards"] = self.reward_estimates.cpu().numpy()

        return training_metadata
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 9

Instances


Project Name: facebookresearch/Horizon
Commit Name: f66da063f023c162b1d8dc2e191a202ff5c84843
Time: 2018-10-25
Author: edoardoc@fb.com
File Name: ml/rl/training/dqn_trainer.py
Class Name: DQNTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: f66da063f023c162b1d8dc2e191a202ff5c84843
Time: 2018-10-25
Author: edoardoc@fb.com
File Name: ml/rl/training/_parametric_dqn_trainer.py
Class Name: _ParametricDQNTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: 7d671d971440e63f58c6097ffcb46ffe3852ce11
Time: 2018-10-30
Author: kittipat@fb.com
File Name: ml/rl/training/sac_trainer.py
Class Name: SACTrainer
Method Name: train