Spaces:
Runtime error
Runtime error
# utils file | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision import transforms | |
import torchvision | |
import numpy as np | |
import cv2 | |
import albumentations as A | |
from albumentations.pytorch.transforms import ToTensorV2 | |
cv2.setNumThreads(0) | |
cv2.ocl.setUseOpenCL(False) | |
class Cifar10SearchDataset(torchvision.datasets.CIFAR10): | |
def __init__(self, root="./data", train=True, download=True, transform=None): | |
super().__init__(root=root, train=train, download=download, transform=transform) | |
def __getitem__(self, index): | |
image, label = self.data[index], self.targets[index] | |
if self.transform is not None: | |
transformed = self.transform(image=image) | |
image = transformed["image"] | |
return image, label | |
def augmentation_custom_resnet(data, mu=(0.49139968, 0.48215827, 0.44653124), sigma=(0.24703233, 0.24348505, 0.26158768), pad=4): | |
if data == 'Train': | |
transform = A.Compose([A.PadIfNeeded(min_height=32+pad, | |
min_width=32+pad, | |
border_mode=cv2.BORDER_CONSTANT, | |
value=np.mean(mu)), | |
A.RandomCrop(32, 32), | |
A.HorizontalFlip(p=0.5), | |
A.Cutout(max_h_size=8, max_w_size=8), | |
A.Normalize(mean=mu, std=sigma), | |
ToTensorV2()]) | |
else: | |
transform = A.Compose([A.Normalize(mean=mu, std=sigma), | |
ToTensorV2()]) | |
return transform | |