if explore:
// Unsqueeze will be unnecessary, once we support batch/time-aware
// Spaces.
action = tensor_fn(self.action_space.sample()).unsqueeze(0)
else:
action = tensor_fn(action_dist.deterministic_sample())
logp = torch.zeros((action.size()[0], ), dtype=torch.float32)
return action, logp
After Change
// Unsqueeze will be unnecessary, once we support batch/time-aware
// Spaces.
a = self.action_space.sample()
action = tensor_fn([a] if isinstance(a, int) else a)
else:
action = tensor_fn(action_dist.deterministic_sample())
logp = torch.zeros((action.size()[0], ), dtype=torch.float32)