|
import warnings |
|
|
|
import mmcv |
|
import numpy as np |
|
import torch |
|
from mmcv.ops import RoIPool |
|
from mmcv.parallel import collate, scatter |
|
from mmcv.runner import load_checkpoint |
|
|
|
from mmdet.core import get_classes |
|
from mmdet.datasets import replace_ImageToTensor |
|
from mmdet.datasets.pipelines import Compose |
|
from mmdet.models import build_detector |
|
|
|
|
|
def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): |
|
"""Initialize a detector from config file. |
|
|
|
Args: |
|
config (str or :obj:`mmcv.Config`): Config file path or the config |
|
object. |
|
checkpoint (str, optional): Checkpoint path. If left as None, the model |
|
will not load any weights. |
|
cfg_options (dict): Options to override some settings in the used |
|
config. |
|
|
|
Returns: |
|
nn.Module: The constructed detector. |
|
""" |
|
if isinstance(config, str): |
|
config = mmcv.Config.fromfile(config) |
|
elif not isinstance(config, mmcv.Config): |
|
raise TypeError('config must be a filename or Config object, ' |
|
f'but got {type(config)}') |
|
if cfg_options is not None: |
|
config.merge_from_dict(cfg_options) |
|
config.model.pretrained = None |
|
config.model.train_cfg = None |
|
model = build_detector(config.model, test_cfg=config.get('test_cfg')) |
|
if checkpoint is not None: |
|
map_loc = 'cpu' if device == 'cpu' else None |
|
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) |
|
if 'CLASSES' in checkpoint.get('meta', {}): |
|
model.CLASSES = checkpoint['meta']['CLASSES'] |
|
else: |
|
warnings.simplefilter('once') |
|
warnings.warn('Class names are not saved in the checkpoint\'s ' |
|
'meta data, use COCO classes by default.') |
|
model.CLASSES = get_classes('coco') |
|
model.cfg = config |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
class LoadImage(object): |
|
"""Deprecated. |
|
|
|
A simple pipeline to load image. |
|
""" |
|
|
|
def __call__(self, results): |
|
"""Call function to load images into results. |
|
|
|
Args: |
|
results (dict): A result dict contains the file name |
|
of the image to be read. |
|
Returns: |
|
dict: ``results`` will be returned containing loaded image. |
|
""" |
|
warnings.simplefilter('once') |
|
warnings.warn('`LoadImage` is deprecated and will be removed in ' |
|
'future releases. You may use `LoadImageFromWebcam` ' |
|
'from `mmdet.datasets.pipelines.` instead.') |
|
if isinstance(results['img'], str): |
|
results['filename'] = results['img'] |
|
results['ori_filename'] = results['img'] |
|
else: |
|
results['filename'] = None |
|
results['ori_filename'] = None |
|
img = mmcv.imread(results['img']) |
|
results['img'] = img |
|
results['img_fields'] = ['img'] |
|
results['img_shape'] = img.shape |
|
results['ori_shape'] = img.shape |
|
return results |
|
|
|
|
|
def inference_detector(model, imgs): |
|
"""Inference image(s) with the detector. |
|
|
|
Args: |
|
model (nn.Module): The loaded detector. |
|
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): |
|
Either image files or loaded images. |
|
|
|
Returns: |
|
If imgs is a list or tuple, the same length list type results |
|
will be returned, otherwise return the detection results directly. |
|
""" |
|
|
|
if isinstance(imgs, (list, tuple)): |
|
is_batch = True |
|
else: |
|
imgs = [imgs] |
|
is_batch = False |
|
|
|
cfg = model.cfg |
|
device = next(model.parameters()).device |
|
|
|
if isinstance(imgs[0], np.ndarray): |
|
cfg = cfg.copy() |
|
|
|
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' |
|
|
|
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) |
|
test_pipeline = Compose(cfg.data.test.pipeline) |
|
|
|
datas = [] |
|
for img in imgs: |
|
|
|
if isinstance(img, np.ndarray): |
|
|
|
data = dict(img=img) |
|
else: |
|
|
|
data = dict(img_info=dict(filename=img), img_prefix=None) |
|
|
|
data = test_pipeline(data) |
|
datas.append(data) |
|
|
|
data = collate(datas, samples_per_gpu=len(imgs)) |
|
|
|
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] |
|
data['img'] = [img.data[0] for img in data['img']] |
|
if next(model.parameters()).is_cuda: |
|
|
|
data = scatter(data, [device])[0] |
|
else: |
|
for m in model.modules(): |
|
assert not isinstance( |
|
m, RoIPool |
|
), 'CPU inference with RoIPool is not supported currently.' |
|
|
|
|
|
with torch.no_grad(): |
|
results = model(return_loss=False, rescale=True, **data) |
|
|
|
if not is_batch: |
|
return results[0] |
|
else: |
|
return results |
|
|
|
|
|
async def async_inference_detector(model, imgs): |
|
"""Async inference image(s) with the detector. |
|
|
|
Args: |
|
model (nn.Module): The loaded detector. |
|
img (str | ndarray): Either image files or loaded images. |
|
|
|
Returns: |
|
Awaitable detection results. |
|
""" |
|
if not isinstance(imgs, (list, tuple)): |
|
imgs = [imgs] |
|
|
|
cfg = model.cfg |
|
device = next(model.parameters()).device |
|
|
|
if isinstance(imgs[0], np.ndarray): |
|
cfg = cfg.copy() |
|
|
|
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' |
|
|
|
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) |
|
test_pipeline = Compose(cfg.data.test.pipeline) |
|
|
|
datas = [] |
|
for img in imgs: |
|
|
|
if isinstance(img, np.ndarray): |
|
|
|
data = dict(img=img) |
|
else: |
|
|
|
data = dict(img_info=dict(filename=img), img_prefix=None) |
|
|
|
data = test_pipeline(data) |
|
datas.append(data) |
|
|
|
data = collate(datas, samples_per_gpu=len(imgs)) |
|
|
|
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] |
|
data['img'] = [img.data[0] for img in data['img']] |
|
if next(model.parameters()).is_cuda: |
|
|
|
data = scatter(data, [device])[0] |
|
else: |
|
for m in model.modules(): |
|
assert not isinstance( |
|
m, RoIPool |
|
), 'CPU inference with RoIPool is not supported currently.' |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
results = await model.aforward_test(rescale=True, **data) |
|
return results |
|
|
|
|
|
def show_result_pyplot(model, |
|
img, |
|
result, |
|
score_thr=0.3, |
|
title='result', |
|
wait_time=0): |
|
"""Visualize the detection results on the image. |
|
|
|
Args: |
|
model (nn.Module): The loaded detector. |
|
img (str or np.ndarray): Image filename or loaded image. |
|
result (tuple[list] or list): The detection result, can be either |
|
(bbox, segm) or just bbox. |
|
score_thr (float): The threshold to visualize the bboxes and masks. |
|
title (str): Title of the pyplot figure. |
|
wait_time (float): Value of waitKey param. |
|
Default: 0. |
|
""" |
|
if hasattr(model, 'module'): |
|
model = model.module |
|
model.show_result( |
|
img, |
|
result, |
|
score_thr=score_thr, |
|
show=True, |
|
wait_time=wait_time, |
|
win_name=title, |
|
bbox_color=(72, 101, 241), |
|
text_color=(72, 101, 241)) |
|
|