grid_dim = grid.size(-1)
grid_data = torch.zeros(int(pow(grid_size, grid_dim)), grid_dim, device=grid.device)
prev_points = None
for i in range(grid_dim):
for j in range(grid_size):
grid_data[j * grid_size ** i : (j + 1) * grid_size ** i, i].fill_(grid[j, i])
if prev_points is not None:
After Change
Returns the set of points on the grid going by column-major order
(due to legacy reasons).
if torch.is_tensor(grid):
grid = convert_legacy_grid(grid)
ndims = len(grid)
assert all(axis.dim() == 1 for axis in grid)
projections = torch.meshgrid(*grid)
grid_tensor = torch.stack(projections, axis=-1)