tiled_cells = tf.tile(cells, (n_cells, 1))
// Lists of n_cells tensors of shape (N, 1)
tiled_centers = tf.split(tiled_centers, n_cells)
tiled_cells = tf.split(tiled_cells, n_cells)
// Lists of length n_cells
coords_rel = [
tf.to_float(cells) - tf.to_float(centers)
for (cells, centers) in zip(tiled_centers, tiled_cells)
]
coords_norm = [tf.reduce_sum(rel**2, axis=1) for rel in coords_rel]
// Lists of length n_cells
// Get indices of n_nbrs atoms closest to each cell point
// n_cells tensors of shape (n_nbrs,)
closest_inds = tf.stack([tf.nn.top_k(norm, k=n_nbrs)[1] for norm in coords_norm])
return closest_inds