Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
raw
history blame
4.12 kB
from typing import Dict
import numpy as np
import torch
import kornia.augmentation as K
from kornia.geometry.transform import warp_perspective
# Adapted from Kornia
class GeometricSequential:
def __init__(self, *transforms, align_corners=True) -> None:
self.transforms = transforms
self.align_corners = align_corners
def __call__(self, x, mode="bilinear"):
b, c, h, w = x.shape
M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
for t in self.transforms:
if np.random.rand() < t.p:
M = M.matmul(
t.compute_transformation(x, t.generate_parameters((b, c, h, w)))
)
return (
warp_perspective(
x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
),
M,
)
def apply_transform(self, x, M, mode="bilinear"):
b, c, h, w = x.shape
return warp_perspective(
x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
)
class RandomPerspective(K.RandomPerspective):
def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
distortion_scale = torch.as_tensor(
self.distortion_scale, device=self._device, dtype=self._dtype
)
return self.random_perspective_generator(
batch_shape[0],
batch_shape[-2],
batch_shape[-1],
distortion_scale,
self.same_on_batch,
self.device,
self.dtype,
)
def random_perspective_generator(
self,
batch_size: int,
height: int,
width: int,
distortion_scale: torch.Tensor,
same_on_batch: bool = False,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
r"""Get parameters for ``perspective`` for a random perspective transform.
Args:
batch_size (int): the tensor batch size.
height (int) : height of the image.
width (int): width of the image.
distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
same_on_batch (bool): apply the same transformation across the batch. Default: False.
device (torch.device): the device on which the random numbers will be generated. Default: cpu.
dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
Returns:
params Dict[str, torch.Tensor]: parameters to be passed for transformation.
- start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
- end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
Note:
The generated random numbers are not reproducible across different devices and dtypes.
"""
if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
raise AssertionError(
f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
)
if not (
type(height) is int and height > 0 and type(width) is int and width > 0
):
raise AssertionError(
f"'height' and 'width' must be integers. Got {height}, {width}."
)
start_points: torch.Tensor = torch.tensor(
[[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
device=distortion_scale.device,
dtype=distortion_scale.dtype,
).expand(batch_size, -1, -1)
# generate random offset not larger than half of the image
fx = distortion_scale * width / 2
fy = distortion_scale * height / 2
factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
offset = (torch.rand_like(start_points) - 0.5) * 2
end_points = start_points + factor * offset
return dict(start_points=start_points, end_points=end_points)