|
from PIL import Image |
|
import matplotlib |
|
import numpy as np |
|
|
|
from PIL import Image |
|
|
|
import torch |
|
from torchvision.transforms import InterpolationMode |
|
from torchvision.transforms.functional import resize |
|
|
|
def concatenate_images(*image_lists): |
|
|
|
if not image_lists or not image_lists[0]: |
|
raise ValueError("At least one non-empty image list must be provided") |
|
|
|
|
|
max_width = 0 |
|
total_height = 0 |
|
row_widths = [] |
|
row_heights = [] |
|
|
|
|
|
for image_list in image_lists: |
|
if image_list: |
|
width = sum(img.width for img in image_list) |
|
height = image_list[0].height |
|
max_width = max(max_width, width) |
|
total_height += height |
|
row_widths.append(width) |
|
row_heights.append(height) |
|
|
|
|
|
new_image = Image.new('RGB', (max_width, total_height)) |
|
|
|
|
|
y_offset = 0 |
|
for i, image_list in enumerate(image_lists): |
|
x_offset = 0 |
|
for img in image_list: |
|
new_image.paste(img, (x_offset, y_offset)) |
|
x_offset += img.width |
|
y_offset += row_heights[i] |
|
|
|
return new_image |
|
|
|
|
|
def colorize_depth_map(depth, mask=None): |
|
cm = matplotlib.colormaps["Spectral"] |
|
|
|
depth = ((depth - depth.min()) / (depth.max() - depth.min())) |
|
|
|
img_colored_np = cm(depth, bytes=False)[:, :, 0:3] |
|
depth_colored = (img_colored_np * 255).astype(np.uint8) |
|
if mask is not None: |
|
masked_image = np.zeros_like(depth_colored) |
|
masked_image[mask.numpy()] = depth_colored[mask.numpy()] |
|
depth_colored_img = Image.fromarray(masked_image) |
|
else: |
|
depth_colored_img = Image.fromarray(depth_colored) |
|
return depth_colored_img |
|
|
|
|
|
def resize_max_res( |
|
img: torch.Tensor, |
|
max_edge_resolution: int, |
|
resample_method: InterpolationMode = InterpolationMode.BILINEAR, |
|
) -> torch.Tensor: |
|
""" |
|
Resize image to limit maximum edge length while keeping aspect ratio. |
|
|
|
Args: |
|
img (`torch.Tensor`): |
|
Image tensor to be resized. Expected shape: [B, C, H, W] |
|
max_edge_resolution (`int`): |
|
Maximum edge length (pixel). |
|
resample_method (`PIL.Image.Resampling`): |
|
Resampling method used to resize images. |
|
|
|
Returns: |
|
`torch.Tensor`: Resized image. |
|
""" |
|
assert 4 == img.dim(), f"Invalid input shape {img.shape}" |
|
|
|
original_height, original_width = img.shape[-2:] |
|
downscale_factor = min( |
|
max_edge_resolution / original_width, max_edge_resolution / original_height |
|
) |
|
|
|
new_width = int(original_width * downscale_factor) |
|
new_height = int(original_height * downscale_factor) |
|
|
|
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) |
|
return resized_img |
|
|
|
|
|
def get_tv_resample_method(method_str: str) -> InterpolationMode: |
|
resample_method_dict = { |
|
"bilinear": InterpolationMode.BILINEAR, |
|
"bicubic": InterpolationMode.BICUBIC, |
|
"nearest": InterpolationMode.NEAREST_EXACT, |
|
"nearest-exact": InterpolationMode.NEAREST_EXACT, |
|
} |
|
resample_method = resample_method_dict.get(method_str, None) |
|
if resample_method is None: |
|
raise ValueError(f"Unknown resampling method: {resample_method}") |
|
else: |
|
return resample_method |
|
|