File size: 890 Bytes
9b2bdf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
#!/usr/bin/python
# encoding: utf-8
import os
from torch.utils.data import Dataset
from PIL import Image
import torch
class GTResDataset(Dataset):
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
self.pairs = []
for f in os.listdir(root_path):
image_path = os.path.join(root_path, f)
gt_path = os.path.join(gt_dir, f)
if f.endswith(".jpg") or f.endswith(".png"):
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
self.transform = transform
self.transform_train = transform_train
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
from_path, to_path, _ = self.pairs[index]
from_im = Image.open(from_path).convert('RGB')
to_im = Image.open(to_path).convert('RGB')
if self.transform:
to_im = self.transform(to_im)
from_im = self.transform(from_im)
return from_im, to_im
|