:return: A tensor of shape ``truth.shape``.
:rtype: torch.Tensor
if pred.dim() != 1 + truth.dim() or pred.shape[1:] != truth.shape:
raise ValueError("Expected pred to have one extra sample dim on left. "
"Actual shapes: {} versus {}".format(pred.shape, truth.shape))
opts = dict(device=pred.device, dtype=pred.dtype)
After Change
:return: A tensor of shape ``truth.shape``.
:rtype: torch.Tensor
if pred.shape[1:] != (1,) * (pred.dim() - truth.dim() - 1) + truth.shape:
raise ValueError("Expected pred to have one extra sample dim on left. "
"Actual shapes: {} versus {}".format(pred.shape, truth.shape))
opts = dict(device=pred.device, dtype=pred.dtype)