f66da063f023c162b1d8dc2e191a202ff5c84843,ml/rl/training/dqn_trainer.py,DQNTrainer,train,#DQNTrainer#Any#Any#,169

Before Change


        reward_loss.backward()
        self.reward_network_optimizer.step()

        if evaluator is not None:
            self.evaluate(
                evaluator,
                training_samples.actions,
                training_samples.propensities,
                boosted_rewards,
                training_samples.episode_values,
            )

    def evaluate(
        self,
        evaluator: Evaluator,
        logged_actions: torch.Tensor,

After Change


        training_metadata = {}
        if evaluator is not None:

            model_propensities = Evaluator.softmax(
                self.all_action_scores.cpu().numpy(), self.rl_temperature
            )

            cpe_stats = BatchStatsForCPE(
                td_loss=self.loss.cpu().numpy(),
                logged_actions=training_samples.actions.cpu().numpy(),
                logged_propensities=training_samples.propensities.cpu().numpy(),
                logged_rewards=rewards.cpu().numpy(),
                logged_values=None,  // Compute at end of each epoch for CPE
                model_propensities=model_propensities,
                model_rewards=self.reward_estimates.cpu().numpy(),
                model_values=self.all_action_scores.cpu().numpy(),
                model_values_on_logged_actions=None,  // Compute at end of each epoch for CPE
                model_action_idxs=self.all_action_scores.argmax(dim=1, keepdim=True)
                .cpu()
                .numpy(),
            )
            evaluator.report(cpe_stats)
            training_metadata["model_rewards"] = self.reward_estimates.cpu().numpy()

        return training_metadata
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 6

Non-data size: 3

Instances


Project Name: facebookresearch/Horizon
Commit Name: f66da063f023c162b1d8dc2e191a202ff5c84843
Time: 2018-10-25
Author: edoardoc@fb.com
File Name: ml/rl/training/dqn_trainer.py
Class Name: DQNTrainer
Method Name: train


Project Name: metagenome-atlas/atlas
Commit Name: 54ca329617d504d886a65fe8433949db5df00db7
Time: 2018-04-05
Author: SilasK@users.noreply.github.com
File Name: atlas/report/qc_report.py
Class Name:
Method Name:


Project Name: facebookresearch/ParlAI
Commit Name: ec0b3e8c47876f4700c69d8ad19d2fe6aafbd1ff
Time: 2020-04-15
Author: jase@fb.com
File Name: parlai/scripts/detect_offensive_language.py
Class Name:
Method Name: detect


Project Name: facebookresearch/Horizon
Commit Name: 6c61a45895e1b6fbdf468a88d565ce8c299aac5a
Time: 2017-11-17
Author: nishadsingh@fb.com
File Name: ml/rl/training/rl_trainer.py
Class Name: RLTrainer
Method Name: stream


Project Name: tensorflow/minigo
Commit Name: 76af726f4d58fd445496734a41e22ff36afe3657
Time: 2018-02-06
Author: brian.kihoon.lee@gmail.com
File Name: dual_net.py
Class Name: DualNetworkTrainer
Method Name: train


Project Name: pfnet/optuna
Commit Name: 32773979b38206603c46eee9fb2bf9f4fe6f0556
Time: 2020-04-15
Author: crissman@preferred.jp
File Name: examples/pytorch_simple.py
Class Name:
Method Name: objective