File size: 4,968 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) OpenMMLab. All rights reserved.
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
import os
import random
from argparse import ArgumentParser
import cv2
import mmcv
import numpy as np
import torch
from mmcv.transforms import Compose
from mmdet.utils import get_test_pipeline_cfg
from mmengine.config import Config, ConfigDict
from mmengine.utils import ProgressBar, path
from mmyolo.utils import register_all_modules
from mmyolo.utils.misc import get_file_list
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
args = parser.parse_args()
return args
def preprocess(config):
data_preprocess = config.get('model', {}).get('data_preprocessor', {})
mean = data_preprocess.get('mean', [0., 0., 0.])
std = data_preprocess.get('std', [1., 1., 1.])
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1)
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1)
class PreProcess(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = x[None].float()
x -= mean.to(x.device)
x /= std.to(x.device)
return x
return PreProcess().eval()
def main():
args = parse_args()
# register all modules in mmdet into the registries
register_all_modules()
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)]
# build the model from a config file and a checkpoint file
if args.checkpoint.endswith('.onnx'):
model = ORTWrapper(args.checkpoint, args.device)
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith(
'.plan'):
model = TRTWrapper(args.checkpoint, args.device)
else:
raise NotImplementedError
model.to(args.device)
cfg = Config.fromfile(args.config)
class_names = cfg.get('class_name')
test_pipeline = get_test_pipeline_cfg(cfg)
test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'})
test_pipeline = Compose(test_pipeline)
pre_pipeline = preprocess(cfg)
if not args.show:
path.mkdir_or_exist(args.out_dir)
# get file list
files, source_type = get_file_list(args.img)
# start detector inference
progress_bar = ProgressBar(len(files))
for i, file in enumerate(files):
bgr = mmcv.imread(file)
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb')
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values()
pad_param = samples.get('pad_param',
np.array([0, 0, 0, 0], dtype=np.float32))
h, w = samples.get('ori_shape', rgb.shape[:2])
pad_param = torch.asarray(
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]],
device=args.device)
scale_factor = samples.get('scale_factor', [1., 1])
scale_factor = torch.asarray(scale_factor * 2, device=args.device)
data = pre_pipeline(data).to(args.device)
result = model(data)
if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
out_file = None if args.show else os.path.join(args.out_dir, filename)
# Get candidate predict info by num_dets
num_dets, bboxes, scores, labels = result
scores = scores[0, :num_dets]
bboxes = bboxes[0, :num_dets]
labels = labels[0, :num_dets]
bboxes -= pad_param
bboxes /= scale_factor
bboxes[:, 0::2].clamp_(0, w)
bboxes[:, 1::2].clamp_(0, h)
bboxes = bboxes.round().int()
for (bbox, score, label) in zip(bboxes, scores, labels):
bbox = bbox.tolist()
color = colors[label]
if class_names is not None:
label_name = class_names[label]
name = f'cls:{label_name}_score:{score:0.4f}'
else:
name = f'cls:{label}_score:{score:0.4f}'
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2)
cv2.putText(
bgr,
name, (bbox[0], bbox[1] - 2),
cv2.FONT_HERSHEY_SIMPLEX,
2.0, [225, 255, 255],
thickness=3)
if args.show:
mmcv.imshow(bgr, 'result', 0)
else:
mmcv.imwrite(bgr, out_file)
progress_bar.update()
if __name__ == '__main__':
main()
|