Spaces:
Sleeping
Sleeping
from torch.utils.data import Dataset | |
from typing import List | |
class ZipDataset(Dataset): | |
def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False): | |
self.datasets = datasets | |
self.transforms = transforms | |
if assert_equal_length: | |
for i in range(1, len(datasets)): | |
assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.' | |
def __len__(self): | |
return max(len(d) for d in self.datasets) | |
def __getitem__(self, idx): | |
x = tuple(d[idx % len(d)] for d in self.datasets) | |
if self.transforms: | |
x = self.transforms(*x) | |
return x | |