85209170887f87b0efb9477999a86d12562cc63f,ml/rl/test/environment/test_environment.py,TestEnvironment,test_gridworld_continuous_generate_samples,#TestEnvironment#,85
Before Change
samples = env.generate_samples(
num_samples, epsilon=1.0, discount_factor=0.9, multi_steps=num_steps
)
for i in range(num_samples):
if samples.terminals[i][0]:
break
if i < num_samples - 1:
self.assertEqual(samples.mdp_ids[i], samples.mdp_ids[i + 1])
self.assertEqual(
samples.sequence_numbers[i] + 1, samples.sequence_numbers[i + 1]
)
for j in range(len(samples.terminals[i])):
self.assertEqual(samples.rewards[i][j], samples.rewards[i + j][0])
self.assertDictEqual(
samples.next_states[i][j], samples.next_states[i + j][0]
)
self.assertDictEqual(
samples.next_actions[i][j], samples.next_actions[i + j][0]
)
self.assertEqual(samples.terminals[i][j], samples.terminals[i + j][0])
self.assertListEqual(
samples.possible_next_actions[i][j],
samples.possible_next_actions[i + j][0],
)
if samples.terminals[i][j]:
continue
self.assertDictEqual(
samples.next_states[i][j], samples.states[i + j + 1]
)
self.assertDictEqual(
samples.next_actions[i][j], samples.actions[i + j + 1]
)
self.assertListEqual(
samples.possible_next_actions[i][j],
samples.possible_actions[i + j + 1],
)
single_step_samples = samples.to_single_step()
for i in range(num_samples):
if single_step_samples.terminals[i] is True:
break
self.assertEqual(single_step_samples.mdp_ids[i], samples.mdp_ids[i])
self.assertEqual(
single_step_samples.sequence_numbers[i], samples.sequence_numbers[i]
)
self.assertDictEqual(single_step_samples.states[i], samples.states[i])
self.assertDictEqual(single_step_samples.actions[i], samples.actions[i])
self.assertEqual(
single_step_samples.action_probabilities[i],
samples.action_probabilities[i],
)
self.assertEqual(single_step_samples.rewards[i], samples.rewards[i][0])
self.assertListEqual(
single_step_samples.possible_actions[i], samples.possible_actions[i]
)
self.assertDictEqual(
single_step_samples.next_states[i], samples.next_states[i][0]
)
self.assertDictEqual(
single_step_samples.next_actions[i], samples.next_actions[i][0]
)
self.assertEqual(single_step_samples.terminals[i], samples.terminals[i][0])
self.assertListEqual(
single_step_samples.possible_next_actions[i],
samples.possible_next_actions[i][0],
)
def test_open_ai_gym_generate_samples_multi_step(self):
env = OpenAIGymEnvironment(
"CartPole-v0",
epsilon=1.0, // take random actions to collect training data
After Change
def test_gridworld_continuous_generate_samples(self):
env = GridworldContinuous()
num_samples = 1000
num_steps = 5
samples = env.generate_samples(
num_samples, epsilon=1.0, discount_factor=0.9, multi_steps=num_steps
)
self._check_samples(samples, num_samples, num_steps, True)
def test_open_ai_gym_generate_samples_multi_step(self):
env = OpenAIGymEnvironment(
"CartPole-v0",
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 54
Instances
Project Name: facebookresearch/Horizon
Commit Name: 85209170887f87b0efb9477999a86d12562cc63f
Time: 2019-06-18
Author: czxttkl@fb.com
File Name: ml/rl/test/environment/test_environment.py
Class Name: TestEnvironment
Method Name: test_gridworld_continuous_generate_samples
Project Name: facebookresearch/Horizon
Commit Name: 85209170887f87b0efb9477999a86d12562cc63f
Time: 2019-06-18
Author: czxttkl@fb.com
File Name: ml/rl/test/environment/test_environment.py
Class Name: TestEnvironment
Method Name: test_gridworld_generate_samples
Project Name: facebookresearch/Horizon
Commit Name: 85209170887f87b0efb9477999a86d12562cc63f
Time: 2019-06-18
Author: czxttkl@fb.com
File Name: ml/rl/test/environment/test_environment.py
Class Name: TestEnvironment
Method Name: test_open_ai_gym_generate_samples_multi_step
Project Name: facebookresearch/Horizon
Commit Name: 85209170887f87b0efb9477999a86d12562cc63f
Time: 2019-06-18
Author: czxttkl@fb.com
File Name: ml/rl/test/environment/test_environment.py
Class Name: TestEnvironment
Method Name: test_gridworld_continuous_generate_samples