5800dcf1950c370dc1a5ba8faf691ea76f549585,slm_lab/agent/algorithm/dqn.py,DQNBase,init_nets,#DQNBase#Any#,181

Before Change


        """Initialize networks"""
        if self.algorithm_spec["name"] == "DQNBase":
            assert all(k not in self.net_spec for k in ["update_type", "update_frequency", "polyak_coef"]), "Network update not available for DQNBase; use DQN."
        if global_nets is None:
            in_dim = self.body.state_dim
            out_dim = net_util.get_out_dim(self.body)
            NetClass = getattr(net, self.net_spec["type"])
            self.net = NetClass(self.net_spec, in_dim, out_dim)
            self.target_net = NetClass(self.net_spec, in_dim, out_dim)
            self.net_names = ["net", "target_net"]
        else:
            util.set_attr(self, global_nets)
            self.net_names = list(global_nets.keys())
        // init net optimizer and its lr scheduler
        self.optim = net_util.get_optim(self.net, self.net.optim_spec)
        self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
        self.post_init_nets()
        self.online_net = self.target_net
        self.eval_net = self.target_net

After Change


    """

    @lab_api
    def init_nets(self, global_nets=None):
        """Initialize networks"""
        if self.algorithm_spec["name"] == "DQNBase":
            assert all(k not in self.net_spec for k in ["update_type", "update_frequency", "polyak_coef"]), "Network update not available for DQNBase; use DQN."
        in_dim = self.body.state_dim
        out_dim = net_util.get_out_dim(self.body)
        NetClass = getattr(net, self.net_spec["type"])
        self.net = NetClass(self.net_spec, in_dim, out_dim)
        self.target_net = NetClass(self.net_spec, in_dim, out_dim)
        self.net_names = ["net", "target_net"]
        // init net optimizer and its lr scheduler
        self.optim = net_util.get_optim(self.net, self.net.optim_spec)
        self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
        if global_nets is not None:
            net_util.set_global_nets(self, global_nets)
        self.post_init_nets()
        self.online_net = self.target_net
        self.eval_net = self.target_net
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 5

Non-data size: 28

Instances


Project Name: kengz/SLM-Lab
Commit Name: 5800dcf1950c370dc1a5ba8faf691ea76f549585
Time: 2019-05-17
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: DQNBase
Method Name: init_nets


Project Name: kengz/SLM-Lab
Commit Name: 5800dcf1950c370dc1a5ba8faf691ea76f549585
Time: 2019-05-17
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/reinforce.py
Class Name: Reinforce
Method Name: init_nets


Project Name: kengz/SLM-Lab
Commit Name: 5800dcf1950c370dc1a5ba8faf691ea76f549585
Time: 2019-05-17
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: VanillaDQN
Method Name: init_nets


Project Name: kengz/SLM-Lab
Commit Name: 5800dcf1950c370dc1a5ba8faf691ea76f549585
Time: 2019-05-17
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/actor_critic.py
Class Name: ActorCritic
Method Name: init_nets


Project Name: kengz/SLM-Lab
Commit Name: 5800dcf1950c370dc1a5ba8faf691ea76f549585
Time: 2019-05-17
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/dqn.py
Class Name: DQNBase
Method Name: init_nets


Project Name: kengz/SLM-Lab
Commit Name: 5800dcf1950c370dc1a5ba8faf691ea76f549585
Time: 2019-05-17
Author: kengzwl@gmail.com
File Name: slm_lab/agent/algorithm/sarsa.py
Class Name: SARSA
Method Name: init_nets