zjowowen's picture
init space
079c32c
raw
history blame
No virus
4.66 kB
from typing import Callable, Union
import os
import re
import math
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torchvision import transforms
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
def natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def find_images_and_targets(folder, types=('.png', '.jpg', '.jpeg'), class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
DEFAULT_CROP_PCT = 0.875
def transforms_noaug_train(
img_size=224,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
):
if interpolation == 'random':
# random interpolation not supported with no-aug
interpolation = 'bilinear'
tfl = [transforms.Resize(img_size, _pil_interp(interpolation)), transforms.CenterCrop(img_size)]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
return transforms.Compose(tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
return transforms.Compose(tfl)
class ImageNetDataset(data.Dataset):
def __init__(self, root: str, is_training: bool, transform: Callable = None) -> None:
self.root = root
if transform is None:
if is_training:
transform = transforms_noaug_train()
else:
transform = transforms_imagenet_eval()
self.transform = transform
self.data, _ = find_images_and_targets(root)
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> Union[torch.Tensor, torch.Tensor]:
img, target = self.data[index]
img = Image.open(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target