Spaces:
Build error
Build error
commit
Browse files- yolov6/utils/nms.py +106 -0
yolov6/utils/nms.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# The code is based on
|
4 |
+
# https://github.com/ultralytics/yolov5/blob/master/utils/general.py
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
|
13 |
+
|
14 |
+
# Settings
|
15 |
+
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
16 |
+
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
17 |
+
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
18 |
+
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
19 |
+
|
20 |
+
|
21 |
+
def xywh2xyxy(x):
|
22 |
+
# Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right
|
23 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
24 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
25 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
26 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
27 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
28 |
+
return y
|
29 |
+
|
30 |
+
|
31 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300):
|
32 |
+
"""Runs Non-Maximum Suppression (NMS) on inference results.
|
33 |
+
This code is borrowed from: https://github.com/ultralytics/yolov5/blob/47233e1698b89fc437a4fb9463c815e9171be955/utils/general.py#L775
|
34 |
+
Args:
|
35 |
+
prediction: (tensor), with shape [N, 5 + num_classes], N is the number of bboxes.
|
36 |
+
conf_thres: (float) confidence threshold.
|
37 |
+
iou_thres: (float) iou threshold.
|
38 |
+
classes: (None or list[int]), if a list is provided, nms only keep the classes you provide.
|
39 |
+
agnostic: (bool), when it is set to True, we do class-independent nms, otherwise, different class would do nms respectively.
|
40 |
+
multi_label: (bool), when it is set to True, one box can have multi labels, otherwise, one box only huave one label.
|
41 |
+
max_det:(int), max number of output bboxes.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
|
45 |
+
"""
|
46 |
+
|
47 |
+
num_classes = prediction.shape[2] - 5 # number of classes
|
48 |
+
pred_candidates = prediction[..., 4] > conf_thres # candidates
|
49 |
+
|
50 |
+
# Check the parameters.
|
51 |
+
assert 0 <= conf_thres <= 1, f'conf_thresh must be in 0.0 to 1.0, however {conf_thres} is provided.'
|
52 |
+
assert 0 <= iou_thres <= 1, f'iou_thres must be in 0.0 to 1.0, however {iou_thres} is provided.'
|
53 |
+
|
54 |
+
# Function settings.
|
55 |
+
max_wh = 4096 # maximum box width and height
|
56 |
+
max_nms = 30000 # maximum number of boxes put into torchvision.ops.nms()
|
57 |
+
time_limit = 10.0 # quit the function when nms cost time exceed the limit time.
|
58 |
+
multi_label &= num_classes > 1 # multiple labels per box
|
59 |
+
|
60 |
+
tik = time.time()
|
61 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
62 |
+
for img_idx, x in enumerate(prediction): # image index, image inference
|
63 |
+
x = x[pred_candidates[img_idx]] # confidence
|
64 |
+
|
65 |
+
# If no box remains, skip the next process.
|
66 |
+
if not x.shape[0]:
|
67 |
+
continue
|
68 |
+
|
69 |
+
# confidence multiply the objectness
|
70 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
71 |
+
|
72 |
+
# (center x, center y, width, height) to (x1, y1, x2, y2)
|
73 |
+
box = xywh2xyxy(x[:, :4])
|
74 |
+
|
75 |
+
# Detections matrix's shape is (n,6), each row represents (xyxy, conf, cls)
|
76 |
+
if multi_label:
|
77 |
+
box_idx, class_idx = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
78 |
+
x = torch.cat((box[box_idx], x[box_idx, class_idx + 5, None], class_idx[:, None].float()), 1)
|
79 |
+
else: # Only keep the class with highest scores.
|
80 |
+
conf, class_idx = x[:, 5:].max(1, keepdim=True)
|
81 |
+
x = torch.cat((box, conf, class_idx.float()), 1)[conf.view(-1) > conf_thres]
|
82 |
+
|
83 |
+
# Filter by class, only keep boxes whose category is in classes.
|
84 |
+
if classes is not None:
|
85 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
86 |
+
|
87 |
+
# Check shape
|
88 |
+
num_box = x.shape[0] # number of boxes
|
89 |
+
if not num_box: # no boxes kept.
|
90 |
+
continue
|
91 |
+
elif num_box > max_nms: # excess max boxes' number.
|
92 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
93 |
+
|
94 |
+
# Batched NMS
|
95 |
+
class_offset = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
96 |
+
boxes, scores = x[:, :4] + class_offset, x[:, 4] # boxes (offset by class), scores
|
97 |
+
keep_box_idx = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
98 |
+
if keep_box_idx.shape[0] > max_det: # limit detections
|
99 |
+
keep_box_idx = keep_box_idx[:max_det]
|
100 |
+
|
101 |
+
output[img_idx] = x[keep_box_idx]
|
102 |
+
if (time.time() - tik) > time_limit:
|
103 |
+
print(f'WARNING: NMS cost time exceed the limited {time_limit}s.')
|
104 |
+
break # time limit exceeded
|
105 |
+
|
106 |
+
return output
|