for label, pred in zip(labels, preds):
pred = pred.max(dim=self.axis)[1]
label = label.max(dim=self.axis)[1]
self.sum_metric += pred.eq(label).sum()
self.num_inst += pred.numel()
def reset(self):
After Change
check_label_shapes(labels, preds)
if self.on_cpu:
if self.sparse_label:
label_imask = labels.cpu().numpy().astype(np.int32)
else:
label_imask = torch.argmax(labels, dim=self.axis).cpu().numpy().astype(np.int32)
pred_imask = torch.argmax(preds, dim=self.axis).cpu().numpy().astype(np.int32)