cand_list = cands[i]
if len(ordering) != len(cand_list):
// ignore padding
true_ordering = [x for x in ordering if x < len(cand_list)]
ordering = true_ordering
// using a generator instead of a list comprehension allows
// to cap the number of elements.
cand_preds_generator = (cand_list[rank] for rank in ordering)
After Change
self.metrics["examples"] += batchsize
for b in range(batchsize):
rank = (ranks[b] == label_inds[b]).nonzero()
rank = rank.item() if len(rank) == 1 else scores.size(1)
self.metrics["rank"] += 1 + rank
self.metrics["mrr"] += 1.0 / (1 + rank)