def choose_action(self, curr_state, phase=RunPhase.TRAIN):
assert not self.env.discrete_controls, "DDPG works only for continuous control problems"
// convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state["observation"]), 0)
result = self.actor_network.online_network.predict(observation)
action_values = result[0].squeeze()
After Change
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
assert not self.env.discrete_controls, "DDPG works only for continuous control problems"
result = self.actor_network.online_network.predict(self.tf_input_state(curr_state))
action_values = result[0].squeeze()
if phase == RunPhase.TRAIN:
action = self.exploration_policy.get_action(action_values)
else:
action = action_values
action = np.clip(action, self.env.action_space_low, self.env.action_space_high)
// get q value
action_batch = np.expand_dims(action, 0)
if type(action) != np.ndarray:
action_batch = np.array([[action]])
inputs = self.tf_input_state(curr_state)
inputs["action"] = action_batch
q_value = self.critic_network.online_network.predict(inputs)[0]
self.q_values.add_sample(q_value)
action_info = {"action_value": q_value}