with torch.no_grad():
x = self(torch.Tensor(observation).unsqueeze(0))
return x.squeeze(0).numpy(), dict()
def get_actions(self, observations):
Get actions given observations.
After Change
observation = self._env_spec.observation_space.flatten(observation)
with torch.no_grad():
observation = torch.Tensor(observation).unsqueeze(0)
action, agent_infos = self.get_actions(observation)
return action[0], {k: v[0] for k, v in agent_infos.items()}
def get_actions(self, observations):
Get actions given observations.