Spaces:
Runtime error
Runtime error
import os | |
import random | |
from PIL import Image | |
from torch.utils.data import Dataset | |
import torchvision.transforms as transforms | |
def get_nonorm_transform(resolution): | |
nonorm_transform = transforms.Compose( | |
[transforms.Resize((resolution, resolution), | |
interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.ToTensor()]) | |
return nonorm_transform | |
class FontDataset(Dataset): | |
"""The dataset of font generation | |
""" | |
def __init__(self, args, phase, transforms=None): | |
super().__init__() | |
self.root = args.data_root | |
self.phase = phase | |
# Get Data path | |
self.get_path() | |
self.transforms = transforms | |
self.nonorm_transforms = get_nonorm_transform(args.resolution) | |
def get_path(self): | |
self.target_images = [] | |
# images with related style | |
self.style_to_images = {} | |
target_image_dir = f"{self.root}/{self.phase}/TargetImage" | |
for style in os.listdir(target_image_dir): | |
images_related_style = [] | |
for img in os.listdir(f"{target_image_dir}/{style}"): | |
img_path = f"{target_image_dir}/{style}/{img}" | |
self.target_images.append(img_path) | |
images_related_style.append(img_path) | |
self.style_to_images[style] = images_related_style | |
def __getitem__(self, index): | |
target_image_path = self.target_images[index] | |
target_image_name = target_image_path.split('/')[-1] | |
style, content = target_image_name.split('.')[0].split('+') | |
# Read content image | |
content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg" | |
content_image = Image.open(content_image_path).convert('RGB') | |
# Random sample used for style image | |
images_related_style = self.style_to_images[style].copy() | |
images_related_style.remove(target_image_path) | |
style_image_path = random.choice(images_related_style) | |
style_image = Image.open(style_image_path).convert("RGB") | |
# Read target image | |
target_image = Image.open(target_image_path).convert("RGB") | |
nonorm_target_image = self.nonorm_transforms(target_image) | |
if self.transforms is not None: | |
content_image = self.transforms[0](content_image) | |
style_image = self.transforms[1](style_image) | |
target_image = self.transforms[2](target_image) | |
return content_image, style_image, target_image, nonorm_target_image, target_image_path | |
def __len__(self): | |
return len(self.target_images) | |