Spaces:
Runtime error
Runtime error
File size: 1,578 Bytes
f533343 fbf4d50 f533343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# utils file
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import numpy as np
import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
def __init__(self, root="./data", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
def augmentation_custom_resnet(data, mu=(0.49139968, 0.48215827, 0.44653124), sigma=(0.24703233, 0.24348505, 0.26158768), pad=4):
if data == 'Train':
transform = A.Compose([A.PadIfNeeded(min_height=32+pad,
min_width=32+pad,
border_mode=cv2.BORDER_CONSTANT,
value=np.mean(mu)),
A.RandomCrop(32, 32),
A.HorizontalFlip(p=0.5),
A.Cutout(max_h_size=8, max_w_size=8),
A.Normalize(mean=mu, std=sigma),
ToTensorV2()])
else:
transform = A.Compose([A.Normalize(mean=mu, std=sigma),
ToTensorV2()])
return transform
|