GCycleGAN / utils /dataloader.py
Egrt's picture
init
95e767b
raw
history blame
1.79 kB
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class CycleGanDataset(Dataset):
def __init__(self, annotation_lines_A, annotation_lines_B, input_shape):
super(CycleGanDataset, self).__init__()
self.annotation_lines_A = annotation_lines_A
self.annotation_lines_B = annotation_lines_B
self.length_A = len(self.annotation_lines_A)
self.length_B = len(self.annotation_lines_B)
self.input_shape = input_shape
def __len__(self):
return max(self.length_A, self.length_B)
def __getitem__(self, index):
index_A = index % self.length_A
image_A = Image.open(self.annotation_lines_A[index_A].split(';')[1].split()[0])
image_A = cvtColor(image_A).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC)
image_A = np.array(image_A, dtype=np.float32)
image_A = np.transpose(preprocess_input(image_A), (2, 0, 1))
index_B = index % self.length_B
image_B = Image.open(self.annotation_lines_B[index_B].split(';')[1].split()[0])
image_B = cvtColor(image_B).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC)
image_B = np.array(image_B, dtype=np.float32)
image_B = np.transpose(preprocess_input(image_B), (2, 0, 1))
return image_A, image_B
def CycleGan_dataset_collate(batch):
images_A = []
images_B = []
for image_A, image_B in batch:
images_A.append(image_A)
images_B.append(image_B)
images_A = torch.from_numpy(np.array(images_A, np.float32))
images_B = torch.from_numpy(np.array(images_B, np.float32))
return images_A, images_B