// advantage and returns
// len x num_heads x num_atoms
advantages = np.stack([
utils.geometric_cumsum(gamma, deltas[:, i])
for i, gamma in enumerate(self._gammas)
], axis=1)
// len x num_heads
returns = np.stack([
After Change
// len x num_heads
returns = np.stack([
utils.geometric_cumsum(gamma, rewards[:, None])[:, 0]
for gamma in self._gammas
], axis=1)