>>> output.shape
torch.Size([1, 3, 4, 4])
if not isinstance(input, torch.Tensor):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) == 4:
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
.format(input.shape))
// comput the x/y gradients
edges: torch.Tensor = spatial_gradient(input, normalized=normalized)
// unpack the edges
gx: torch.Tensor = edges[:, :, 0]
gy: torch.Tensor = edges[:, :, 1]
// compute gradient maginitude
magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
return magnitude
class SpatialGradient(nn.Module):
rComputes the first order image derivative in both x and y using a Sobel