split_x = tf.split(x, split_indices, axis=1)
// Structured or flattened (by single action component) input.
else:
split_x = tree.flatten(x)
def map_(val, dist):
// Remove extra categorical dimension.
if isinstance(dist, Categorical):
val = tf.cast(tf.squeeze(val, axis=-1), tf.int32)
return dist.logp(val)
// Remove extra categorical dimension and take the logp of each
// component.
flat_logps = tree.map_structure(map_, split_x,
self.flat_child_distributions)
return functools.reduce(lambda a, b: a + b, flat_logps)
@override(ActionDistribution)