Spaces:
Runtime error
Runtime error
File size: 4,190 Bytes
3d3e4e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
def load_and_preprocess_images(
folder_path: str, image_size: int = 224, mode: str = "bilinear"
) -> torch.Tensor:
image_paths = [
os.path.join(folder_path, file)
for file in os.listdir(folder_path)
if file.lower().endswith((".png", ".jpg", ".jpeg"))
]
image_paths.sort()
images = []
bboxes_xyxy = []
scales = []
for path in image_paths:
image = _load_image(path)
image, bbox_xyxy, min_hw = _center_crop_square(image)
minscale = image_size / min_hw
imre = F.interpolate(
torch.from_numpy(image)[None],
size=(image_size, image_size),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)[0]
images.append(imre.numpy())
bboxes_xyxy.append(bbox_xyxy.numpy())
scales.append(minscale)
images_tensor = torch.from_numpy(np.stack(images))
# assume all the images have the same shape for GGS
image_info = {
"size": (min_hw, min_hw),
"bboxes_xyxy": np.stack(bboxes_xyxy),
"resized_scales": np.stack(scales),
}
return images_tensor, image_info
# helper functions
def _load_image(path) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def _center_crop_square(image: np.ndarray) -> np.ndarray:
h, w = image.shape[1:]
min_dim = min(h, w)
top = (h - min_dim) // 2
left = (w - min_dim) // 2
cropped_image = image[:, top : top + min_dim, left : left + min_dim]
# bbox_xywh: the cropped region
bbox_xywh = torch.tensor([left, top, min_dim, min_dim])
# the format from xywh to xyxy
bbox_xyxy = _clamp_box_to_image_bounds_and_round(
_get_clamp_bbox(
bbox_xywh,
box_crop_context=0.0,
),
image_size_hw=(h, w),
)
return cropped_image, bbox_xyxy, min_dim
def _get_clamp_bbox(
bbox: torch.Tensor,
box_crop_context: float = 0.0,
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
bbox = bbox.clone() # do not edit bbox in place
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(
bbox[2:], 2
) # set min height, width to 2 along both axes
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
return bbox_xyxy
def _bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def _clamp_box_to_image_bounds_and_round(
bbox_xyxy: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]
if __name__ == "__main__":
# Example usage:
folder_path = "path/to/your/folder"
image_size = 224
images_tensor = load_and_preprocess_images(folder_path, image_size)
print(images_tensor.shape)
|