batch_squash = utils.BatchSquash(outer_rank)
// Get single observation out regardless of nesting.
states = tf.cast(nest.flatten(observation)[0], tf.float32)
if self._batch_squash:
states = batch_squash.flatten(states)
After Change
outer_rank = nest_utils.get_outer_rank(
observation, self.input_tensor_spec)
batch_squash = utils.BatchSquash(outer_rank)
observation = nest.map_structure(batch_squash.flatten, observation)
if self._preprocessing_layers is None:
processed = observation
else:
processed = []
for obs, layer in zip(
nest.flatten_up_to(self.input_tensor_spec, observation),