from __future__ import annotations import json import math from pathlib import Path from typing import Any import numpy as np import torch import torchvision from einops import rearrange from PIL import Image from torch.utils.data import Dataset class EditDataset(Dataset): def __init__( self, path: str, split: str = "train", splits: tuple[float, float, float] = (0.9, 0.05, 0.05), min_resize_res: int = 256, max_resize_res: int = 256, crop_res: int = 256, flip_prob: float = 0.0, ): assert split in ("train", "val", "test") assert sum(splits) == 1 self.path = path self.min_resize_res = min_resize_res self.max_resize_res = max_resize_res self.crop_res = crop_res self.flip_prob = flip_prob with open(Path(self.path, "seeds.json")) as f: self.seeds = json.load(f) split_0, split_1 = { "train": (0.0, splits[0]), "val": (splits[0], splits[0] + splits[1]), "test": (splits[0] + splits[1], 1.0), }[split] idx_0 = math.floor(split_0 * len(self.seeds)) idx_1 = math.floor(split_1 * len(self.seeds)) self.seeds = self.seeds[idx_0:idx_1] def __len__(self) -> int: return len(self.seeds) def __getitem__(self, i: int) -> dict[str, Any]: name, seeds = self.seeds[i] propt_dir = Path(self.path, name) seed = seeds[torch.randint(0, len(seeds), ()).item()] with open(propt_dir.joinpath("prompt.json")) as fp: prompt = json.load(fp)["edit"] image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") crop = torchvision.transforms.RandomCrop(self.crop_res) flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))