states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
// Biggest state is best state.
value = tf.reduce_max(input_tensor=states, axis=-1)
value = tf.reshape(value, [-1])
// Biggest action is best action.
q_value = tf.reduce_max(input_tensor=actions, axis=-1)
q_value = tf.reshape(q_value, [-1])
// Biggest state is best state.
return value + q_value, ()
class SacAgentTest(test_utils.TestCase):
After Change
a_value = self._action_layer(actions)
// Biggest state is best state.
q_value = tf.reshape(s_value + a_value, [-1])
return q_value, network_state
class SacAgentTest(test_utils.TestCase):