Spaces:
Running
Running
import collections.abc as collections | |
from pathlib import Path | |
from typing import Optional, Tuple | |
import cv2 | |
import kornia | |
import numpy as np | |
import torch | |
from omegaconf import OmegaConf | |
class ImagePreprocessor: | |
default_conf = { | |
"resize": None, # target edge length, None for no resizing | |
"edge_divisible_by": None, | |
"side": "long", | |
"interpolation": "bilinear", | |
"align_corners": None, | |
"antialias": True, | |
"square_pad": False, | |
"add_padding_mask": False, | |
} | |
def __init__(self, conf) -> None: | |
super().__init__() | |
default_conf = OmegaConf.create(self.default_conf) | |
OmegaConf.set_struct(default_conf, True) | |
self.conf = OmegaConf.merge(default_conf, conf) | |
def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict: | |
"""Resize and preprocess an image, return image and resize scale""" | |
h, w = img.shape[-2:] | |
size = h, w | |
if self.conf.resize is not None: | |
if interpolation is None: | |
interpolation = self.conf.interpolation | |
size = self.get_new_image_size(h, w) | |
img = kornia.geometry.transform.resize( | |
img, | |
size, | |
side=self.conf.side, | |
antialias=self.conf.antialias, | |
align_corners=self.conf.align_corners, | |
interpolation=interpolation, | |
) | |
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) | |
T = np.diag([scale[0], scale[1], 1]) | |
data = { | |
"scales": scale, | |
"image_size": np.array(size[::-1]), | |
"transform": T, | |
"original_image_size": np.array([w, h]), | |
} | |
if self.conf.square_pad: | |
sl = max(img.shape[-2:]) | |
data["image"] = torch.zeros( | |
*img.shape[:-2], sl, sl, device=img.device, dtype=img.dtype | |
) | |
data["image"][:, : img.shape[-2], : img.shape[-1]] = img | |
if self.conf.add_padding_mask: | |
data["padding_mask"] = torch.zeros( | |
*img.shape[:-3], 1, sl, sl, device=img.device, dtype=torch.bool | |
) | |
data["padding_mask"][:, : img.shape[-2], : img.shape[-1]] = True | |
else: | |
data["image"] = img | |
return data | |
def load_image(self, image_path: Path) -> dict: | |
return self(load_image(image_path)) | |
def get_new_image_size( | |
self, | |
h: int, | |
w: int, | |
) -> Tuple[int, int]: | |
side = self.conf.side | |
if isinstance(self.conf.resize, collections.Iterable): | |
assert len(self.conf.resize) == 2 | |
return tuple(self.conf.resize) | |
side_size = self.conf.resize | |
aspect_ratio = w / h | |
if side not in ("short", "long", "vert", "horz"): | |
raise ValueError( | |
f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'" | |
) | |
if side == "vert": | |
size = side_size, int(side_size * aspect_ratio) | |
elif side == "horz": | |
size = int(side_size / aspect_ratio), side_size | |
elif (side == "short") ^ (aspect_ratio < 1.0): | |
size = side_size, int(side_size * aspect_ratio) | |
else: | |
size = int(side_size / aspect_ratio), side_size | |
if self.conf.edge_divisible_by is not None: | |
df = self.conf.edge_divisible_by | |
size = list(map(lambda x: int(x // df * df), size)) | |
return size | |
def read_image(path: Path, grayscale: bool = False) -> np.ndarray: | |
"""Read an image from path as RGB or grayscale""" | |
if not Path(path).exists(): | |
raise FileNotFoundError(f"No image at path {path}.") | |
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR | |
image = cv2.imread(str(path), mode) | |
if image is None: | |
raise IOError(f"Could not read image at {path}.") | |
if not grayscale: | |
image = image[..., ::-1] | |
return image | |
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: | |
"""Normalize the image tensor and reorder the dimensions.""" | |
if image.ndim == 3: | |
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW | |
elif image.ndim == 2: | |
image = image[None] # add channel axis | |
else: | |
raise ValueError(f"Not an image: {image.shape}") | |
return torch.tensor(image / 255.0, dtype=torch.float) | |
def load_image(path: Path, grayscale=False) -> torch.Tensor: | |
image = read_image(path, grayscale=grayscale) | |
return numpy_image_to_torch(image) | |