# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import numpy as np from PIL import Image import torch import torch.nn.functional as F from typing import ( Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union, ) def load_and_preprocess_images( folder_path: str, image_size: int = 224, mode: str = "bilinear" ) -> torch.Tensor: image_paths = [ os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.lower().endswith((".png", ".jpg", ".jpeg")) ] image_paths.sort() images = [] bboxes_xyxy = [] scales = [] for path in image_paths: image = _load_image(path) image, bbox_xyxy, min_hw = _center_crop_square(image) minscale = image_size / min_hw imre = F.interpolate( torch.from_numpy(image)[None], size=(image_size, image_size), mode=mode, align_corners=False if mode == "bilinear" else None, )[0] images.append(imre.numpy()) bboxes_xyxy.append(bbox_xyxy.numpy()) scales.append(minscale) images_tensor = torch.from_numpy(np.stack(images)) # assume all the images have the same shape for GGS image_info = { "size": (min_hw, min_hw), "bboxes_xyxy": np.stack(bboxes_xyxy), "resized_scales": np.stack(scales), } return images_tensor, image_info # helper functions def _load_image(path) -> np.ndarray: with Image.open(path) as pil_im: im = np.array(pil_im.convert("RGB")) im = im.transpose((2, 0, 1)) im = im.astype(np.float32) / 255.0 return im def _center_crop_square(image: np.ndarray) -> np.ndarray: h, w = image.shape[1:] min_dim = min(h, w) top = (h - min_dim) // 2 left = (w - min_dim) // 2 cropped_image = image[:, top : top + min_dim, left : left + min_dim] # bbox_xywh: the cropped region bbox_xywh = torch.tensor([left, top, min_dim, min_dim]) # the format from xywh to xyxy bbox_xyxy = _clamp_box_to_image_bounds_and_round( _get_clamp_bbox( bbox_xywh, box_crop_context=0.0, ), image_size_hw=(h, w), ) return cropped_image, bbox_xyxy, min_dim def _get_clamp_bbox( bbox: torch.Tensor, box_crop_context: float = 0.0, ) -> torch.Tensor: # box_crop_context: rate of expansion for bbox # returns possibly expanded bbox xyxy as float bbox = bbox.clone() # do not edit bbox in place # increase box size if box_crop_context > 0.0: c = box_crop_context bbox = bbox.float() bbox[0] -= bbox[2] * c / 2 bbox[1] -= bbox[3] * c / 2 bbox[2] += bbox[2] * c bbox[3] += bbox[3] * c if (bbox[2:] <= 1.0).any(): raise ValueError( f"squashed image!! The bounding box contains no pixels." ) bbox[2:] = torch.clamp( bbox[2:], 2 ) # set min height, width to 2 along both axes bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) return bbox_xyxy def _bbox_xywh_to_xyxy( xywh: torch.Tensor, clamp_size: Optional[int] = None ) -> torch.Tensor: xyxy = xywh.clone() if clamp_size is not None: xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) xyxy[2:] += xyxy[:2] return xyxy def _clamp_box_to_image_bounds_and_round( bbox_xyxy: torch.Tensor, image_size_hw: Tuple[int, int], ) -> torch.LongTensor: bbox_xyxy = bbox_xyxy.clone() bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) if not isinstance(bbox_xyxy, torch.LongTensor): bbox_xyxy = bbox_xyxy.round().long() return bbox_xyxy # pyre-ignore [7] if __name__ == "__main__": # Example usage: folder_path = "path/to/your/folder" image_size = 224 images_tensor = load_and_preprocess_images(folder_path, image_size) print(images_tensor.shape)