// to add back the subtracted with v_pred at n
gammas = calc_gammas(batch, gamma)
final_terms = gammas * v_preds
final_terms = torch.cat([final_terms[n:], torch.zeros((n,))])[:rets_len]
nstep_rets = rets - tail_rets + final_terms
assert not np.isnan(nstep_rets).any(), f"N-step returns has nan: {nstep_rets}"
return nstep_rets
After Change
then add v_pred for n as final term
"""
rets = copy.deepcopy(batch["rewards"])
nstep_rets = np.zeros_like(rets) + rets
cur_gamma = gamma
for i in range(1, n):
// Shift returns by one and pad with zeros
rets[:-1] = rets[1:]
rets[-1] = 0
nstep_rets += cur_gamma * rets
// Update current gamma
cur_gamma *= cur_gamma
// Add final terms. Note no next state if epi is done
final_terms = cur_gamma * next_v_preds * (1 - batch["dones"])