seq_len, batch_size, tag_size = unary.size()
min_length = torch.min(lengths)
batch_range = torch.arange(batch_size, dtype=torch.int64)
backpointers = []
// Alphas: [B, 1, N]
alphas = torch.Tensor(batch_size, 1, tag_size).fill_(-1e4).to(unary.device)
alphas[:, 0, start_idx] = 0
alphas = F.log_softmax(alphas, dim=-1) if norm else alphas
for i, unary_t in enumerate(unary):
next_tag_var = alphas + trans
viterbi, best_tag_ids = torch.max(next_tag_var, 2)
backpointers.append(best_tag_ids.data)
new_alphas = viterbi + unary_t
new_alphas.unsqueeze_(1)
if i >= min_length:
mask = (i < lengths).view(-1, 1, 1)
alphas = alphas.masked_fill(mask, 0) + new_alphas.masked_fill(mask == 0, 0)
else:
alphas = new_alphas
// Add end tag
terminal_var = alphas.squeeze(1) + trans[:, end_idx]
_, best_tag_id = torch.max(terminal_var, 1)
path_score = terminal_var[batch_range, best_tag_id] // Select best_tag from each batch
best_path = [best_tag_id]
// Flip lengths
rev_len = seq_len - lengths - 1
for i, backpointer_t in enumerate(reversed(backpointers)):
// Get new best tag candidate
new_best_tag_id = backpointer_t[batch_range, best_tag_id]
// We are going backwards now, if you passed your flipped length then you aren"t in your real results yet
mask = (i > rev_len)
best_tag_id = best_tag_id.masked_fill(mask, 0) + new_best_tag_id.masked_fill(mask == 0, 0)
best_path.append(best_tag_id)
After Change
rev_len = seq_len - lengths - 1
for i, backpointer_t in enumerate(reversed(backpointers)):
// Get new best tag candidate
new_best_tag_id = backpointer_t.gather(1, best_tag_id.unsqueeze(1)).squeeze(1)
// We are going backwards now, if you passed your flipped length then you aren"t in your real results yet
mask = (i > rev_len)
best_tag_id = best_tag_id.masked_fill(mask, 0) + new_best_tag_id.masked_fill(mask == 0, 0)
best_path.append(best_tag_id)