2381a50a70559340a0335288d648b4bb9a675588,slm_lab/agent/algorithm/dqn.py,MultitaskDQN,sample,#MultitaskDQN#,333
Before Change
@lab_api
def sample(self):
// NOTE the purpose of multi-body is to parallelize and get more batch_sizes
batches = [body.memory.sample() for body in self.agent.nanflat_body_a]
// Package data into pytorch variables
for batch_b in batches:
util.to_torch_batch(batch_b, self.net.gpu)
// Concat state
combined_states = torch.cat(
[batch_b["states"] for batch_b in batches], dim=1)
combined_next_states = torch.cat(
[batch_b["next_states"] for batch_b in batches], dim=1)
After Change
Note that multitask"s bodies are parallelized copies with similar envs, just to get more batch sizes
"""
batches = []
for body in self.agent.nanflat_body_a:
body_batch = body.memory.sample()
// one-hot actions to calc q_targets
if body.is_discrete:
body_batch["actions"] = util.to_one_hot(body_batch["actions"], body.action_space.high)
body_batch = util.to_torch_batch(body_batch, self.net.gpu)
batches.append(body_batch)
// Concat states at dim=1 for feedforward
batch = {
"states": torch.cat([body_batch["states"] for body_batch in batches], dim=1),
"next_states": torch.cat([body_batch["next_states"] for body_batch in batches], dim=1),
}
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 20
Instances
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: MultitaskDQN
Method Name: sample
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: HydraDQN
Method Name: sample
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: VanillaDQN
Method Name: sample
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: MultitaskDQN
Method Name: sample