PKaushik commited on
Commit
7dc2bfe
1 Parent(s): 6d7be64
Files changed (1) hide show
  1. yolov6/utils/nms.py +106 -0
yolov6/utils/nms.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # The code is based on
4
+ # https://github.com/ultralytics/yolov5/blob/master/utils/general.py
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ import torchvision
12
+
13
+
14
+ # Settings
15
+ torch.set_printoptions(linewidth=320, precision=5, profile='long')
16
+ np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
17
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
18
+ os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
19
+
20
+
21
+ def xywh2xyxy(x):
22
+ # Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right
23
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
24
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
25
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
26
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
27
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
28
+ return y
29
+
30
+
31
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300):
32
+ """Runs Non-Maximum Suppression (NMS) on inference results.
33
+ This code is borrowed from: https://github.com/ultralytics/yolov5/blob/47233e1698b89fc437a4fb9463c815e9171be955/utils/general.py#L775
34
+ Args:
35
+ prediction: (tensor), with shape [N, 5 + num_classes], N is the number of bboxes.
36
+ conf_thres: (float) confidence threshold.
37
+ iou_thres: (float) iou threshold.
38
+ classes: (None or list[int]), if a list is provided, nms only keep the classes you provide.
39
+ agnostic: (bool), when it is set to True, we do class-independent nms, otherwise, different class would do nms respectively.
40
+ multi_label: (bool), when it is set to True, one box can have multi labels, otherwise, one box only huave one label.
41
+ max_det:(int), max number of output bboxes.
42
+
43
+ Returns:
44
+ list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
45
+ """
46
+
47
+ num_classes = prediction.shape[2] - 5 # number of classes
48
+ pred_candidates = prediction[..., 4] > conf_thres # candidates
49
+
50
+ # Check the parameters.
51
+ assert 0 <= conf_thres <= 1, f'conf_thresh must be in 0.0 to 1.0, however {conf_thres} is provided.'
52
+ assert 0 <= iou_thres <= 1, f'iou_thres must be in 0.0 to 1.0, however {iou_thres} is provided.'
53
+
54
+ # Function settings.
55
+ max_wh = 4096 # maximum box width and height
56
+ max_nms = 30000 # maximum number of boxes put into torchvision.ops.nms()
57
+ time_limit = 10.0 # quit the function when nms cost time exceed the limit time.
58
+ multi_label &= num_classes > 1 # multiple labels per box
59
+
60
+ tik = time.time()
61
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
62
+ for img_idx, x in enumerate(prediction): # image index, image inference
63
+ x = x[pred_candidates[img_idx]] # confidence
64
+
65
+ # If no box remains, skip the next process.
66
+ if not x.shape[0]:
67
+ continue
68
+
69
+ # confidence multiply the objectness
70
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
71
+
72
+ # (center x, center y, width, height) to (x1, y1, x2, y2)
73
+ box = xywh2xyxy(x[:, :4])
74
+
75
+ # Detections matrix's shape is (n,6), each row represents (xyxy, conf, cls)
76
+ if multi_label:
77
+ box_idx, class_idx = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
78
+ x = torch.cat((box[box_idx], x[box_idx, class_idx + 5, None], class_idx[:, None].float()), 1)
79
+ else: # Only keep the class with highest scores.
80
+ conf, class_idx = x[:, 5:].max(1, keepdim=True)
81
+ x = torch.cat((box, conf, class_idx.float()), 1)[conf.view(-1) > conf_thres]
82
+
83
+ # Filter by class, only keep boxes whose category is in classes.
84
+ if classes is not None:
85
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
86
+
87
+ # Check shape
88
+ num_box = x.shape[0] # number of boxes
89
+ if not num_box: # no boxes kept.
90
+ continue
91
+ elif num_box > max_nms: # excess max boxes' number.
92
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
93
+
94
+ # Batched NMS
95
+ class_offset = x[:, 5:6] * (0 if agnostic else max_wh) # classes
96
+ boxes, scores = x[:, :4] + class_offset, x[:, 4] # boxes (offset by class), scores
97
+ keep_box_idx = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
98
+ if keep_box_idx.shape[0] > max_det: # limit detections
99
+ keep_box_idx = keep_box_idx[:max_det]
100
+
101
+ output[img_idx] = x[keep_box_idx]
102
+ if (time.time() - tik) > time_limit:
103
+ print(f'WARNING: NMS cost time exceed the limited {time_limit}s.')
104
+ break # time limit exceeded
105
+
106
+ return output