FontDiffuser-Gradio / dataset /font_dataset.py
yeungchenwa's picture
[Update] Add files and checkpoint
508b842
raw
history blame
2.66 kB
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
def get_nonorm_transform(resolution):
nonorm_transform = transforms.Compose(
[transforms.Resize((resolution, resolution),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor()])
return nonorm_transform
class FontDataset(Dataset):
"""The dataset of font generation
"""
def __init__(self, args, phase, transforms=None):
super().__init__()
self.root = args.data_root
self.phase = phase
# Get Data path
self.get_path()
self.transforms = transforms
self.nonorm_transforms = get_nonorm_transform(args.resolution)
def get_path(self):
self.target_images = []
# images with related style
self.style_to_images = {}
target_image_dir = f"{self.root}/{self.phase}/TargetImage"
for style in os.listdir(target_image_dir):
images_related_style = []
for img in os.listdir(f"{target_image_dir}/{style}"):
img_path = f"{target_image_dir}/{style}/{img}"
self.target_images.append(img_path)
images_related_style.append(img_path)
self.style_to_images[style] = images_related_style
def __getitem__(self, index):
target_image_path = self.target_images[index]
target_image_name = target_image_path.split('/')[-1]
style, content = target_image_name.split('.')[0].split('+')
# Read content image
content_image_path = f"{self.root}/{self.phase}/ContentImage/{content}.jpg"
content_image = Image.open(content_image_path).convert('RGB')
# Random sample used for style image
images_related_style = self.style_to_images[style].copy()
images_related_style.remove(target_image_path)
style_image_path = random.choice(images_related_style)
style_image = Image.open(style_image_path).convert("RGB")
# Read target image
target_image = Image.open(target_image_path).convert("RGB")
nonorm_target_image = self.nonorm_transforms(target_image)
if self.transforms is not None:
content_image = self.transforms[0](content_image)
style_image = self.transforms[1](style_image)
target_image = self.transforms[2](target_image)
return content_image, style_image, target_image, nonorm_target_image, target_image_path
def __len__(self):
return len(self.target_images)