// this ensure we have enought space for the images
base_size = int(math.ceil((label_img.size(0)) ** 0.5))
// cat the images to reach the square of base_size
label_img = torch.cat((label_img, torch.zeros(base_size ** 2 - label_img.size(0), *label_img.size()[1:])), 0)
// this call make_grid insied, but now we can ensure a square grid
torchvision.utils.save_image(label_img, os.path.join(save_path, "sprite.png"), nrow=base_size, padding=0)