if isinstance(self.env_spec.observation_space, akro.Image) and \
len(observation.shape) < \
len(self.env_spec.observation_space.shape):
observation = self.env_spec.observation_space.unflatten(
observation)
q_vals = self._f_qval([observation])
opt_action = np.argmax(q_vals)
return opt_action, dict()
def get_actions(self, observations):
Get actions from this policy for the input observations.
After Change
opt_actions, agent_infos = self.get_actions([observation])
return opt_actions[0], {k: v[0] for k, v in agent_infos.items()}
def get_actions(self, observations):
Get actions from this policy for the input observations.