Spaces:
Runtime error
Runtime error
File size: 4,546 Bytes
a6dac9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
""" Quick n Simple Image Folder, Tarfile based DataSet
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.utils.data as data
import os
import torch
import logging
from PIL import Image
from .parsers import create_parser
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
class ImageDataset(data.Dataset):
def __init__(
self,
root,
parser=None,
class_map='',
load_bytes=False,
transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target
def __len__(self):
return len(self.parser)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
parser=None,
split='train',
is_training=False,
batch_size=None,
class_map='',
load_bytes=False,
repeats=0,
transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
else:
self.parser = parser
self.transform = transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
yield img, target
def __len__(self):
if hasattr(self.parser, '__len__'):
return len(self.parser)
else:
return 0
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_splits=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_splits = num_splits
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i] # all splits share the same dataset base transform
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
# run the full augmentation on the remaining splits
for _ in range(self.num_splits - 1):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)
|