PoseDiffusion_MVP / util /load_img_folder.py
hugoycj
Initial commit
3d3e4e9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
def load_and_preprocess_images(
folder_path: str, image_size: int = 224, mode: str = "bilinear"
) -> torch.Tensor:
image_paths = [
os.path.join(folder_path, file)
for file in os.listdir(folder_path)
if file.lower().endswith((".png", ".jpg", ".jpeg"))
]
image_paths.sort()
images = []
bboxes_xyxy = []
scales = []
for path in image_paths:
image = _load_image(path)
image, bbox_xyxy, min_hw = _center_crop_square(image)
minscale = image_size / min_hw
imre = F.interpolate(
torch.from_numpy(image)[None],
size=(image_size, image_size),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)[0]
images.append(imre.numpy())
bboxes_xyxy.append(bbox_xyxy.numpy())
scales.append(minscale)
images_tensor = torch.from_numpy(np.stack(images))
# assume all the images have the same shape for GGS
image_info = {
"size": (min_hw, min_hw),
"bboxes_xyxy": np.stack(bboxes_xyxy),
"resized_scales": np.stack(scales),
}
return images_tensor, image_info
# helper functions
def _load_image(path) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def _center_crop_square(image: np.ndarray) -> np.ndarray:
h, w = image.shape[1:]
min_dim = min(h, w)
top = (h - min_dim) // 2
left = (w - min_dim) // 2
cropped_image = image[:, top : top + min_dim, left : left + min_dim]
# bbox_xywh: the cropped region
bbox_xywh = torch.tensor([left, top, min_dim, min_dim])
# the format from xywh to xyxy
bbox_xyxy = _clamp_box_to_image_bounds_and_round(
_get_clamp_bbox(
bbox_xywh,
box_crop_context=0.0,
),
image_size_hw=(h, w),
)
return cropped_image, bbox_xyxy, min_dim
def _get_clamp_bbox(
bbox: torch.Tensor,
box_crop_context: float = 0.0,
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
bbox = bbox.clone() # do not edit bbox in place
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(
bbox[2:], 2
) # set min height, width to 2 along both axes
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
return bbox_xyxy
def _bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def _clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]
if __name__ == "__main__":
# Example usage:
folder_path = "path/to/your/folder"
image_size = 224
images_tensor = load_and_preprocess_images(folder_path, image_size)
print(images_tensor.shape)