e5a659aa52c02eed0368d917a66cc8afb4c9fbf8,ml/rl/test/gridworld/test_gridworld_parametric.py,TestGridworldParametric,_test_trainer_sarsa_factorized,#TestGridworldParametric#Any#Any#,204

Before Change



    def _test_trainer_sarsa_factorized(self, use_gpu=False, use_all_avail_gpus=False):
        environment = GridworldContinuous()
        samples = environment.generate_samples(100000, 1.0, DISCOUNT)
        trainer = self.get_sarsa_trainer(
            environment,
            self.get_sarsa_parameters_factorized(),
            use_gpu=use_gpu,
            use_all_avail_gpus=use_all_avail_gpus,
        )
        predictor = trainer.predictor()
        evaluator = GridworldContinuousEvaluator(
            environment, False, DISCOUNT, False, samples
        )
        tdps = environment.preprocess_samples(
            samples, self.minibatch_size, use_gpu=use_gpu
        )

        for tdp in tdps:
            trainer.train(tdp)

        predictor = trainer.predictor()
        evaluator.evaluate(predictor)

        self.assertLess(evaluator.evaluate(predictor), 0.15)

After Change


    def test_modular_trainer_sarsa_all_gpus(self):
        self._test_trainer_sarsa(use_gpu=True, use_all_avail_gpus=True, modular=True)

    def _test_trainer_sarsa_factorized(self, use_gpu=False, use_all_avail_gpus=False):
        self.check_tolerance = False
        self.tolerance_threshold = 0.15
        environment = GridworldContinuous()
        trainer, exporter = self.get_sarsa_trainer_exporter(
            environment,
            self.get_sarsa_parameters_factorized(),
            use_gpu,
            use_all_avail_gpus,
        )
        evaluator = GridworldContinuousEvaluator(environment, False, DISCOUNT, False)
        self.evaluate_gridworld(environment, evaluator, trainer, exporter, use_gpu)

    def test_trainer_sarsa_factorized(self):
        self._test_trainer_sarsa_factorized()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 13

Instances


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_parametric.py
Class Name: TestGridworldParametric
Method Name: _test_trainer_sarsa_factorized


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_parametric.py
Class Name: TestGridworldParametric
Method Name: _test_trainer_sarsa_factorized


Project Name: facebookresearch/Horizon
Commit Name: d59156c9157ff3aab5b3cb24b3cc65bb9269004c
Time: 2018-10-24
Author: kittipat@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_sac.py
Class Name: TestGridworldSAC
Method Name: _test_sac_trainer


Project Name: facebookresearch/Horizon
Commit Name: e5a659aa52c02eed0368d917a66cc8afb4c9fbf8
Time: 2018-10-24
Author: jjg@fb.com
File Name: ml/rl/test/gridworld/test_gridworld_ddpg.py
Class Name: TestGridworldContinuous
Method Name: _test_ddpg_trainer