from __future__ import print_function import torch import torchvision.datasets as datasets from torch.utils.data import Dataset from PIL import Image from .tsv_io import TSVFile import numpy as np import base64 import io class TSVDataset(Dataset): """ TSV dataset for ImageNet 1K training """ def __init__(self, tsv_file, transform=None, target_transform=None): self.tsv = TSVFile(tsv_file) self.transform = transform self.target_transform = target_transform def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ row = self.tsv.seek(index) image_data = base64.b64decode(row[-1]) image = Image.open(io.BytesIO(image_data)) image = image.convert('RGB') target = int(row[1]) if self.transform is not None: img = self.transform(image) else: img = image if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return self.tsv.num_rows()