Spaces:
Build error
Build error
commit
Browse files- inferer.py +238 -0
inferer.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import math
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from PIL import Image, ImageFont
|
11 |
+
|
12 |
+
from yolov6.data.data_augment import letterbox
|
13 |
+
from yolov6.layers.common import DetectBackend
|
14 |
+
from yolov6.utils.events import LOGGER, load_yaml
|
15 |
+
from yolov6.utils.nms import non_max_suppression
|
16 |
+
|
17 |
+
|
18 |
+
class Inferer:
|
19 |
+
def __init__(self, model_id, device="cpu", yaml="coco.yaml", img_size=640, half=False):
|
20 |
+
self.__dict__.update(locals())
|
21 |
+
|
22 |
+
# Init model
|
23 |
+
self.img_size = img_size
|
24 |
+
cuda = device != "cpu" and torch.cuda.is_available()
|
25 |
+
self.device = torch.device("cuda:0" if cuda else "cpu")
|
26 |
+
self.model = DetectBackend(hf_hub_download(model_id, "model.pt"), device=self.device)
|
27 |
+
self.stride = self.model.stride
|
28 |
+
self.class_names = load_yaml(yaml)["names"]
|
29 |
+
self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
|
30 |
+
|
31 |
+
# Half precision
|
32 |
+
if half & (self.device.type != "cpu"):
|
33 |
+
self.model.model.half()
|
34 |
+
else:
|
35 |
+
self.model.model.float()
|
36 |
+
half = False
|
37 |
+
|
38 |
+
if self.device.type != "cpu":
|
39 |
+
self.model(
|
40 |
+
torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))
|
41 |
+
) # warmup
|
42 |
+
|
43 |
+
# Switch model to deploy status
|
44 |
+
self.model_switch(self.model, self.img_size)
|
45 |
+
|
46 |
+
def model_switch(self, model, img_size):
|
47 |
+
"""Model switch to deploy status"""
|
48 |
+
from yolov6.layers.common import RepVGGBlock
|
49 |
+
|
50 |
+
for layer in model.modules():
|
51 |
+
if isinstance(layer, RepVGGBlock):
|
52 |
+
layer.switch_to_deploy()
|
53 |
+
|
54 |
+
LOGGER.info("Switch model to deploy modality.")
|
55 |
+
|
56 |
+
def __call__(
|
57 |
+
self,
|
58 |
+
path_or_image,
|
59 |
+
conf_thres=0.25,
|
60 |
+
iou_thres=0.45,
|
61 |
+
classes=None,
|
62 |
+
agnostic_nms=False,
|
63 |
+
max_det=1000,
|
64 |
+
hide_labels=False,
|
65 |
+
hide_conf=False,
|
66 |
+
):
|
67 |
+
"""Model Inference and results visualization"""
|
68 |
+
|
69 |
+
img, img_src = self.precess_image(path_or_image, self.img_size, self.stride, self.half)
|
70 |
+
img = img.to(self.device)
|
71 |
+
if len(img.shape) == 3:
|
72 |
+
img = img[None]
|
73 |
+
# expand for batch dim
|
74 |
+
pred_results = self.model(img)
|
75 |
+
det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
|
76 |
+
|
77 |
+
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
78 |
+
img_ori = img_src
|
79 |
+
|
80 |
+
# check image and font
|
81 |
+
assert (
|
82 |
+
img_ori.data.contiguous
|
83 |
+
), "Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im)."
|
84 |
+
self.font_check()
|
85 |
+
|
86 |
+
if len(det):
|
87 |
+
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
|
88 |
+
|
89 |
+
for *xyxy, conf, cls in reversed(det):
|
90 |
+
class_num = int(cls) # integer class
|
91 |
+
label = (
|
92 |
+
None
|
93 |
+
if hide_labels
|
94 |
+
else (self.class_names[class_num] if hide_conf else f"{self.class_names[class_num]} {conf:.2f}")
|
95 |
+
)
|
96 |
+
|
97 |
+
self.plot_box_and_label(
|
98 |
+
img_ori,
|
99 |
+
max(round(sum(img_ori.shape) / 2 * 0.003), 2),
|
100 |
+
xyxy,
|
101 |
+
label,
|
102 |
+
color=self.generate_colors(class_num, True),
|
103 |
+
)
|
104 |
+
|
105 |
+
img_src = np.asarray(img_ori)
|
106 |
+
|
107 |
+
return img_src
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def precess_image(path_or_image, img_size, stride, half):
|
111 |
+
"""Process image before image inference."""
|
112 |
+
if isinstance(path_or_image, str):
|
113 |
+
try:
|
114 |
+
img_src = cv2.imread(path_or_image)
|
115 |
+
assert img_src is not None, f"Invalid image: {path_or_image}"
|
116 |
+
except Exception as e:
|
117 |
+
LOGGER.warning(e)
|
118 |
+
elif isinstance(path_or_image, np.ndarray):
|
119 |
+
img_src = path_or_image
|
120 |
+
elif isinstance(path_or_image, Image.Image):
|
121 |
+
img_src = np.array(path_or_image)
|
122 |
+
|
123 |
+
image = letterbox(img_src, img_size, stride=stride)[0]
|
124 |
+
|
125 |
+
# Convert
|
126 |
+
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
127 |
+
image = torch.from_numpy(np.ascontiguousarray(image))
|
128 |
+
image = image.half() if half else image.float() # uint8 to fp16/32
|
129 |
+
image /= 255 # 0 - 255 to 0.0 - 1.0
|
130 |
+
|
131 |
+
return image, img_src
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def rescale(ori_shape, boxes, target_shape):
|
135 |
+
"""Rescale the output to the original image shape"""
|
136 |
+
ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
|
137 |
+
padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
|
138 |
+
|
139 |
+
boxes[:, [0, 2]] -= padding[0]
|
140 |
+
boxes[:, [1, 3]] -= padding[1]
|
141 |
+
boxes[:, :4] /= ratio
|
142 |
+
|
143 |
+
boxes[:, 0].clamp_(0, target_shape[1]) # x1
|
144 |
+
boxes[:, 1].clamp_(0, target_shape[0]) # y1
|
145 |
+
boxes[:, 2].clamp_(0, target_shape[1]) # x2
|
146 |
+
boxes[:, 3].clamp_(0, target_shape[0]) # y2
|
147 |
+
|
148 |
+
return boxes
|
149 |
+
|
150 |
+
def check_img_size(self, img_size, s=32, floor=0):
|
151 |
+
"""Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
|
152 |
+
if isinstance(img_size, int): # integer i.e. img_size=640
|
153 |
+
new_size = max(self.make_divisible(img_size, int(s)), floor)
|
154 |
+
elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
|
155 |
+
new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
|
156 |
+
else:
|
157 |
+
raise Exception(f"Unsupported type of img_size: {type(img_size)}")
|
158 |
+
|
159 |
+
if new_size != img_size:
|
160 |
+
print(f"WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}")
|
161 |
+
return new_size if isinstance(img_size, list) else [new_size] * 2
|
162 |
+
|
163 |
+
def make_divisible(self, x, divisor):
|
164 |
+
# Upward revision the value x to make it evenly divisible by the divisor.
|
165 |
+
return math.ceil(x / divisor) * divisor
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def plot_box_and_label(image, lw, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)):
|
169 |
+
# Add one xyxy box to image with label
|
170 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
171 |
+
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
|
172 |
+
if label:
|
173 |
+
tf = max(lw - 1, 1) # font thickness
|
174 |
+
w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
|
175 |
+
outside = p1[1] - h - 3 >= 0 # label fits outside box
|
176 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
177 |
+
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
|
178 |
+
cv2.putText(
|
179 |
+
image,
|
180 |
+
label,
|
181 |
+
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
182 |
+
0,
|
183 |
+
lw / 3,
|
184 |
+
txt_color,
|
185 |
+
thickness=tf,
|
186 |
+
lineType=cv2.LINE_AA,
|
187 |
+
)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def font_check(font="./yolov6/utils/Arial.ttf", size=10):
|
191 |
+
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
192 |
+
assert osp.exists(font), f"font path not exists: {font}"
|
193 |
+
try:
|
194 |
+
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
195 |
+
except Exception as e: # download if missing
|
196 |
+
return ImageFont.truetype(str(font), size)
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def box_convert(x):
|
200 |
+
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
|
201 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
202 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
203 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
204 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
205 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
206 |
+
return y
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def generate_colors(i, bgr=False):
|
210 |
+
hex = (
|
211 |
+
"FF3838",
|
212 |
+
"FF9D97",
|
213 |
+
"FF701F",
|
214 |
+
"FFB21D",
|
215 |
+
"CFD231",
|
216 |
+
"48F90A",
|
217 |
+
"92CC17",
|
218 |
+
"3DDB86",
|
219 |
+
"1A9334",
|
220 |
+
"00D4BB",
|
221 |
+
"2C99A8",
|
222 |
+
"00C2FF",
|
223 |
+
"344593",
|
224 |
+
"6473FF",
|
225 |
+
"0018EC",
|
226 |
+
"8438FF",
|
227 |
+
"520085",
|
228 |
+
"CB38FF",
|
229 |
+
"FF95C8",
|
230 |
+
"FF37C7",
|
231 |
+
)
|
232 |
+
palette = []
|
233 |
+
for iter in hex:
|
234 |
+
h = "#" + iter
|
235 |
+
palette.append(tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)))
|
236 |
+
num = len(palette)
|
237 |
+
color = palette[int(i) % num]
|
238 |
+
return (color[2], color[1], color[0]) if bgr else color
|