// prepare kernel
b, c, h, w = input.shape
tmp_kernel: torch.Tensor = self.kernel.to(input.device).to(input.dtype)
kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1, 1)
// convolve input tensor with sobel kernel
kernel_flip: torch.Tensor = kernel.flip(-3)
// Pad with "replicate for spatial dims, but with zeros for channel
After Change
// prepare kernel
b, c, h, w = input.shape
tmp_kernel: torch.Tensor = self.kernel.to(input.device).to(input.dtype)
kernel: torch.Tensor = tmp_kernel.unsqueeze(1).unsqueeze(1)
// convolve input tensor with sobel kernel
kernel_flip: torch.Tensor = kernel.flip(-3)
// Pad with "replicate for spatial dims, but with zeros for channel