Spaces:
Build error
Build error
import argparse | |
import os | |
import os.path as osp | |
import numpy as np | |
import onnx | |
import onnxruntime as ort | |
import torch | |
from mmcv.ops import get_onnxruntime_op_path | |
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, | |
save_trt_engine) | |
from mmcv.visualization.image import imshow_det_bboxes | |
from mmdet.core import get_classes, preprocess_example_input | |
def get_GiB(x: int): | |
"""return x GiB.""" | |
return x * (1 << 30) | |
def onnx2tensorrt(onnx_file, | |
trt_file, | |
input_config, | |
verify=False, | |
show=False, | |
dataset='coco', | |
workspace_size=1): | |
onnx_model = onnx.load(onnx_file) | |
input_shape = input_config['input_shape'] | |
# create trt engine and wraper | |
opt_shape_dict = {'input': [input_shape, input_shape, input_shape]} | |
max_workspace_size = get_GiB(workspace_size) | |
trt_engine = onnx2trt( | |
onnx_model, | |
opt_shape_dict, | |
fp16_mode=False, | |
max_workspace_size=max_workspace_size) | |
save_dir, _ = osp.split(trt_file) | |
if save_dir: | |
os.makedirs(save_dir, exist_ok=True) | |
save_trt_engine(trt_engine, trt_file) | |
print(f'Successfully created TensorRT engine: {trt_file}') | |
if verify: | |
one_img, one_meta = preprocess_example_input(input_config) | |
input_img_cpu = one_img.detach().cpu().numpy() | |
input_img_cuda = one_img.cuda() | |
img = one_meta['show_img'] | |
# Get results from TensorRT | |
trt_model = TRTWraper(trt_file, ['input'], ['boxes', 'labels']) | |
with torch.no_grad(): | |
trt_outputs = trt_model({'input': input_img_cuda}) | |
trt_boxes = trt_outputs['boxes'].detach().cpu().numpy() | |
trt_labels = trt_outputs['labels'].detach().cpu().numpy() | |
# Get results from ONNXRuntime | |
ort_custom_op_path = get_onnxruntime_op_path() | |
session_options = ort.SessionOptions() | |
if osp.exists(ort_custom_op_path): | |
session_options.register_custom_ops_library(ort_custom_op_path) | |
sess = ort.InferenceSession(onnx_file, session_options) | |
onnx_outputs = sess.run(None, { | |
'input': input_img_cpu, | |
}) | |
ort_boxes, ort_labels = onnx_outputs | |
# Show detection outputs | |
if show: | |
CLASSES = get_classes(dataset) | |
score_thr = 0.35 | |
imshow_det_bboxes( | |
img.copy(), | |
trt_boxes, | |
trt_labels, | |
CLASSES, | |
score_thr=score_thr, | |
win_name='TensorRT') | |
imshow_det_bboxes( | |
img.copy(), | |
ort_boxes, | |
ort_labels, | |
CLASSES, | |
score_thr=score_thr, | |
win_name='ONNXRuntime') | |
# Compare results | |
np.testing.assert_allclose( | |
ort_boxes, trt_boxes, rtol=1e-03, atol=1e-05) | |
np.testing.assert_allclose(ort_labels, trt_labels) | |
print('The numerical values are the same ' + | |
'between ONNXRuntime and TensorRT') | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Convert MMDetection models from ONNX to TensorRT') | |
parser.add_argument('model', help='Filename of input ONNX model') | |
parser.add_argument( | |
'--trt-file', | |
type=str, | |
default='tmp.trt', | |
help='Filename of output TensorRT engine') | |
parser.add_argument( | |
'--input-img', type=str, default='', help='Image for test') | |
parser.add_argument( | |
'--show', action='store_true', help='Whether to show output results') | |
parser.add_argument( | |
'--dataset', type=str, default='coco', help='Dataset name') | |
parser.add_argument( | |
'--verify', | |
action='store_true', | |
help='Verify the outputs of ONNXRuntime and TensorRT') | |
parser.add_argument( | |
'--to-rgb', | |
action='store_false', | |
help='Feed model with RGB or BGR image. Default is RGB.') | |
parser.add_argument( | |
'--shape', | |
type=int, | |
nargs='+', | |
default=[400, 600], | |
help='Input size of the model') | |
parser.add_argument( | |
'--mean', | |
type=float, | |
nargs='+', | |
default=[123.675, 116.28, 103.53], | |
help='Mean value used for preprocess input data') | |
parser.add_argument( | |
'--std', | |
type=float, | |
nargs='+', | |
default=[58.395, 57.12, 57.375], | |
help='Variance value used for preprocess input data') | |
parser.add_argument( | |
'--workspace-size', | |
type=int, | |
default=1, | |
help='Max workspace size in GiB') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' | |
args = parse_args() | |
if not args.input_img: | |
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.jpg') | |
if len(args.shape) == 1: | |
input_shape = (1, 3, args.shape[0], args.shape[0]) | |
elif len(args.shape) == 2: | |
input_shape = (1, 3) + tuple(args.shape) | |
else: | |
raise ValueError('invalid input shape') | |
assert len(args.mean) == 3 | |
assert len(args.std) == 3 | |
normalize_cfg = {'mean': args.mean, 'std': args.std, 'to_rgb': args.to_rgb} | |
input_config = { | |
'input_shape': input_shape, | |
'input_path': args.input_img, | |
'normalize_cfg': normalize_cfg | |
} | |
# Create TensorRT engine | |
onnx2tensorrt( | |
args.model, | |
args.trt_file, | |
input_config, | |
verify=args.verify, | |
show=args.show, | |
dataset=args.dataset, | |
workspace_size=args.workspace_size) | |