if self.normalize_state:
state = policy_util.update_online_stats_and_normalize_state(body, state)
states.append(state)
xs = [torch.from_numpy(state).float() for state in states]
pdparam = self.calc_pdparam(xs)
// use multi-policy. note arg change
action_a = self.action_policy(states, self, self.agent.nanflat_body_a, pdparam)
After Change
if self.normalize_state:
state = policy_util.update_online_stats_and_normalize_state(body, state)
states.append(state)
xs = [torch.from_numpy(state.astype(np.float32)) for state in states]
pdparam = self.calc_pdparam(xs)
// use multi-policy. note arg change
action_a = self.action_policy(states, self, self.agent.nanflat_body_a, pdparam)