|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data.dataset import Dataset |
|
|
|
from utils.utils import cvtColor, preprocess_input |
|
|
|
|
|
class CycleGanDataset(Dataset): |
|
def __init__(self, annotation_lines_A, annotation_lines_B, input_shape): |
|
super(CycleGanDataset, self).__init__() |
|
|
|
self.annotation_lines_A = annotation_lines_A |
|
self.annotation_lines_B = annotation_lines_B |
|
self.length_A = len(self.annotation_lines_A) |
|
self.length_B = len(self.annotation_lines_B) |
|
|
|
self.input_shape = input_shape |
|
|
|
def __len__(self): |
|
return max(self.length_A, self.length_B) |
|
|
|
def __getitem__(self, index): |
|
index_A = index % self.length_A |
|
image_A = Image.open(self.annotation_lines_A[index_A].split(';')[1].split()[0]) |
|
image_A = cvtColor(image_A).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC) |
|
image_A = np.array(image_A, dtype=np.float32) |
|
image_A = np.transpose(preprocess_input(image_A), (2, 0, 1)) |
|
|
|
index_B = index % self.length_B |
|
image_B = Image.open(self.annotation_lines_B[index_B].split(';')[1].split()[0]) |
|
image_B = cvtColor(image_B).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC) |
|
image_B = np.array(image_B, dtype=np.float32) |
|
image_B = np.transpose(preprocess_input(image_B), (2, 0, 1)) |
|
return image_A, image_B |
|
|
|
def CycleGan_dataset_collate(batch): |
|
images_A = [] |
|
images_B = [] |
|
for image_A, image_B in batch: |
|
images_A.append(image_A) |
|
images_B.append(image_B) |
|
images_A = torch.from_numpy(np.array(images_A, np.float32)) |
|
images_B = torch.from_numpy(np.array(images_B, np.float32)) |
|
return images_A, images_B |
|
|