PKaushik commited on
Commit
a6aa138
1 Parent(s): e31779b
Files changed (1) hide show
  1. yolov6/core/inferer.py +206 -0
yolov6/core/inferer.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import os.path as osp
5
+ import math
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ from PIL import ImageFont
11
+
12
+ from yolov6.utils.events import LOGGER, load_yaml
13
+ from yolov6.layers.common import DetectBackend
14
+ from yolov6.data.data_augment import letterbox
15
+ from yolov6.utils.nms import non_max_suppression
16
+ from yolov6.utils.torch_utils import get_model_info
17
+
18
+
19
+ class Inferer:
20
+ def __init__(self, source, weights, device, yaml, img_size, half):
21
+ import glob
22
+ from yolov6.data.datasets import IMG_FORMATS
23
+
24
+ self.__dict__.update(locals())
25
+
26
+ # Init model
27
+ self.device = device
28
+ self.img_size = img_size
29
+ cuda = self.device != 'cpu' and torch.cuda.is_available()
30
+ self.device = torch.device('cuda:0' if cuda else 'cpu')
31
+ self.model = DetectBackend(weights, device=self.device)
32
+ self.stride = self.model.stride
33
+ self.class_names = load_yaml(yaml)['names']
34
+ self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
35
+
36
+ # Half precision
37
+ if half & (self.device.type != 'cpu'):
38
+ self.model.model.half()
39
+ else:
40
+ self.model.model.float()
41
+ half = False
42
+
43
+ if self.device.type != 'cpu':
44
+ self.model(torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))) # warmup
45
+
46
+ # Load data
47
+ if os.path.isdir(source):
48
+ img_paths = sorted(glob.glob(os.path.join(source, '*.*'))) # dir
49
+ elif os.path.isfile(source):
50
+ img_paths = [source] # files
51
+ else:
52
+ raise Exception(f'Invalid path: {source}')
53
+ self.img_paths = [img_path for img_path in img_paths if img_path.split('.')[-1].lower() in IMG_FORMATS]
54
+
55
+ # Switch model to deploy status
56
+ self.model_switch(self.model, self.img_size)
57
+
58
+ def model_switch(self, model, img_size):
59
+ ''' Model switch to deploy status '''
60
+ from yolov6.layers.common import RepVGGBlock
61
+ for layer in model.modules():
62
+ if isinstance(layer, RepVGGBlock):
63
+ layer.switch_to_deploy()
64
+
65
+ LOGGER.info("Switch model to deploy modality.")
66
+
67
+ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf):
68
+ ''' Model Inference and results visualization '''
69
+
70
+ for img_path in tqdm(self.img_paths):
71
+ img, img_src = self.precess_image(img_path, self.img_size, self.stride, self.half)
72
+ img = img.to(self.device)
73
+ if len(img.shape) == 3:
74
+ img = img[None]
75
+ # expand for batch dim
76
+ pred_results = self.model(img)
77
+ det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
78
+
79
+ save_path = osp.join(save_dir, osp.basename(img_path)) # im.jpg
80
+ txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0])
81
+
82
+ gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
83
+ img_ori = img_src
84
+
85
+ # check image and font
86
+ assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).'
87
+ self.font_check()
88
+
89
+ if len(det):
90
+ det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
91
+
92
+ for *xyxy, conf, cls in reversed(det):
93
+ if save_txt: # Write to file
94
+ xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
95
+ line = (cls, *xywh, conf)
96
+ with open(txt_path + '.txt', 'a') as f:
97
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
98
+
99
+ if save_img:
100
+ class_num = int(cls) # integer class
101
+ label = None if hide_labels else (self.class_names[class_num] if hide_conf else f'{self.class_names[class_num]} {conf:.2f}')
102
+
103
+ self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))
104
+
105
+ img_src = np.asarray(img_ori)
106
+
107
+ # Save results (image with detections)
108
+ if save_img:
109
+ cv2.imwrite(save_path, img_src)
110
+
111
+ @staticmethod
112
+ def precess_image(path, img_size, stride, half):
113
+ '''Process image before image inference.'''
114
+ try:
115
+ img_src = cv2.imread(path)
116
+ assert img_src is not None, f'Invalid image: {path}'
117
+ except Exception as e:
118
+ LOGGER.warning(e)
119
+ image = letterbox(img_src, img_size, stride=stride)[0]
120
+
121
+ # Convert
122
+ image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
123
+ image = torch.from_numpy(np.ascontiguousarray(image))
124
+ image = image.half() if half else image.float() # uint8 to fp16/32
125
+ image /= 255 # 0 - 255 to 0.0 - 1.0
126
+
127
+ return image, img_src
128
+
129
+ @staticmethod
130
+ def rescale(ori_shape, boxes, target_shape):
131
+ '''Rescale the output to the original image shape'''
132
+ ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
133
+ padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
134
+
135
+ boxes[:, [0, 2]] -= padding[0]
136
+ boxes[:, [1, 3]] -= padding[1]
137
+ boxes[:, :4] /= ratio
138
+
139
+ boxes[:, 0].clamp_(0, target_shape[1]) # x1
140
+ boxes[:, 1].clamp_(0, target_shape[0]) # y1
141
+ boxes[:, 2].clamp_(0, target_shape[1]) # x2
142
+ boxes[:, 3].clamp_(0, target_shape[0]) # y2
143
+
144
+ return boxes
145
+
146
+ def check_img_size(self, img_size, s=32, floor=0):
147
+ """Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
148
+ if isinstance(img_size, int): # integer i.e. img_size=640
149
+ new_size = max(self.make_divisible(img_size, int(s)), floor)
150
+ elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
151
+ new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
152
+ else:
153
+ raise Exception(f"Unsupported type of img_size: {type(img_size)}")
154
+
155
+ if new_size != img_size:
156
+ print(f'WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}')
157
+ return new_size if isinstance(img_size,list) else [new_size]*2
158
+
159
+ def make_divisible(self, x, divisor):
160
+ # Upward revision the value x to make it evenly divisible by the divisor.
161
+ return math.ceil(x / divisor) * divisor
162
+
163
+ @staticmethod
164
+ def plot_box_and_label(image, lw, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
165
+ # Add one xyxy box to image with label
166
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
167
+ cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
168
+ if label:
169
+ tf = max(lw - 1, 1) # font thickness
170
+ w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
171
+ outside = p1[1] - h - 3 >= 0 # label fits outside box
172
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
173
+ cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
174
+ cv2.putText(image, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, lw / 3, txt_color,
175
+ thickness=tf, lineType=cv2.LINE_AA)
176
+
177
+ @staticmethod
178
+ def font_check(font='./yolov6/utils/Arial.ttf', size=10):
179
+ # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
180
+ assert osp.exists(font), f'font path not exists: {font}'
181
+ try:
182
+ return ImageFont.truetype(str(font) if font.exists() else font.name, size)
183
+ except Exception as e: # download if missing
184
+ return ImageFont.truetype(str(font), size)
185
+
186
+ @staticmethod
187
+ def box_convert(x):
188
+ # Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
189
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
190
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
191
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
192
+ y[:, 2] = x[:, 2] - x[:, 0] # width
193
+ y[:, 3] = x[:, 3] - x[:, 1] # height
194
+ return y
195
+
196
+ @staticmethod
197
+ def generate_colors(i, bgr=False):
198
+ hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
199
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
200
+ palette = []
201
+ for iter in hex:
202
+ h = '#' + iter
203
+ palette.append(tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)))
204
+ num = len(palette)
205
+ color = palette[int(i) % num]
206
+ return (color[2], color[1], color[0]) if bgr else color