# utils file import matplotlib.pyplot as plt import torch from torchvision import transforms import torchvision import numpy as np import cv2 import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) class Cifar10SearchDataset(torchvision.datasets.CIFAR10): def __init__(self, root="./data", train=True, download=True, transform=None): super().__init__(root=root, train=train, download=download, transform=transform) def __getitem__(self, index): image, label = self.data[index], self.targets[index] if self.transform is not None: transformed = self.transform(image=image) image = transformed["image"] return image, label def augmentation_custom_resnet(data, mu=(0.49139968, 0.48215827, 0.44653124), sigma=(0.24703233, 0.24348505, 0.26158768), pad=4): if data == 'Train': transform = A.Compose([A.PadIfNeeded(min_height=32+pad, min_width=32+pad, border_mode=cv2.BORDER_CONSTANT, value=np.mean(mu)), A.RandomCrop(32, 32), A.HorizontalFlip(p=0.5), A.Cutout(max_h_size=8, max_w_size=8), A.Normalize(mean=mu, std=sigma), ToTensorV2()]) else: transform = A.Compose([A.Normalize(mean=mu, std=sigma), ToTensorV2()]) return transform