import random import torch import numpy as np import math from torchvision import transforms as T from torchvision.transforms import functional as F from PIL import Image, ImageFilter """ Pair transforms are MODs of regular transforms so that it takes in multiple images and apply exact transforms on all images. This is especially useful when we want the transforms on a pair of images. Example: img1, img2, ..., imgN = transforms(img1, img2, ..., imgN) """ class PairCompose(T.Compose): def __call__(self, *x): for transform in self.transforms: x = transform(*x) return x class PairApply: def __init__(self, transforms): self.transforms = transforms def __call__(self, *x): return [self.transforms(xi) for xi in x] class PairApplyOnlyAtIndices: def __init__(self, indices, transforms): self.indices = indices self.transforms = transforms def __call__(self, *x): return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)] class PairRandomAffine(T.RandomAffine): def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0): super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor) self.resamples = resamples def __call__(self, *x): if not len(x): return [] param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size) resamples = self.resamples or [self.resample] * len(x) return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)] class PairRandomHorizontalFlip(T.RandomHorizontalFlip): def __call__(self, *x): if torch.rand(1) < self.p: x = [F.hflip(xi) for xi in x] return x class RandomBoxBlur: def __init__(self, prob, max_radius): self.prob = prob self.max_radius = max_radius def __call__(self, img): if torch.rand(1) < self.prob: fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) img = img.filter(fil) return img class PairRandomBoxBlur(RandomBoxBlur): def __call__(self, *x): if torch.rand(1) < self.prob: fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) x = [xi.filter(fil) for xi in x] return x class RandomSharpen: def __init__(self, prob): self.prob = prob self.filter = ImageFilter.SHARPEN def __call__(self, img): if torch.rand(1) < self.prob: img = img.filter(self.filter) return img class PairRandomSharpen(RandomSharpen): def __call__(self, *x): if torch.rand(1) < self.prob: x = [xi.filter(self.filter) for xi in x] return x class PairRandomAffineAndResize: def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0): self.size = size self.degrees = degrees self.translate = translate self.scale = scale self.shear = shear self.ratio = ratio self.resample = resample self.fillcolor = fillcolor def __call__(self, *x): if not len(x): return [] w, h = x[0].size scale_factor = max(self.size[1] / w, self.size[0] / h) w_padded = max(w, self.size[1]) h_padded = max(h, self.size[0]) pad_h = int(math.ceil((h_padded - h) / 2)) pad_w = int(math.ceil((w_padded - w) / 2)) scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h)) def transform(img): if pad_h > 0 or pad_w > 0: img = F.pad(img, (pad_w, pad_h)) img = F.affine(img, *affine_params, self.resample, self.fillcolor) img = F.center_crop(img, self.size) return img return [transform(xi) for xi in x] class RandomAffineAndResize(PairRandomAffineAndResize): def __call__(self, img): return super().__call__(img)[0]