Spaces:
Running
on
T4
Running
on
T4
from __future__ import print_function | |
import torch | |
import torchvision.datasets as datasets | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from .tsv_io import TSVFile | |
import numpy as np | |
import base64 | |
import io | |
class TSVDataset(Dataset): | |
""" TSV dataset for ImageNet 1K training | |
""" | |
def __init__(self, tsv_file, transform=None, target_transform=None): | |
self.tsv = TSVFile(tsv_file) | |
self.transform = transform | |
self.target_transform = target_transform | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (image, target) where target is class_index of the target class. | |
""" | |
row = self.tsv.seek(index) | |
image_data = base64.b64decode(row[-1]) | |
image = Image.open(io.BytesIO(image_data)) | |
image = image.convert('RGB') | |
target = int(row[1]) | |
if self.transform is not None: | |
img = self.transform(image) | |
else: | |
img = image | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return img, target | |
def __len__(self): | |
return self.tsv.num_rows() | |