|
|
|
from projects.easydeploy.model import ORTWrapper, TRTWrapper |
|
import os |
|
import random |
|
from argparse import ArgumentParser |
|
|
|
import cv2 |
|
import mmcv |
|
import numpy as np |
|
import torch |
|
from mmcv.transforms import Compose |
|
from mmdet.utils import get_test_pipeline_cfg |
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.utils import ProgressBar, path |
|
|
|
from mmyolo.utils import register_all_modules |
|
from mmyolo.utils.misc import get_file_list |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
'img', help='Image path, include image file, dir and URL.') |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument( |
|
'--out-dir', default='./output', help='Path to output file') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference') |
|
parser.add_argument( |
|
'--show', action='store_true', help='Show the detection results') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def preprocess(config): |
|
data_preprocess = config.get('model', {}).get('data_preprocessor', {}) |
|
mean = data_preprocess.get('mean', [0., 0., 0.]) |
|
std = data_preprocess.get('std', [1., 1., 1.]) |
|
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1) |
|
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1) |
|
|
|
class PreProcess(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
x = x[None].float() |
|
x -= mean.to(x.device) |
|
x /= std.to(x.device) |
|
return x |
|
|
|
return PreProcess().eval() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
|
|
register_all_modules() |
|
|
|
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)] |
|
|
|
|
|
if args.checkpoint.endswith('.onnx'): |
|
model = ORTWrapper(args.checkpoint, args.device) |
|
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith( |
|
'.plan'): |
|
model = TRTWrapper(args.checkpoint, args.device) |
|
else: |
|
raise NotImplementedError |
|
|
|
model.to(args.device) |
|
|
|
cfg = Config.fromfile(args.config) |
|
class_names = cfg.get('class_name') |
|
|
|
test_pipeline = get_test_pipeline_cfg(cfg) |
|
test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'}) |
|
test_pipeline = Compose(test_pipeline) |
|
|
|
pre_pipeline = preprocess(cfg) |
|
|
|
if not args.show: |
|
path.mkdir_or_exist(args.out_dir) |
|
|
|
|
|
files, source_type = get_file_list(args.img) |
|
|
|
|
|
progress_bar = ProgressBar(len(files)) |
|
for i, file in enumerate(files): |
|
bgr = mmcv.imread(file) |
|
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb') |
|
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values() |
|
pad_param = samples.get('pad_param', |
|
np.array([0, 0, 0, 0], dtype=np.float32)) |
|
h, w = samples.get('ori_shape', rgb.shape[:2]) |
|
pad_param = torch.asarray( |
|
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]], |
|
device=args.device) |
|
scale_factor = samples.get('scale_factor', [1., 1]) |
|
scale_factor = torch.asarray(scale_factor * 2, device=args.device) |
|
data = pre_pipeline(data).to(args.device) |
|
|
|
result = model(data) |
|
if source_type['is_dir']: |
|
filename = os.path.relpath(file, args.img).replace('/', '_') |
|
else: |
|
filename = os.path.basename(file) |
|
out_file = None if args.show else os.path.join(args.out_dir, filename) |
|
|
|
|
|
num_dets, bboxes, scores, labels = result |
|
scores = scores[0, :num_dets] |
|
bboxes = bboxes[0, :num_dets] |
|
labels = labels[0, :num_dets] |
|
bboxes -= pad_param |
|
bboxes /= scale_factor |
|
|
|
bboxes[:, 0::2].clamp_(0, w) |
|
bboxes[:, 1::2].clamp_(0, h) |
|
bboxes = bboxes.round().int() |
|
|
|
for (bbox, score, label) in zip(bboxes, scores, labels): |
|
bbox = bbox.tolist() |
|
color = colors[label] |
|
|
|
if class_names is not None: |
|
label_name = class_names[label] |
|
name = f'cls:{label_name}_score:{score:0.4f}' |
|
else: |
|
name = f'cls:{label}_score:{score:0.4f}' |
|
|
|
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2) |
|
cv2.putText( |
|
bgr, |
|
name, (bbox[0], bbox[1] - 2), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
2.0, [225, 255, 255], |
|
thickness=3) |
|
|
|
if args.show: |
|
mmcv.imshow(bgr, 'result', 0) |
|
else: |
|
mmcv.imwrite(bgr, out_file) |
|
progress_bar.update() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|