Spaces:
Running
on
T4
Running
on
T4
File size: 1,208 Bytes
a277bb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from __future__ import print_function
import torch
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from PIL import Image
from .tsv_io import TSVFile
import numpy as np
import base64
import io
class TSVDataset(Dataset):
""" TSV dataset for ImageNet 1K training
"""
def __init__(self, tsv_file, transform=None, target_transform=None):
self.tsv = TSVFile(tsv_file)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
row = self.tsv.seek(index)
image_data = base64.b64decode(row[-1])
image = Image.open(io.BytesIO(image_data))
image = image.convert('RGB')
target = int(row[1])
if self.transform is not None:
img = self.transform(image)
else:
img = image
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.tsv.num_rows()
|