state = policy_util.update_online_stats_and_normalize_state(body, state)
action, action_pd = self.action_policy(state, self, body)
// sum for single and multi-action
body.entropies.append(action_pd.entropy().sum(dim=0))
body.log_probs.append(action_pd.log_prob(action.float()).sum(dim=0))
assert not torch.isnan(body.log_probs[-1])
if len(action.shape) == 0: // scalar
return action.cpu().numpy().astype(body.action_space.dtype).item()
After Change
if self.normalize_state:
state = policy_util.update_online_stats_and_normalize_state(body, state)
action, action_pd = self.action_policy(state, self, body)
body.action_tensor, body.action_pd = action, action_pd // used for body.action_pd_update later
if len(action.shape) == 0: // scalar
return action.cpu().numpy().astype(body.action_space.dtype).item()
else: