Spaces:
Sleeping
Sleeping
File size: 694 Bytes
854728f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from torch.utils.data import Dataset
from typing import List
class ZipDataset(Dataset):
def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
self.datasets = datasets
self.transforms = transforms
if assert_equal_length:
for i in range(1, len(datasets)):
assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
def __len__(self):
return max(len(d) for d in self.datasets)
def __getitem__(self, idx):
x = tuple(d[idx % len(d)] for d in self.datasets)
if self.transforms:
x = self.transforms(*x)
return x
|