2381a50a70559340a0335288d648b4bb9a675588,slm_lab/agent/algorithm/dqn.py,HydraDQN,sample,#HydraDQN#,440
Before Change
@lab_api
def sample(self):
"""Samples one batch per environment"""
batches = [body.memory.sample() for body in self.agent.nanflat_body_a]
// Package data into pytorch variables
for batch_b in batches:
util.to_torch_batch(batch_b, self.net.gpu)
batch = {"states": [], "next_states": []}
for b in batches:
batch["states"].append(b["states"])
batch["next_states"].append(b["next_states"])
batch["batches"] = batches
return batch
def compute_q_target_values(self, batch):
After Change
for body in self.agent.nanflat_body_a:
body_batch = body.memory.sample()
// one-hot actions to calc q_targets
if body.is_discrete:
body_batch["actions"] = util.to_one_hot(body_batch["actions"], body.action_space.high)
body_batch = util.to_torch_batch(body_batch, self.net.gpu)
batches.append(body_batch)
// collect per body for feedforward to hydra heads
batch = {
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 13
Instances
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: HydraDQN
Method Name: sample
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: HydraDQN
Method Name: sample
Project Name: kengz/SLM-Lab
Commit Name: 861657d2c9b321961994c8cdd0e58b6c4fe0645f
Time: 2018-09-03
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: VanillaDQN
Method Name: sample
Project Name: kengz/SLM-Lab
Commit Name: 861657d2c9b321961994c8cdd0e58b6c4fe0645f
Time: 2018-09-03
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/sarsa.py
Class Name: SARSA
Method Name: sample