rev_len = seq_len - lengths - 1
for i, backpointer_t in enumerate(reversed(backpointers)):
// Get new best tag candidate
new_best_tag_id = backpointer_t[batch_range, best_tag_id]
// We are going backwards now, if you passed your flipped length then you aren"t in your real results yet
mask = (i > rev_len)
best_tag_id = best_tag_id.masked_fill(mask, 0) + new_best_tag_id.masked_fill(mask == 0, 0)
best_path.append(best_tag_id)
After Change
best_path = [best_tag_id]
for i, backpointer_t in enumerate(reversed(backpointers)):
// Get new best tag candidate
new_best_tag_id = backpointer_t.gather(1, best_tag_id.unsqueeze(1)).squeeze(1)
// We are going backwards now, if flipped length was passed
// these you aren"t in your real results yet
mask = (i > rev_len)
best_tag_id = best_tag_id.masked_fill(mask, 0) + new_best_tag_id.masked_fill(mask == 0, 0)