observations_batch = rollouts.observations[:-1].view(-1, *obs_shape)[indices]
actions_batch = rollouts.actions.view(-1, action_shape)[indices]
return_batch = rollouts.returns[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = rollouts.action_log_probs.view(-1, 1)[indices]
// Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy = actor_critic.evaluate_actions(Variable(observations_batch), Variable(actions_batch))