def get_torch_exploration_action(self, action_dist, explore):
tensor_fn = torch.LongTensor if \
type(self.action_space) in [Discrete, MultiDiscrete] else \
torch.FloatTensor
if explore:
// Unsqueeze will be unnecessary, once we support batch/time-aware
After Change
// Add a batch dimension.
if len(action_dist.inputs.shape) == len(req) + 1:
a = np.expand_dims(a, 0)
action = torch.from_numpy(a).to(self.device)
else:
action = action_dist.deterministic_sample()
logp = torch.zeros(
(action.size()[0], ), dtype=torch.float32, device=self.device)