ec849adaf4ceb42ed52ca142c839f627c34b9434,slm_lab/agent/algorithm/reinforce.py,Reinforce,train,#Reinforce#,95
Before Change
logger.debug2(f"Training...")
// We only care about the rewards from the batch
rewards = self.sample()["rewards"]
logger.debug3(f"Length first epi: {len(rewards[0])}")
logger.debug3(f"Len log probs: {len(self.saved_log_probs)}")
self.net.optim.zero_grad()
policy_loss = self.get_policy_loss(rewards)
loss = policy_loss.data[0]
policy_loss.backward()
After Change
self.saved_log_probs = []
self.entropy = []
logger.debug(f"Policy loss: {loss}")
return loss.item()
else:
return np.nan
def calc_policy_loss(self, batch):
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 4
Instances
Project Name: kengz/SLM-Lab
Commit Name: ec849adaf4ceb42ed52ca142c839f627c34b9434
Time: 2018-05-21
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: train
Project Name: kengz/SLM-Lab
Commit Name: ec849adaf4ceb42ed52ca142c839f627c34b9434
Time: 2018-05-21
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/sarsa.py
Class Name: SARSA
Method Name: train
Project Name: kengz/SLM-Lab
Commit Name: 2381a50a70559340a0335288d648b4bb9a675588
Time: 2018-06-12
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: HydraDQN
Method Name: train
Project Name: kengz/SLM-Lab
Commit Name: ec849adaf4ceb42ed52ca142c839f627c34b9434
Time: 2018-05-21
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: train