Spaces:
Sleeping
Sleeping
import os | |
import glob | |
from torch.utils.data import Dataset | |
from PIL import Image | |
class ImagesDataset(Dataset): | |
def __init__(self, root, mode='RGB', transforms=None): | |
self.transforms = transforms | |
self.mode = mode | |
self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True), | |
*glob.glob(os.path.join(root, '**', '*.png'), recursive=True)]) | |
def __len__(self): | |
return len(self.filenames) | |
def __getitem__(self, idx): | |
with Image.open(self.filenames[idx]) as img: | |
img = img.convert(self.mode) | |
if self.transforms: | |
img = self.transforms(img) | |
return img | |