S12-ERA-Phase-I / utils.py
LN1996's picture
Update utils.py
fbf4d50
raw
history blame
1.58 kB
# utils file
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import numpy as np
import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
def __init__(self, root="./data", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
def augmentation_custom_resnet(data, mu=(0.49139968, 0.48215827, 0.44653124), sigma=(0.24703233, 0.24348505, 0.26158768), pad=4):
if data == 'Train':
transform = A.Compose([A.PadIfNeeded(min_height=32+pad,
min_width=32+pad,
border_mode=cv2.BORDER_CONSTANT,
value=np.mean(mu)),
A.RandomCrop(32, 32),
A.HorizontalFlip(p=0.5),
A.Cutout(max_h_size=8, max_w_size=8),
A.Normalize(mean=mu, std=sigma),
ToTensorV2()])
else:
transform = A.Compose([A.Normalize(mean=mu, std=sigma),
ToTensorV2()])
return transform