// End condition is the top beam reached end_token.
finished = topk_ids[:, 0].eq(end_token)
finished_count = finished.sum()
// Save result of finished sentences.
if finished_count > 0 or step + 1 == self.max_length:
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
After Change
end_condition = topk_ids[:, 0].eq(end_token)
if step + 1 == self.max_length:
end_condition.fill_(1)
finished = end_condition.nonzero().view(-1)
// Save result of finished sentences.
if len(finished) > 0:
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
scores = topk_scores.view(-1, beam_size)
for i in finished: