File size: 1,578 Bytes
f533343
 
 
 
 
 
 
 
fbf4d50
 
 
f533343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# utils file

import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import numpy as np 
import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2


cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

class Cifar10SearchDataset(torchvision.datasets.CIFAR10):

    def __init__(self, root="./data", train=True, download=True, transform=None):
      super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
      image, label = self.data[index], self.targets[index]
      
      if self.transform is not None:
        transformed = self.transform(image=image)
        image = transformed["image"]

        return image, label
        
def augmentation_custom_resnet(data, mu=(0.49139968, 0.48215827, 0.44653124), sigma=(0.24703233, 0.24348505, 0.26158768), pad=4):

  if data == 'Train':
    transform = A.Compose([A.PadIfNeeded(min_height=32+pad,
                                        min_width=32+pad,
                                        border_mode=cv2.BORDER_CONSTANT,
                                        value=np.mean(mu)),
                            A.RandomCrop(32, 32),
                            A.HorizontalFlip(p=0.5),
                            A.Cutout(max_h_size=8, max_w_size=8),
                            A.Normalize(mean=mu, std=sigma),
                            ToTensorV2()])
  else:
    transform = A.Compose([A.Normalize(mean=mu, std=sigma),
                           ToTensorV2()])

  return transform