// Merge batch and beam dimensions.
original_shape = tf.shape(target_tokens)
target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]])
attention = tf.reshape(alignment, [-1, tf.shape(alignment)[2], tf.shape(alignment)[3]])
// We don"t have attention for </s> but ensure that the attention time dimension matches
// the tokens time dimension.
attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1])
After Change
// Merge batch and beam dimensions.
original_shape = tf.shape(target_tokens)
target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]])
align_shape = shape_list(alignment)
attention = tf.reshape(
alignment, [align_shape[0] * align_shape[1], align_shape[2], align_shape[3]])
// We don"t have attention for </s> but ensure that the attention time dimension matches
// the tokens time dimension.
attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1])