e5a659aa52c02eed0368d917a66cc8afb4c9fbf8,ml/rl/test/gridworld/test_gridworld_ddpg.py,TestGridworldContinuous,_test_ddpg_trainer,#TestGridworldContinuous#Any#Any#,55

Before Change



    def _test_ddpg_trainer(self, use_gpu=False, use_all_avail_gpus=False):
        environment = GridworldContinuous()
        samples = environment.generate_samples(100000, 0.25, DISCOUNT)
        trainer = DDPGTrainer(
            self.get_ddpg_parameters(),
            environment.normalization,
            environment.normalization_action,
            environment.min_action_range,
            environment.max_action_range,
            use_gpu=use_gpu,
            use_all_avail_gpus=use_all_avail_gpus,
        )
        evaluator = GridworldDDPGEvaluator(environment, True, DISCOUNT, False, samples)
        tdps = environment.preprocess_samples(
            samples, self.minibatch_size, use_gpu=use_gpu
        )

        critic_predictor = trainer.predictor(actor=False)
        evaluator.evaluate_critic(critic_predictor)
        for tdp in tdps:
            tdp.rewards = tdp.rewards.reshape(-1, 1)
            tdp.not_terminals = tdp.not_terminals.reshape(-1, 1)
            trainer.train(tdp)

        // Make sure actor predictor works
        actor = trainer.predictor(actor=True)
        evaluator.evaluate_actor(actor)

        // Evaluate critic predicor for correctness

After Change


            ),
        )

    def _test_ddpg_trainer(self, use_gpu=False, use_all_avail_gpus=False):
        self.check_tolerance = False
        self.tolerance_threshold = 1.0
        environment = GridworldContinuous()
        trainer = DDPGTrainer(
            self.get_ddpg_parameters(),
            environment.normalization,
            environment.normalization_action,
            environment.min_action_range,
            environment.max_action_range,
            use_gpu=use_gpu,
            use_all_avail_gpus=use_all_avail_gpus,
        )
        evaluator = GridworldDDPGEvaluator(environment, True, DISCOUNT, False)
        self.evaluate_gridworld(environment, evaluator, trainer, trainer, use_gpu)

    def test_ddpg_trainer(self):
        self._test_ddpg_trainer()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 11

Instances


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_ddpg.py
Class Name: TestGridworldContinuous
Method Name: _test_ddpg_trainer


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_parametric.py
Class Name: TestGridworldParametric
Method Name: _test_trainer_sarsa


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_parametric.py
Class Name: TestGridworldParametric
Method Name: _test_trainer_sarsa_factorized


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_ddpg.py
Class Name: TestGridworldContinuous
Method Name: _test_ddpg_trainer