// Use online_net to select actions in next state
online_next_q_preds = self.online_net(batch["next_states"])
// Use eval_net to calculate next_q_preds for actions chosen by online_net
next_q_preds = self.eval_net(batch["next_states"])
max_next_q_preds = next_q_preds.gather(-1, online_next_q_preds.argmax(dim=-1, keepdim=True)).squeeze(-1)
max_q_targets = batch["rewards"] + self.gamma * (1 - batch["dones"]) * max_next_q_preds
max_q_targets = max_q_targets.detach()
After Change
if self.body.env.is_venv:
states = math_util.venv_unpack(states)
next_states = math_util.venv_unpack(next_states)
q_preds = self.net(states)
// Use online_net to select actions in next state
online_next_q_preds = self.online_net(next_states)
// Use eval_net to calculate next_q_preds for actions chosen by online_net
next_q_preds = self.eval_net(next_states)