torch.Tensor: a tensor with the confusion matrix with shape
:math:`(N, C, C)` where C is the number of classes.
if not torch.is_tensor(y_true):
raise TypeError("Input y_true type is not a torch.Tensor. Got {}"
.format(type(y_true)))
if not torch.is_tensor(y_pred):
raise TypeError("Input y_pred type is not a torch.Tensor. Got {}"
.format(type(y_pred)))
if not y_true.shape == y_pred.shape:
raise ValueError("Inputs y_true and y_pred must have the same shape. "
"Got: {}".format(y_true.shape, y_pred.shape))
if not y_true.device == y_pred.device:
raise ValueError("Inputs must be in the same device. "
"Got: {} - {}".format(y_true.device, y_pred.device))
if not isinstance(num_classes, int) or num_classes < 2:
raise ValueError("The number of classes must be an intenger bigger "
"than two. Got: {}".format(num_classes))
batch_size: int = y_true.shape[0]
y_true_vec: torch.Tensor = y_true.view(batch_size, -1)
y_pred_vec: torch.Tensor = y_pred.view(batch_size, -1)
// NOTE: torch.bincount does not implement batched version
pre_bincount: torch.Tensor = y_true_vec * num_classes + y_pred_vec
confusion_vec: torch.Tensor = torch.stack([
torch.bincount(pb) for pb in pre_bincount
])
After Change
torch.Tensor: a tensor with the confusion matrix with shape
:math:`(B, K, K)` where K is the number of classes.
if not isinstance(input, torch.LongTensor):
raise TypeError("Input input type is not a torch.LongTensor. Got {}"
.format(type(input)))
if not isinstance(target, torch.LongTensor):
raise TypeError("Input target type is not a torch.LongTensor. Got {}"
.format(type(target)))
if not input.shape == target.shape:
raise ValueError("Inputs input and target must have the same shape. "
"Got: {}".format(input.shape, target.shape))
if not input.device == target.device:
raise ValueError("Inputs must be in the same device. "
"Got: {} - {}".format(input.device, target.device))
if not isinstance(num_classes, int) or num_classes < 2:
raise ValueError("The number of classes must be an intenger bigger "
"than two. Got: {}".format(num_classes))
batch_size: int = input.shape[0]
// hack for bitcounting 2 arrays together
// NOTE: torch.bincount does not implement batched version
pre_bincount: torch.Tensor = input + target * num_classes
pre_bincount_vec: torch.Tensor = pre_bincount.view(batch_size, -1)
confusion_vec: torch.Tensor = torch.stack([
torch.bincount(pb, minlength=num_classes**2) for pb in pre_bincount_vec