756e0cd08ee8c661e2c6d0adee360c4792b97ca0,ml/rl/training/dqn_trainer.py,DQNTrainer,train,#DQNTrainer#Any#Any#,153
Before Change
)
next_q_values = self.get_next_action_q_values(next_states, next_actions)
filtered_next_q_vals = next_q_values.reshape(-1, 1) * not_done_mask
if self.use_reward_burnin and self.minibatch < self.reward_burnin:
target_q_values = rewards
After Change
next_actions = training_samples.next_actions
next_q_values = self.get_next_action_q_values(next_states, next_actions)
filtered_next_q_vals = next_q_values * not_done_mask
if self.use_reward_burnin and self.minibatch < self.reward_burnin:
target_q_values = rewards
else:
target_q_values = rewards + (discount_tensor * filtered_next_q_vals)
// Get Q-value of action taken
all_q_values = self.q_network(states)
self.all_action_scores = deepcopy(all_q_values.detach())
q_values = torch.sum(all_q_values * actions, 1, keepdim=True)
logger.info(q_values.shape)
logger.info(target_q_values.shape)
logger.info(rewards.shape)
logger.info(next_q_values.shape)
loss = self.q_network_loss(q_values, target_q_values)
self.loss = loss.detach()
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 3
Instances
Project Name: facebookresearch/Horizon
Commit Name: 756e0cd08ee8c661e2c6d0adee360c4792b97ca0
Time: 2018-09-18
Author: jjg@fb.com
File Name: ml/rl/training/dqn_trainer.py
Class Name: DQNTrainer
Method Name: train
Project Name: ilastik/ilastik
Commit Name: fe073644f6a8f37e9ce57df903bf12b560690fc3
Time: 2012-09-14
Author: christoph.straehle@iwr.uni-heidelberg.de
File Name: lazyflow/operators/obsolete/classifierOperators.py
Class Name: OpPredictRandomForest
Method Name: execute
Project Name: markovmodel/PyEMMA
Commit Name: d1de2f98ec8f851228e28bf28a0f800e90bd9bcf
Time: 2017-12-21
Author: m.scherer@fu-berlin.de
File Name: pyemma/msm/estimators/maximum_likelihood_msm.py
Class Name: AugmentedMarkovModel
Method Name: _estimate