stevengrove
initial commit
186701e
# Copyright (c) OpenMMLab. All rights reserved.
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
import os
import random
from argparse import ArgumentParser
import cv2
import mmcv
import numpy as np
import torch
from mmcv.transforms import Compose
from mmdet.utils import get_test_pipeline_cfg
from mmengine.config import Config, ConfigDict
from mmengine.utils import ProgressBar, path
from mmyolo.utils import register_all_modules
from mmyolo.utils.misc import get_file_list
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
args = parser.parse_args()
return args
def preprocess(config):
data_preprocess = config.get('model', {}).get('data_preprocessor', {})
mean = data_preprocess.get('mean', [0., 0., 0.])
std = data_preprocess.get('std', [1., 1., 1.])
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1)
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1)
class PreProcess(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x[None].float()
x -= mean.to(x.device)
x /= std.to(x.device)
return x
return PreProcess().eval()
def main():
args = parse_args()
# register all modules in mmdet into the registries
register_all_modules()
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)]
# build the model from a config file and a checkpoint file
if args.checkpoint.endswith('.onnx'):
model = ORTWrapper(args.checkpoint, args.device)
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith(
'.plan'):
model = TRTWrapper(args.checkpoint, args.device)
else:
raise NotImplementedError
model.to(args.device)
cfg = Config.fromfile(args.config)
class_names = cfg.get('class_name')
test_pipeline = get_test_pipeline_cfg(cfg)
test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'})
test_pipeline = Compose(test_pipeline)
pre_pipeline = preprocess(cfg)
if not args.show:
path.mkdir_or_exist(args.out_dir)
# get file list
files, source_type = get_file_list(args.img)
# start detector inference
progress_bar = ProgressBar(len(files))
for i, file in enumerate(files):
bgr = mmcv.imread(file)
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb')
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values()
pad_param = samples.get('pad_param',
np.array([0, 0, 0, 0], dtype=np.float32))
h, w = samples.get('ori_shape', rgb.shape[:2])
pad_param = torch.asarray(
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]],
device=args.device)
scale_factor = samples.get('scale_factor', [1., 1])
scale_factor = torch.asarray(scale_factor * 2, device=args.device)
data = pre_pipeline(data).to(args.device)
result = model(data)
if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
out_file = None if args.show else os.path.join(args.out_dir, filename)
# Get candidate predict info by num_dets
num_dets, bboxes, scores, labels = result
scores = scores[0, :num_dets]
bboxes = bboxes[0, :num_dets]
labels = labels[0, :num_dets]
bboxes -= pad_param
bboxes /= scale_factor
bboxes[:, 0::2].clamp_(0, w)
bboxes[:, 1::2].clamp_(0, h)
bboxes = bboxes.round().int()
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.tolist()
color = colors[label]
if class_names is not None:
label_name = class_names[label]
name = f'cls:{label_name}_score:{score:0.4f}'
else:
name = f'cls:{label}_score:{score:0.4f}'
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2)
cv2.putText(
bgr,
name, (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
2.0, [225, 255, 255],
thickness=3)
if args.show:
mmcv.imshow(bgr, 'result', 0)
else:
mmcv.imwrite(bgr, out_file)
progress_bar.update()
if __name__ == '__main__':
main()