a18d8941f663eea55488781c804e6305a36f1b58,ml/rl/training/parametric_dqn_trainer.py,ParametricDQNTrainer,train,#ParametricDQNTrainer#Any#,57

Before Change


                training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        discount_tensor = torch.full_like(reward, self.gamma)
        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma, learning_input.time_diff.float())
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma, learning_input.step.float())

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                learning_input.tiled_next_state, learning_input.possible_next_actions
            )
            // Compute max a" Q(s", a") over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values.q_value,
                all_next_q_values_target.q_value,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            // SARSA (Use the target network)
            _, next_q_values = self.get_detached_q_values(
                learning_input.next_state, learning_input.next_action
            )
            next_q_values = next_q_values.q_value

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        if self.minibatch < self.reward_burnin:
            target_q_values = reward
        else:
            target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        // Get Q-value of action taken
        current_state_action = rlt.StateAction(
            state=learning_input.state, action=learning_input.action
        )
        q_values = self.q_network(current_state_action).q_value
        self.all_action_scores = q_values.detach()

        value_loss = self.q_network_loss(q_values, target_q_values)
        self.loss = value_loss.detach()

        self.q_network_optimizer.zero_grad()
        value_loss.backward()
        self.q_network_optimizer.step()

        // TODO: Maybe soft_update should belong to the target network
        if self.minibatch < self.reward_burnin:
            // Reward burnin: force target network
            self._soft_update(self.q_network, self.q_network_target, 1.0)
        else:
            // Use the soft update rule to update target network
            self._soft_update(self.q_network, self.q_network_target, self.tau)

        // get reward estimates
        reward_estimates = self.reward_network(current_state_action).q_value
        reward_loss = F.mse_loss(reward_estimates, reward)
        self.reward_network_optimizer.zero_grad()
        reward_loss.backward()

After Change


        self.q_network_optimizer.step()

        // Use the soft update rule to update target network
        self._soft_update(self.q_network, self.q_network_target, self.tau)

        // get reward estimates
        reward_estimates = self.reward_network(current_state_action).q_value
        reward_loss = F.mse_loss(reward_estimates, reward)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 17

Instances


Project Name: facebookresearch/Horizon
Commit Name: a18d8941f663eea55488781c804e6305a36f1b58
Time: 2019-04-18
Author: jjg@fb.com
File Name: ml/rl/training/parametric_dqn_trainer.py
Class Name: ParametricDQNTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: a18d8941f663eea55488781c804e6305a36f1b58
Time: 2019-04-18
Author: jjg@fb.com
File Name: ml/rl/training/sac_trainer.py
Class Name: SACTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: a18d8941f663eea55488781c804e6305a36f1b58
Time: 2019-04-18
Author: jjg@fb.com
File Name: ml/rl/training/parametric_dqn_trainer.py
Class Name: ParametricDQNTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: a18d8941f663eea55488781c804e6305a36f1b58
Time: 2019-04-18
Author: jjg@fb.com
File Name: ml/rl/training/ddpg_trainer.py
Class Name: DDPGTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: a18d8941f663eea55488781c804e6305a36f1b58
Time: 2019-04-18
Author: jjg@fb.com
File Name: ml/rl/training/dqn_trainer.py
Class Name: DQNTrainer
Method Name: train