Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# The code is based on | |
# https://github.com/ultralytics/yolov5/blob/master/utils/general.py | |
import os | |
import time | |
import numpy as np | |
import cv2 | |
import torch | |
import torchvision | |
# Settings | |
torch.set_printoptions(linewidth=320, precision=5, profile='long') | |
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 | |
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) | |
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads | |
def xywh2xyxy(x): | |
# Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right | |
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) | |
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x | |
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y | |
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x | |
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y | |
return y | |
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300): | |
"""Runs Non-Maximum Suppression (NMS) on inference results. | |
This code is borrowed from: https://github.com/ultralytics/yolov5/blob/47233e1698b89fc437a4fb9463c815e9171be955/utils/general.py#L775 | |
Args: | |
prediction: (tensor), with shape [N, 5 + num_classes], N is the number of bboxes. | |
conf_thres: (float) confidence threshold. | |
iou_thres: (float) iou threshold. | |
classes: (None or list[int]), if a list is provided, nms only keep the classes you provide. | |
agnostic: (bool), when it is set to True, we do class-independent nms, otherwise, different class would do nms respectively. | |
multi_label: (bool), when it is set to True, one box can have multi labels, otherwise, one box only huave one label. | |
max_det:(int), max number of output bboxes. | |
Returns: | |
list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls]. | |
""" | |
num_classes = prediction.shape[2] - 5 # number of classes | |
pred_candidates = prediction[..., 4] > conf_thres # candidates | |
# Check the parameters. | |
assert 0 <= conf_thres <= 1, f'conf_thresh must be in 0.0 to 1.0, however {conf_thres} is provided.' | |
assert 0 <= iou_thres <= 1, f'iou_thres must be in 0.0 to 1.0, however {iou_thres} is provided.' | |
# Function settings. | |
max_wh = 4096 # maximum box width and height | |
max_nms = 30000 # maximum number of boxes put into torchvision.ops.nms() | |
time_limit = 10.0 # quit the function when nms cost time exceed the limit time. | |
multi_label &= num_classes > 1 # multiple labels per box | |
tik = time.time() | |
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] | |
for img_idx, x in enumerate(prediction): # image index, image inference | |
x = x[pred_candidates[img_idx]] # confidence | |
# If no box remains, skip the next process. | |
if not x.shape[0]: | |
continue | |
# confidence multiply the objectness | |
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf | |
# (center x, center y, width, height) to (x1, y1, x2, y2) | |
box = xywh2xyxy(x[:, :4]) | |
# Detections matrix's shape is (n,6), each row represents (xyxy, conf, cls) | |
if multi_label: | |
box_idx, class_idx = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T | |
x = torch.cat((box[box_idx], x[box_idx, class_idx + 5, None], class_idx[:, None].float()), 1) | |
else: # Only keep the class with highest scores. | |
conf, class_idx = x[:, 5:].max(1, keepdim=True) | |
x = torch.cat((box, conf, class_idx.float()), 1)[conf.view(-1) > conf_thres] | |
# Filter by class, only keep boxes whose category is in classes. | |
if classes is not None: | |
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] | |
# Check shape | |
num_box = x.shape[0] # number of boxes | |
if not num_box: # no boxes kept. | |
continue | |
elif num_box > max_nms: # excess max boxes' number. | |
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence | |
# Batched NMS | |
class_offset = x[:, 5:6] * (0 if agnostic else max_wh) # classes | |
boxes, scores = x[:, :4] + class_offset, x[:, 4] # boxes (offset by class), scores | |
keep_box_idx = torchvision.ops.nms(boxes, scores, iou_thres) # NMS | |
if keep_box_idx.shape[0] > max_det: # limit detections | |
keep_box_idx = keep_box_idx[:max_det] | |
output[img_idx] = x[keep_box_idx] | |
if (time.time() - tik) > time_limit: | |
print(f'WARNING: NMS cost time exceed the limited {time_limit}s.') | |
break # time limit exceeded | |
return output | |