Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from typing import ( | |
Any, | |
ClassVar, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
Sequence, | |
Tuple, | |
Type, | |
TYPE_CHECKING, | |
Union, | |
) | |
def load_and_preprocess_images( | |
folder_path: str, image_size: int = 224, mode: str = "bilinear" | |
) -> torch.Tensor: | |
image_paths = [ | |
os.path.join(folder_path, file) | |
for file in os.listdir(folder_path) | |
if file.lower().endswith((".png", ".jpg", ".jpeg")) | |
] | |
image_paths.sort() | |
images = [] | |
bboxes_xyxy = [] | |
scales = [] | |
for path in image_paths: | |
image = _load_image(path) | |
image, bbox_xyxy, min_hw = _center_crop_square(image) | |
minscale = image_size / min_hw | |
imre = F.interpolate( | |
torch.from_numpy(image)[None], | |
size=(image_size, image_size), | |
mode=mode, | |
align_corners=False if mode == "bilinear" else None, | |
)[0] | |
images.append(imre.numpy()) | |
bboxes_xyxy.append(bbox_xyxy.numpy()) | |
scales.append(minscale) | |
images_tensor = torch.from_numpy(np.stack(images)) | |
# assume all the images have the same shape for GGS | |
image_info = { | |
"size": (min_hw, min_hw), | |
"bboxes_xyxy": np.stack(bboxes_xyxy), | |
"resized_scales": np.stack(scales), | |
} | |
return images_tensor, image_info | |
# helper functions | |
def _load_image(path) -> np.ndarray: | |
with Image.open(path) as pil_im: | |
im = np.array(pil_im.convert("RGB")) | |
im = im.transpose((2, 0, 1)) | |
im = im.astype(np.float32) / 255.0 | |
return im | |
def _center_crop_square(image: np.ndarray) -> np.ndarray: | |
h, w = image.shape[1:] | |
min_dim = min(h, w) | |
top = (h - min_dim) // 2 | |
left = (w - min_dim) // 2 | |
cropped_image = image[:, top : top + min_dim, left : left + min_dim] | |
# bbox_xywh: the cropped region | |
bbox_xywh = torch.tensor([left, top, min_dim, min_dim]) | |
# the format from xywh to xyxy | |
bbox_xyxy = _clamp_box_to_image_bounds_and_round( | |
_get_clamp_bbox( | |
bbox_xywh, | |
box_crop_context=0.0, | |
), | |
image_size_hw=(h, w), | |
) | |
return cropped_image, bbox_xyxy, min_dim | |
def _get_clamp_bbox( | |
bbox: torch.Tensor, | |
box_crop_context: float = 0.0, | |
) -> torch.Tensor: | |
# box_crop_context: rate of expansion for bbox | |
# returns possibly expanded bbox xyxy as float | |
bbox = bbox.clone() # do not edit bbox in place | |
# increase box size | |
if box_crop_context > 0.0: | |
c = box_crop_context | |
bbox = bbox.float() | |
bbox[0] -= bbox[2] * c / 2 | |
bbox[1] -= bbox[3] * c / 2 | |
bbox[2] += bbox[2] * c | |
bbox[3] += bbox[3] * c | |
if (bbox[2:] <= 1.0).any(): | |
raise ValueError( | |
f"squashed image!! The bounding box contains no pixels." | |
) | |
bbox[2:] = torch.clamp( | |
bbox[2:], 2 | |
) # set min height, width to 2 along both axes | |
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) | |
return bbox_xyxy | |
def _bbox_xywh_to_xyxy( | |
xywh: torch.Tensor, clamp_size: Optional[int] = None | |
) -> torch.Tensor: | |
xyxy = xywh.clone() | |
if clamp_size is not None: | |
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) | |
xyxy[2:] += xyxy[:2] | |
return xyxy | |
def _clamp_box_to_image_bounds_and_round( | |
bbox_xyxy: torch.Tensor, | |
image_size_hw: Tuple[int, int], | |
) -> torch.LongTensor: | |
bbox_xyxy = bbox_xyxy.clone() | |
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) | |
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) | |
if not isinstance(bbox_xyxy, torch.LongTensor): | |
bbox_xyxy = bbox_xyxy.round().long() | |
return bbox_xyxy # pyre-ignore [7] | |
if __name__ == "__main__": | |
# Example usage: | |
folder_path = "path/to/your/folder" | |
image_size = 224 | |
images_tensor = load_and_preprocess_images(folder_path, image_size) | |
print(images_tensor.shape) | |