756e0cd08ee8c661e2c6d0adee360c4792b97ca0,ml/rl/training/dqn_trainer.py,DQNTrainer,train,#DQNTrainer#Any#Any#,153

Before Change


            )
            next_q_values = self.get_next_action_q_values(next_states, next_actions)

        filtered_next_q_vals = next_q_values.reshape(-1, 1) * not_done_mask

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            target_q_values = rewards

After Change


            next_actions = training_samples.next_actions
            next_q_values = self.get_next_action_q_values(next_states, next_actions)

        filtered_next_q_vals = next_q_values * not_done_mask

        if self.use_reward_burnin and self.minibatch < self.reward_burnin:
            target_q_values = rewards
        else:
            target_q_values = rewards + (discount_tensor * filtered_next_q_vals)

        // Get Q-value of action taken
        all_q_values = self.q_network(states)
        self.all_action_scores = deepcopy(all_q_values.detach())
        q_values = torch.sum(all_q_values * actions, 1, keepdim=True)

        logger.info(q_values.shape)
        logger.info(target_q_values.shape)
        logger.info(rewards.shape)
        logger.info(next_q_values.shape)
        loss = self.q_network_loss(q_values, target_q_values)
        self.loss = loss.detach()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: facebookresearch/Horizon
Commit Name: 756e0cd08ee8c661e2c6d0adee360c4792b97ca0
Time: 2018-09-18
Author: jjg@fb.com
File Name: ml/rl/training/dqn_trainer.py
Class Name: DQNTrainer
Method Name: train


Project Name: ilastik/ilastik
Commit Name: fe073644f6a8f37e9ce57df903bf12b560690fc3
Time: 2012-09-14
Author: christoph.straehle@iwr.uni-heidelberg.de
File Name: lazyflow/operators/obsolete/classifierOperators.py
Class Name: OpPredictRandomForest
Method Name: execute


Project Name: markovmodel/PyEMMA
Commit Name: d1de2f98ec8f851228e28bf28a0f800e90bd9bcf
Time: 2017-12-21
Author: m.scherer@fu-berlin.de
File Name: pyemma/msm/estimators/maximum_likelihood_msm.py
Class Name: AugmentedMarkovModel
Method Name: _estimate