|
import argparse |
|
import os |
|
import sys |
|
import warnings |
|
from io import BytesIO |
|
from pathlib import Path |
|
|
|
import onnx |
|
import torch |
|
from mmdet.apis import init_detector |
|
from mmengine.config import ConfigDict |
|
from mmengine.logging import print_log |
|
from mmengine.utils.path import mkdir_or_exist |
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parents[3])) |
|
from projects.easydeploy.model import DeployModel, MMYOLOBackend |
|
|
|
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) |
|
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) |
|
warnings.filterwarnings(action='ignore', category=UserWarning) |
|
warnings.filterwarnings(action='ignore', category=FutureWarning) |
|
warnings.filterwarnings(action='ignore', category=ResourceWarning) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument( |
|
'--model-only', action='store_true', help='Export model only') |
|
parser.add_argument( |
|
'--work-dir', default='./work_dir', help='Path to save export model') |
|
parser.add_argument( |
|
'--img-size', |
|
nargs='+', |
|
type=int, |
|
default=[640, 640], |
|
help='Image size of height and width') |
|
parser.add_argument('--batch-size', type=int, default=1, help='Batch size') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference') |
|
parser.add_argument( |
|
'--simplify', |
|
action='store_true', |
|
help='Simplify onnx model by onnx-sim') |
|
parser.add_argument( |
|
'--opset', type=int, default=11, help='ONNX opset version') |
|
parser.add_argument( |
|
'--backend', |
|
type=str, |
|
default='onnxruntime', |
|
help='Backend for export onnx') |
|
parser.add_argument( |
|
'--pre-topk', |
|
type=int, |
|
default=1000, |
|
help='Postprocess pre topk bboxes feed into NMS') |
|
parser.add_argument( |
|
'--keep-topk', |
|
type=int, |
|
default=100, |
|
help='Postprocess keep topk bboxes out of NMS') |
|
parser.add_argument( |
|
'--iou-threshold', |
|
type=float, |
|
default=0.65, |
|
help='IoU threshold for NMS') |
|
parser.add_argument( |
|
'--score-threshold', |
|
type=float, |
|
default=0.25, |
|
help='Score threshold for NMS') |
|
args = parser.parse_args() |
|
args.img_size *= 2 if len(args.img_size) == 1 else 1 |
|
return args |
|
|
|
|
|
def build_model_from_cfg(config_path, checkpoint_path, device): |
|
model = init_detector(config_path, checkpoint_path, device=device) |
|
model.eval() |
|
return model |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
mkdir_or_exist(args.work_dir) |
|
backend = MMYOLOBackend(args.backend.lower()) |
|
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO, |
|
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): |
|
if not args.model_only: |
|
print_log('Export ONNX with bbox decoder and NMS ...') |
|
else: |
|
args.model_only = True |
|
print_log(f'Can not export postprocess for {args.backend.lower()}.\n' |
|
f'Set "args.model_only=True" default.') |
|
if args.model_only: |
|
postprocess_cfg = None |
|
output_names = None |
|
else: |
|
postprocess_cfg = ConfigDict( |
|
pre_top_k=args.pre_topk, |
|
keep_top_k=args.keep_topk, |
|
iou_threshold=args.iou_threshold, |
|
score_threshold=args.score_threshold) |
|
output_names = ['num_dets', 'boxes', 'scores', 'labels'] |
|
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device) |
|
|
|
deploy_model = DeployModel( |
|
baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg) |
|
deploy_model.eval() |
|
|
|
fake_input = torch.randn(args.batch_size, 3, |
|
*args.img_size).to(args.device) |
|
|
|
deploy_model(fake_input) |
|
|
|
save_onnx_path = os.path.join( |
|
args.work_dir, |
|
os.path.basename(args.checkpoint).replace('pth', 'onnx')) |
|
|
|
with BytesIO() as f: |
|
torch.onnx.export( |
|
deploy_model, |
|
fake_input, |
|
f, |
|
input_names=['images'], |
|
output_names=output_names, |
|
opset_version=args.opset) |
|
f.seek(0) |
|
onnx_model = onnx.load(f) |
|
onnx.checker.check_model(onnx_model) |
|
|
|
|
|
if not args.model_only and backend in (MMYOLOBackend.TENSORRT8, |
|
MMYOLOBackend.TENSORRT7): |
|
shapes = [ |
|
args.batch_size, 1, args.batch_size, args.keep_topk, 4, |
|
args.batch_size, args.keep_topk, args.batch_size, |
|
args.keep_topk |
|
] |
|
for i in onnx_model.graph.output: |
|
for j in i.type.tensor_type.shape.dim: |
|
j.dim_param = str(shapes.pop(0)) |
|
if args.simplify: |
|
try: |
|
import onnxsim |
|
onnx_model, check = onnxsim.simplify(onnx_model) |
|
assert check, 'assert check failed' |
|
except Exception as e: |
|
print_log(f'Simplify failure: {e}') |
|
onnx.save(onnx_model, save_onnx_path) |
|
print_log(f'ONNX export success, save into {save_onnx_path}') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|