Spaces:
Running
Running
from typing import Dict | |
import numpy as np | |
import torch | |
import kornia.augmentation as K | |
from kornia.geometry.transform import warp_perspective | |
# Adapted from Kornia | |
class GeometricSequential: | |
def __init__(self, *transforms, align_corners=True) -> None: | |
self.transforms = transforms | |
self.align_corners = align_corners | |
def __call__(self, x, mode="bilinear"): | |
b, c, h, w = x.shape | |
M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) | |
for t in self.transforms: | |
if np.random.rand() < t.p: | |
M = M.matmul( | |
t.compute_transformation(x, t.generate_parameters((b, c, h, w))) | |
) | |
return ( | |
warp_perspective( | |
x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners | |
), | |
M, | |
) | |
def apply_transform(self, x, M, mode="bilinear"): | |
b, c, h, w = x.shape | |
return warp_perspective( | |
x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode | |
) | |
class RandomPerspective(K.RandomPerspective): | |
def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: | |
distortion_scale = torch.as_tensor( | |
self.distortion_scale, device=self._device, dtype=self._dtype | |
) | |
return self.random_perspective_generator( | |
batch_shape[0], | |
batch_shape[-2], | |
batch_shape[-1], | |
distortion_scale, | |
self.same_on_batch, | |
self.device, | |
self.dtype, | |
) | |
def random_perspective_generator( | |
self, | |
batch_size: int, | |
height: int, | |
width: int, | |
distortion_scale: torch.Tensor, | |
same_on_batch: bool = False, | |
device: torch.device = torch.device("cpu"), | |
dtype: torch.dtype = torch.float32, | |
) -> Dict[str, torch.Tensor]: | |
r"""Get parameters for ``perspective`` for a random perspective transform. | |
Args: | |
batch_size (int): the tensor batch size. | |
height (int) : height of the image. | |
width (int): width of the image. | |
distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. | |
same_on_batch (bool): apply the same transformation across the batch. Default: False. | |
device (torch.device): the device on which the random numbers will be generated. Default: cpu. | |
dtype (torch.dtype): the data type of the generated random numbers. Default: float32. | |
Returns: | |
params Dict[str, torch.Tensor]: parameters to be passed for transformation. | |
- start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). | |
- end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). | |
Note: | |
The generated random numbers are not reproducible across different devices and dtypes. | |
""" | |
if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): | |
raise AssertionError( | |
f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." | |
) | |
if not ( | |
type(height) is int and height > 0 and type(width) is int and width > 0 | |
): | |
raise AssertionError( | |
f"'height' and 'width' must be integers. Got {height}, {width}." | |
) | |
start_points: torch.Tensor = torch.tensor( | |
[[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], | |
device=distortion_scale.device, | |
dtype=distortion_scale.dtype, | |
).expand(batch_size, -1, -1) | |
# generate random offset not larger than half of the image | |
fx = distortion_scale * width / 2 | |
fy = distortion_scale * height / 2 | |
factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) | |
offset = (torch.rand_like(start_points) - 0.5) * 2 | |
end_points = start_points + factor * offset | |
return dict(start_points=start_points, end_points=end_points) | |