"global_avg"],
default="global_avg")
if reduction_type == "flatten":
output_node = Flatten().build(hp, output_node)
elif reduction_type == "global_max":
output_node = layer_utils.get_global_max_pooling(
output_node.shape)()(output_node)
elif reduction_type == "global_avg":
output_node = layer_utils.get_global_average_pooling(
output_node.shape)()(output_node)
return output_node
class TemporalReduction(block_module.Block):
Reduce the dimension of a temporal tensor, e.g. output of RNN, to a vector.