44d2847610944f56a06b7cfa54faadb66e130a83,allennlp/training/metrics/categorical_accuracy.py,CategoricalAccuracy,__call__,#CategoricalAccuracy#Any#Any#Any#,31
Before Change
if self._top_k == 1:
top_k = predictions.max(-1)[1].unsqueeze(-1)
else:
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
// This is of shape (batch_size, ..., top_k).
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
else:
After Change
self.total_count += gold_labels.numel()
self.correct_count += correct.sum()
if is_distributed():
_correct_count = torch.tensor(self.correct_count).to(device)
_total_count = torch.tensor(self.total_count).to(device)
dist.all_reduce(_correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_count, op=dist.ReduceOp.SUM)
self.correct_count = _correct_count.item()
self.total_count = _total_count.item()
def get_metric(self, reset: bool = False):
// Returns
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 3
Instances
Project Name: allenai/allennlp
Commit Name: 44d2847610944f56a06b7cfa54faadb66e130a83
Time: 2020-08-12
Author: akshita23bhagia@gmail.com
File Name: allennlp/training/metrics/categorical_accuracy.py
Class Name: CategoricalAccuracy
Method Name: __call__
Project Name: open-mmlab/mmdetection
Commit Name: cd0d37ccb398b7e80dd7af06925246e133d8178c
Time: 2019-12-18
Author: yhcao6@gmail.com
File Name: mmdet/apis/train.py
Class Name:
Method Name: parse_losses
Project Name: mapillary/inplace_abn
Commit Name: e6bbf54046cf4567e88cb130300b6b78ec88cb27
Time: 2018-11-28
Author: samuel@mapillary.com
File Name: test_imagenet.py
Class Name:
Method Name: validate