296881376f1d04fd189ec2dfa7878400a59d2b1d,slm_lab/agent/algorithm/dqn.py,DQNBase,calc_q_loss,#DQNBase#Any#,195

Before Change


        // Use online_net to select actions in next state
        online_next_q_preds = self.online_net(batch["next_states"])
        // Use eval_net to calculate next_q_preds for actions chosen by online_net
        next_q_preds = self.eval_net(batch["next_states"])
        max_next_q_preds = next_q_preds.gather(-1, online_next_q_preds.argmax(dim=-1, keepdim=True)).squeeze(-1)
        max_q_targets = batch["rewards"] + self.gamma * (1 - batch["dones"]) * max_next_q_preds
        max_q_targets = max_q_targets.detach()

After Change


        if self.body.env.is_venv:
            states = math_util.venv_unpack(states)
            next_states = math_util.venv_unpack(next_states)
        q_preds = self.net(states)
        // Use online_net to select actions in next state
        online_next_q_preds = self.online_net(next_states)
        // Use eval_net to calculate next_q_preds for actions chosen by online_net
        next_q_preds = self.eval_net(next_states)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: kengz/SLM-Lab
Commit Name: 296881376f1d04fd189ec2dfa7878400a59d2b1d
Time: 2019-04-30
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: DQNBase
Method Name: calc_q_loss


Project Name: NifTK/NiftyNet
Commit Name: 2f50617c8039945bdd0634f6e28cd9728acbab38
Time: 2017-11-03
Author: wenqi.li@ucl.ac.uk
File Name: niftynet/application/label_driven_registration.py
Class Name: RegApp
Method Name: connect_data_and_network


Project Name: kengz/SLM-Lab
Commit Name: 58b42927fc44b779e8c7dd8507d6cfebe344f2ef
Time: 2019-05-01
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: DQNBase
Method Name: calc_q_loss