8cebf285e6067215bf6d94d9022da4a92ae93ba5,slm_lab/agent/algorithm/reinforce.py,Reinforce,train,#Reinforce#,132

Before Change



    @lab_api
    def train(self):
        if util.get_lab_mode() in ["enjoy", "eval"]:
            self.body.entropies = []
            self.body.log_probs = []
            return np.nan
        clock = self.body.env.clock
        if self.to_train == 1:
            batch = self.sample()
            loss = self.calc_policy_loss(batch)
            self.net.training_step(loss=loss, lr_clock=clock)
            // reset
            self.to_train = 0
            self.body.entropies = []
            self.body.log_probs = []
            logger.debug(f"Trained {self.name} at epi: {clock.get("epi")}, total_t: {clock.get("total_t")}, t: {clock.get("t")}, total_reward so far: {self.body.memory.total_reward}, loss: {loss:.8f}")

            return loss.item()
        else:

After Change



    @lab_api
    def train(self):
        if util.get_lab_mode() in ("enjoy", "eval"):
            self.body.flush()
            return np.nan
        clock = self.body.env.clock
        if self.to_train == 1:
            batch = self.sample()
            loss = self.calc_policy_loss(batch)
            self.net.training_step(loss=loss, lr_clock=clock)
            // reset
            self.to_train = 0
            self.body.flush()
            logger.debug(f"Trained {self.name} at epi: {clock.get("epi")}, total_t: {clock.get("total_t")}, t: {clock.get("t")}, total_reward so far: {self.body.memory.total_reward}, loss: {loss:.8f}")

            return loss.item()
        else:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 28

Instances


Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: train


Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: VanillaDQN
Method Name: train


Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: train


Project Name: kengz/SLM-Lab
Commit Name: 8cebf285e6067215bf6d94d9022da4a92ae93ba5
Time: 2018-12-11
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/sarsa.py
Class Name: SARSA
Method Name: train