VideoMatting / dataset /images.py
Fazhong Liu
init
854728f
raw
history blame contribute delete
716 Bytes
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
class ImagesDataset(Dataset):
def __init__(self, root, mode='RGB', transforms=None):
self.transforms = transforms
self.mode = mode
self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
*glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
with Image.open(self.filenames[idx]) as img:
img = img.convert(self.mode)
if self.transforms:
img = self.transforms(img)
return img