import math | |
from typing import List | |
from PIL import Image | |
def get_image_grid(images: List[Image.Image]) -> Image: | |
num_images = len(images) | |
cols = int(math.ceil(math.sqrt(num_images))) | |
rows = int(math.ceil(num_images / cols)) | |
width, height = images[0].size | |
grid_image = Image.new('RGB', (cols * width, rows * height)) | |
for i, img in enumerate(images): | |
x = i % cols | |
y = i // cols | |
grid_image.paste(img, (x * width, y * height)) | |
return grid_image | |