// If you actually passed in Variables here instead of Tensors, this will be a huge memory
// leak, because it will prevent garbage collection for the computation graph. We"ll ensure
// that we"re using tensors here first.
if isinstance(predictions, Variable):
predictions = predictions.data
if isinstance(gold_labels, Variable):
gold_labels = gold_labels.data
if isinstance(mask, Variable):
mask = mask.data
if mask is not None:
// We can multiply by the mask up front, because we"re just checking equality below, and
// this way everything that"s masked will be equal.
predictions = predictions * mask
gold_labels = gold_labels * mask
batch_size = predictions.size(0)
predictions = predictions.view(batch_size, -1)
gold_labels = gold_labels.view(batch_size, -1)
// The .prod() here is functioning as a logical and.
correct = predictions.eq(gold_labels).prod(dim=1).float()
count = torch.ones(gold_labels.size(0))
self._correct_count += correct.sum()
self._total_count += count.sum()
def get_metric(self, reset: bool = False):