means = np.zeros((1, self.mdp_info["action_space"].n))
sigmas = np.zeros(means.shape)
actions = self.mdp_info["action_space"].values
for i, a in enumerate(actions):
sa_n = [next_state, np.array([a])]
sa_n_idx = tuple(np.concatenate((next_state, np.array([a])),
axis=1).astype(np.int).ravel())
means[0, i] = self.approximator.predict(sa_n)
sigmas[0, i] = self._sigma[sa_n_idx]
if self.sampling:
samples = np.random.normal(np.repeat(means, self.precision, 0),
np.repeat(sigmas, self.precision, 0))
max_idx = np.argmax(samples, axis=1)
max_idx, max_count = np.unique(max_idx, return_counts=True)
count = np.zeros(actions.shape[0])
count[max_idx] = max_count
w = count / self.precision
else:
raise NotImplementedError
sa = [np.repeat(next_state, actions.shape[0], axis=0), actions]
W = np.dot(w, self.approximator.predict(sa))
return W
After Change
means = self.approximator.predict_all(next_state)
sigmas = np.zeros((1, self.approximator.discrete_actions.shape[0]))
for a in xrange(sigmas.size):
sa_n_idx = tuple(np.concatenate((next_state, np.array([[a]])),
axis=1).astype(np.int).ravel())