8cebf285e6067215bf6d94d9022da4a92ae93ba5,slm_lab/agent/algorithm/reinforce.py,Reinforce,train,#Reinforce#,132
Before Change
@lab_api
def train(self):
if util.get_lab_mode() in ["enjoy", "eval"]:
self.body.entropies = []
self.body.log_probs = []
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
batch = self.sample()
loss = self.calc_policy_loss(batch)
self.net.training_step(loss=loss, lr_clock=clock)
// reset
self.to_train = 0
self.body.entropies = []
self.body.log_probs = []
logger.debug(f"Trained {self.name} at epi: {clock.get("epi")}, total_t: {clock.get("total_t")}, t: {clock.get("t")}, total_reward so far: {self.body.memory.total_reward}, loss: {loss:.8f}")
return loss.item()
else:
After Change
@lab_api
def train(self):
if util.get_lab_mode() in ("enjoy", "eval"):
self.body.flush()
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
batch = self.sample()
loss = self.calc_policy_loss(batch)
self.net.training_step(loss=loss, lr_clock=clock)
// reset
self.to_train = 0
self.body.flush()
logger.debug(f"Trained {self.name} at epi: {clock.get("epi")}, total_t: {clock.get("total_t")}, t: {clock.get("t")}, total_reward so far: {self.body.memory.total_reward}, loss: {loss:.8f}")
return loss.item()
else:
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 28
Instances
Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: train
Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: VanillaDQN
Method Name: train
Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: train
Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/sarsa.py
Class Name: SARSA
Method Name: train