Spaces:
Sleeping
Sleeping
import random | |
import torch | |
import numpy as np | |
import math | |
from torchvision import transforms as T | |
from torchvision.transforms import functional as F | |
from PIL import Image, ImageFilter | |
""" | |
Pair transforms are MODs of regular transforms so that it takes in multiple images | |
and apply exact transforms on all images. This is especially useful when we want the | |
transforms on a pair of images. | |
Example: | |
img1, img2, ..., imgN = transforms(img1, img2, ..., imgN) | |
""" | |
class PairCompose(T.Compose): | |
def __call__(self, *x): | |
for transform in self.transforms: | |
x = transform(*x) | |
return x | |
class PairApply: | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, *x): | |
return [self.transforms(xi) for xi in x] | |
class PairApplyOnlyAtIndices: | |
def __init__(self, indices, transforms): | |
self.indices = indices | |
self.transforms = transforms | |
def __call__(self, *x): | |
return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)] | |
class PairRandomAffine(T.RandomAffine): | |
def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0): | |
super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor) | |
self.resamples = resamples | |
def __call__(self, *x): | |
if not len(x): | |
return [] | |
param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size) | |
resamples = self.resamples or [self.resample] * len(x) | |
return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)] | |
class PairRandomHorizontalFlip(T.RandomHorizontalFlip): | |
def __call__(self, *x): | |
if torch.rand(1) < self.p: | |
x = [F.hflip(xi) for xi in x] | |
return x | |
class RandomBoxBlur: | |
def __init__(self, prob, max_radius): | |
self.prob = prob | |
self.max_radius = max_radius | |
def __call__(self, img): | |
if torch.rand(1) < self.prob: | |
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) | |
img = img.filter(fil) | |
return img | |
class PairRandomBoxBlur(RandomBoxBlur): | |
def __call__(self, *x): | |
if torch.rand(1) < self.prob: | |
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) | |
x = [xi.filter(fil) for xi in x] | |
return x | |
class RandomSharpen: | |
def __init__(self, prob): | |
self.prob = prob | |
self.filter = ImageFilter.SHARPEN | |
def __call__(self, img): | |
if torch.rand(1) < self.prob: | |
img = img.filter(self.filter) | |
return img | |
class PairRandomSharpen(RandomSharpen): | |
def __call__(self, *x): | |
if torch.rand(1) < self.prob: | |
x = [xi.filter(self.filter) for xi in x] | |
return x | |
class PairRandomAffineAndResize: | |
def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0): | |
self.size = size | |
self.degrees = degrees | |
self.translate = translate | |
self.scale = scale | |
self.shear = shear | |
self.ratio = ratio | |
self.resample = resample | |
self.fillcolor = fillcolor | |
def __call__(self, *x): | |
if not len(x): | |
return [] | |
w, h = x[0].size | |
scale_factor = max(self.size[1] / w, self.size[0] / h) | |
w_padded = max(w, self.size[1]) | |
h_padded = max(h, self.size[0]) | |
pad_h = int(math.ceil((h_padded - h) / 2)) | |
pad_w = int(math.ceil((w_padded - w) / 2)) | |
scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor | |
translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor | |
affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h)) | |
def transform(img): | |
if pad_h > 0 or pad_w > 0: | |
img = F.pad(img, (pad_w, pad_h)) | |
img = F.affine(img, *affine_params, self.resample, self.fillcolor) | |
img = F.center_crop(img, self.size) | |
return img | |
return [transform(xi) for xi in x] | |
class RandomAffineAndResize(PairRandomAffineAndResize): | |
def __call__(self, img): | |
return super().__call__(img)[0] |