7704d54a65086803c9a3258d5d65a21e04db5d04,ml/rl/training/sac_trainer.py,SACTrainer,train,#SACTrainer#Any#,100

Before Change


                target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        self.value_network_optimizer.zero_grad()
        value_loss.backward()
        self.value_network_optimizer.step()

        //
        // Second, optimize Q networks; minimizing MSE between
        // Q(s, a) & r + discount * V"(next_s)
        //

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask.float()
            )

            target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        self.q1_network_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_network_optimizer.step()
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            self.q2_network_optimizer.zero_grad()
            q2_loss.backward()
            self.q2_network_optimizer.step()

        //
        // Lastly, optimize the actor; minimizing KL-divergence between action propensity
        // & softmax of value. Due to reparameterization trick, it ends up being
        // log_prob(actor_action) - Q(s, actor_action)
        //

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        // Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        self.actor_network_optimizer.zero_grad()
        actor_loss_mean.backward()
        self.actor_network_optimizer.step()

        // Use the soft update rule to update both target networks
        self._soft_update(self.value_network, self.value_network_target, self.tau)

After Change


            components += ["q2_network", "q2_network_optimizer"]
        return components

    def train(self, training_batch) -> None:
        
        IMPORTANT: the input action here is assumed to be preprocessed to match the
        range of the output of the actor.
        
        if hasattr(training_batch, "as_parametric_sarsa_training_batch"):
            training_batch = training_batch.as_parametric_sarsa_training_batch()

        learning_input = training_batch.training_input
        self.minibatch += 1

        state = learning_input.state
        action = learning_input.action
        reward = learning_input.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = learning_input.not_terminal

        if self._should_scale_action_in_train():
            action = rlt.FeatureVector(
                rescale_torch_tensor(
                    action.float_features,
                    new_min=self.min_action_range_tensor_training,
                    new_max=self.max_action_range_tensor_training,
                    prev_min=self.min_action_range_tensor_serving,
                    prev_max=self.max_action_range_tensor_serving,
                )
            )

        current_state_action = rlt.StateAction(state=state, action=action)

        q1_value = self.q1_network(current_state_action).q_value
        min_q_value = q1_value

        if self.q2_network:
            q2_value = self.q2_network(current_state_action).q_value
            min_q_value = torch.min(q1_value, q2_value)

        // Use the minimum as target, ensure no gradient going through
        min_q_value = min_q_value.detach()

        //
        // First, optimize value network; minimizing MSE between
        // V(s) & Q(s, a) - log(pi(a|s))
        //

        state_value = self.value_network(state.float_features)  // .q_value

        if self.logged_action_uniform_prior:
            log_prob_a = torch.zeros_like(min_q_value)
            target_value = min_q_value
        else:
            with torch.no_grad():
                log_prob_a = self.actor_network.get_log_prob(
                    state, action.float_features
                )
                log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                target_value = min_q_value - self.entropy_temperature * log_prob_a

        value_loss = F.mse_loss(state_value, target_value)
        value_loss.backward()
        self._maybe_run_optimizer(
            self.value_network_optimizer, self.minibatches_per_step
        )

        //
        // Second, optimize Q networks; minimizing MSE between
        // Q(s, a) & r + discount * V"(next_s)
        //

        with torch.no_grad():
            next_state_value = (
                self.value_network_target(learning_input.next_state.float_features)
                * not_done_mask.float()
            )

            target_q_value = reward + discount * next_state_value

        q1_loss = F.mse_loss(q1_value, target_q_value)
        q1_loss.backward()
        self._maybe_run_optimizer(self.q1_network_optimizer, self.minibatches_per_step)
        if self.q2_network:
            q2_loss = F.mse_loss(q2_value, target_q_value)
            q2_loss.backward()
            self._maybe_run_optimizer(
                self.q2_network_optimizer, self.minibatches_per_step
            )

        //
        // Lastly, optimize the actor; minimizing KL-divergence between action propensity
        // & softmax of value. Due to reparameterization trick, it ends up being
        // log_prob(actor_action) - Q(s, actor_action)
        //

        actor_output = self.actor_network(rlt.StateInput(state=state))

        state_actor_action = rlt.StateAction(
            state=state, action=rlt.FeatureVector(float_features=actor_output.action)
        )
        q1_actor_value = self.q1_network(state_actor_action).q_value
        min_q_actor_value = q1_actor_value
        if self.q2_network:
            q2_actor_value = self.q2_network(state_actor_action).q_value
            min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

        actor_loss = (
            self.entropy_temperature * actor_output.log_prob - min_q_actor_value
        )
        // Do this in 2 steps so we can log histogram of actor loss
        actor_loss_mean = actor_loss.mean()
        actor_loss_mean.backward()
        self._maybe_run_optimizer(
            self.actor_network_optimizer, self.minibatches_per_step
        )

        // Use the soft update rule to update both target networks
        self._maybe_soft_update(
            self.value_network,
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 15

Instances


Project Name: facebookresearch/Horizon
Commit Name: 7704d54a65086803c9a3258d5d65a21e04db5d04
Time: 2019-04-24
Author: lucasadams@fb.com
File Name: ml/rl/training/sac_trainer.py
Class Name: SACTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: 7704d54a65086803c9a3258d5d65a21e04db5d04
Time: 2019-04-24
Author: lucasadams@fb.com
File Name: ml/rl/training/sac_trainer.py
Class Name: SACTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: 7704d54a65086803c9a3258d5d65a21e04db5d04
Time: 2019-04-24
Author: lucasadams@fb.com
File Name: ml/rl/training/parametric_dqn_trainer.py
Class Name: ParametricDQNTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: 7704d54a65086803c9a3258d5d65a21e04db5d04
Time: 2019-04-24
Author: lucasadams@fb.com
File Name: ml/rl/training/ddpg_trainer.py
Class Name: DDPGTrainer
Method Name: train


Project Name: facebookresearch/Horizon
Commit Name: 7704d54a65086803c9a3258d5d65a21e04db5d04
Time: 2019-04-24
Author: lucasadams@fb.com
File Name: ml/rl/training/dqn_trainer.py
Class Name: DQNTrainer
Method Name: calculate_cpes