def choose_action(self, curr_state, phase=RunPhase.TRAIN):
// convert to batch so we can run it through the network
observation = np.expand_dims(np.array(curr_state["observation"]), 0)
if self.tp.agent.use_measurements:
measurements = np.expand_dims(np.array(curr_state["measurements"]), 0)
prediction = self.main_network.online_network.predict([observation, measurements])
else:
After Change
def choose_action(self, curr_state, phase=RunPhase.TRAIN):
// convert to batch so we can run it through the network
prediction = self.main_network.online_network.predict(self.tf_input_state(curr_state))
// get action values and extract the best action from it
action_values = self.extract_action_values(prediction)