Spaces:
Sleeping
Sleeping
File size: 4,385 Bytes
854728f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import random
import torch
import numpy as np
import math
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image, ImageFilter
"""
Pair transforms are MODs of regular transforms so that it takes in multiple images
and apply exact transforms on all images. This is especially useful when we want the
transforms on a pair of images.
Example:
img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
"""
class PairCompose(T.Compose):
def __call__(self, *x):
for transform in self.transforms:
x = transform(*x)
return x
class PairApply:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *x):
return [self.transforms(xi) for xi in x]
class PairApplyOnlyAtIndices:
def __init__(self, indices, transforms):
self.indices = indices
self.transforms = transforms
def __call__(self, *x):
return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
class PairRandomAffine(T.RandomAffine):
def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
self.resamples = resamples
def __call__(self, *x):
if not len(x):
return []
param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
resamples = self.resamples or [self.resample] * len(x)
return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
def __call__(self, *x):
if torch.rand(1) < self.p:
x = [F.hflip(xi) for xi in x]
return x
class RandomBoxBlur:
def __init__(self, prob, max_radius):
self.prob = prob
self.max_radius = max_radius
def __call__(self, img):
if torch.rand(1) < self.prob:
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
img = img.filter(fil)
return img
class PairRandomBoxBlur(RandomBoxBlur):
def __call__(self, *x):
if torch.rand(1) < self.prob:
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
x = [xi.filter(fil) for xi in x]
return x
class RandomSharpen:
def __init__(self, prob):
self.prob = prob
self.filter = ImageFilter.SHARPEN
def __call__(self, img):
if torch.rand(1) < self.prob:
img = img.filter(self.filter)
return img
class PairRandomSharpen(RandomSharpen):
def __call__(self, *x):
if torch.rand(1) < self.prob:
x = [xi.filter(self.filter) for xi in x]
return x
class PairRandomAffineAndResize:
def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
self.size = size
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.ratio = ratio
self.resample = resample
self.fillcolor = fillcolor
def __call__(self, *x):
if not len(x):
return []
w, h = x[0].size
scale_factor = max(self.size[1] / w, self.size[0] / h)
w_padded = max(w, self.size[1])
h_padded = max(h, self.size[0])
pad_h = int(math.ceil((h_padded - h) / 2))
pad_w = int(math.ceil((w_padded - w) / 2))
scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
def transform(img):
if pad_h > 0 or pad_w > 0:
img = F.pad(img, (pad_w, pad_h))
img = F.affine(img, *affine_params, self.resample, self.fillcolor)
img = F.center_crop(img, self.size)
return img
return [transform(xi) for xi in x]
class RandomAffineAndResize(PairRandomAffineAndResize):
def __call__(self, img):
return super().__call__(img)[0] |