// The .prod() here is functioning as a logical and.
correct = predictions.eq(gold_labels).prod(dim=1).float()
count = torch.ones(gold_labels.size(0))
self._correct_count += correct.sum()
self._total_count += count.sum()
def get_metric(self, reset: bool = False):
After Change
// We want to skip predictions that are completely masked;
// so we"ll keep predictions that aren"t.
keep = mask.view(batch_size, -1).max(dim=1)[0].float()
else:
keep = torch.ones(batch_size).float()
predictions = predictions.view(batch_size, -1)