top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
kth_best = top_values[:, -1].view([-1, 1])
kth_best = kth_best.repeat([1, logits.shape[1]])
kth_best = kth_best.type(torch.cuda.FloatTensor)
// Set all logits that are not in the top-k to -1000.
// This puts the probabilities close to 0.
keep = torch.ge(logits, kth_best).type(torch.cuda.FloatTensor)
After Change
if keep_topk > 0:
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
kth_best = top_values[:, -1].view([-1, 1])
kth_best = kth_best.repeat([1, logits.shape[1]]).float()
// Set all logits that are not in the top-k to -1000.
// This puts the probabilities close to 0.
keep = torch.ge(logits, kth_best).float()