// repeat coords along channel dim if not given that way
if coords.dim() == 2:
coords = coords.unsqueeze(0).repeat(input.size(0),1,1)
// take clamp of coords so they"re in the image bounds
xc = torch.clamp(coords[:,:,0], 0, input.size(1)-1)
yc = torch.clamp(coords[:,:,1], 0, input.size(2)-1)
zc = torch.clamp(coords[:,:,2], 0, input.size(3)-1)
// round to nearest coordinate
coords = torch.stack([xc.round().long(),
yc.round().long(),
zc.round().long()], 2)
// gather image values at coordinates
mapped_vals = torch.stack([th_gather_nd(input[i], coords[i])
for i in range(input.size(0))], 0)
return mapped_vals.view_as(input)