import math from typing import List from PIL import Image def get_image_grid(images: List[Image.Image]) -> Image: num_images = len(images) cols = int(math.ceil(math.sqrt(num_images))) rows = int(math.ceil(num_images / cols)) width, height = images[0].size grid_image = Image.new('RGB', (cols * width, rows * height)) for i, img in enumerate(images): x = i % cols y = i // cols grid_image.paste(img, (x * width, y * height)) return grid_image