File size: 413 Bytes
854728f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.utils.data import Dataset


class SampleDataset(Dataset):
    def __init__(self, dataset, samples):
        samples = min(samples, len(dataset))
        self.dataset = dataset
        self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]