lp_constant=params.decode_constant)
// Set inputs back to the unexpanded inputs to not to confuse the Estimator
features["source"] = inputs_old
features["source_length"] = inputs_length_old
// Return `top_beams` decoding
After Change
normalized_scores = scores / output_length
if params.decode_normalize:
scores, indices = tf.nn.top_k(normalized_scores, k=top_beams)
// shape of ids: [batch, beam_size, max_length]
// shape of coordinates: [batch, beam_size, 2]
batch_pos = compute_batch_indices(batch_size, beam_size)
coordinates = tf.stack([batch_pos, indices], axis=2)ids = tf.gather_nd(ids, coordinates)
// Return `top_beams` decoding
// (also remove initial id from the beam search)
return ids[:, :top_beams, 1:], scores[:, :top_beams]