|
|
|
|
|
import os |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import torch |
|
|
|
class GTResDataset(Dataset): |
|
|
|
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): |
|
self.pairs = [] |
|
for f in os.listdir(root_path): |
|
image_path = os.path.join(root_path, f) |
|
gt_path = os.path.join(gt_dir, f) |
|
if f.endswith(".jpg") or f.endswith(".png"): |
|
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) |
|
self.transform = transform |
|
self.transform_train = transform_train |
|
|
|
def __len__(self): |
|
return len(self.pairs) |
|
|
|
def __getitem__(self, index): |
|
from_path, to_path, _ = self.pairs[index] |
|
from_im = Image.open(from_path).convert('RGB') |
|
to_im = Image.open(to_path).convert('RGB') |
|
|
|
if self.transform: |
|
to_im = self.transform(to_im) |
|
from_im = self.transform(from_im) |
|
|
|
return from_im, to_im |
|
|