from typing import Tuple, List import random import torch import torch.nn as nn class Intensity(nn.Module): """ Overview: Intensity transformation for data augmentation. Scale the image intensity by a random factor. """ def __init__(self, scale: float) -> None: """ Arguments: - scale (:obj:`float`): The scale factor for intensity transformation. """ super().__init__() self.scale = scale def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - x (:obj:`torch.Tensor`): The input image tensor with shape (B, C, H, W). - output (:obj:`torch.Tensor`): The output image tensor with shape (B, C, H, W). """ r = torch.randn((x.size(0), 1, 1, 1), device=x.device) noise = 1.0 + (self.scale * r.clamp(-2.0, 2.0)) return x * noise class RandomCrop(nn.Module): """ Overview: Random crop the image to the given size. """ def __init__(self, image_shape: Tuple[int]) -> None: """ Arguments: - image_shape (:obj:`Tuple[int]`): The target shape of the image to be cropped. """ super().__init__() self.image_shape = image_shape def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - x (:obj:`torch.Tensor`): The input image tensor with shape (B, C, H, W), where H and W are \ the original image shape. - output (:obj:`torch.Tensor`): The output image tensor with shape (B, C, H_, W_), where H_ and W_ are \ the target image shape indicated by `image_shape`. """ H, W = x.shape[2:] H_, W_ = self.image_shape dh, dw = H - H_, W - W_ h, w = random.randint(0, dh), random.randint(0, dw) return x[..., h:h + H_, w:w + W_] class ImageTransforms(object): """ Overview: Image transformation for data augmentation. Including image normalization (divide 255), random crop and intensity transformation. """ def __init__(self, augmentation: List[str], shift_delta: int = 4, image_shape: Tuple[int] = (96, 96)) -> None: """ Arguments: - augmentation (:obj:`List[str]`): The list of augmentation types. Now support "shift" and "intensity". - shift_delta (:obj:`int`): The delta value for random shift padding before crop. Use ReplicationPad2d \ to pad the image without the loss of information. - image_shape (:obj:`Tuple[int]`): The target shape of the image to be cropped. """ self.augmentation = augmentation self.image_transforms = [] for aug in self.augmentation: if aug == "shift": # TODO validate the effectiveness of ReflectionPad2d transformation = nn.Sequential(nn.ReplicationPad2d(shift_delta), RandomCrop(image_shape)) elif aug == "intensity": transformation = Intensity(scale=0.05) else: raise NotImplementedError("not support augmentation type: {}".format(aug)) self.image_transforms.append(transformation) @torch.no_grad() def transform(self, images: torch.Tensor) -> torch.Tensor: """ Shapes: - x (:obj:`torch.Tensor`): The input image tensor with shape (B, C, H, W), where H and W are \ the original image shape. - output (:obj:`torch.Tensor`): The output image tensor with shape (B, C, H_, W_), where H_ and W_ are \ the target image shape indicated by `image_shape`. .. note:: Use torch.no_grad() to save cuda memory. Transformations are not trainable. """ images = images.float() / 255. if images.dtype == torch.uint8 else images processed_images = images.reshape(-1, *images.shape[-3:]) for transform in self.image_transforms: processed_images = transform(processed_images) processed_images = processed_images.view(*images.shape[:-3], *processed_images.shape[1:]) return processed_images