3v324v23's picture
Add files
c9019cd
raw
history blame
8 kB
import warnings
import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector
def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): Options to override some settings in the used
config.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
config.model.pretrained = None
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
class LoadImage(object):
"""Deprecated.
A simple pipeline to load image.
"""
def __call__(self, results):
"""Call function to load images into results.
Args:
results (dict): A result dict contains the file name
of the image to be read.
Returns:
dict: ``results`` will be returned containing loaded image.
"""
warnings.simplefilter('once')
warnings.warn('`LoadImage` is deprecated and will be removed in '
'future releases. You may use `LoadImageFromWebcam` '
'from `mmdet.datasets.pipelines.` instead.')
if isinstance(results['img'], str):
results['filename'] = results['img']
results['ori_filename'] = results['img']
else:
results['filename'] = None
results['ori_filename'] = None
img = mmcv.imread(results['img'])
results['img'] = img
results['img_fields'] = ['img']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
return results
def inference_detector(model, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
if isinstance(imgs, (list, tuple)):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = model.cfg
device = next(model.parameters()).device # model device
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
datas = []
for img in imgs:
# prepare data
if isinstance(img, np.ndarray):
# directly add img
data = dict(img=img)
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
data = test_pipeline(data)
datas.append(data)
data = collate(datas, samples_per_gpu=len(imgs))
# just get the actual data from DataContainer
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
data['img'] = [img.data[0] for img in data['img']]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# forward the model
with torch.no_grad():
results = model(return_loss=False, rescale=True, **data)
if not is_batch:
return results[0]
else:
return results
async def async_inference_detector(model, imgs):
"""Async inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
img (str | ndarray): Either image files or loaded images.
Returns:
Awaitable detection results.
"""
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
cfg = model.cfg
device = next(model.parameters()).device # model device
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
datas = []
for img in imgs:
# prepare data
if isinstance(img, np.ndarray):
# directly add img
data = dict(img=img)
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
data = test_pipeline(data)
datas.append(data)
data = collate(datas, samples_per_gpu=len(imgs))
# just get the actual data from DataContainer
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
data['img'] = [img.data[0] for img in data['img']]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# We don't restore `torch.is_grad_enabled()` value during concurrent
# inference since execution can overlap
torch.set_grad_enabled(False)
results = await model.aforward_test(rescale=True, **data)
return results
def show_result_pyplot(model,
img,
result,
score_thr=0.3,
title='result',
wait_time=0):
"""Visualize the detection results on the image.
Args:
model (nn.Module): The loaded detector.
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
score_thr (float): The threshold to visualize the bboxes and masks.
title (str): Title of the pyplot figure.
wait_time (float): Value of waitKey param.
Default: 0.
"""
if hasattr(model, 'module'):
model = model.module
model.show_result(
img,
result,
score_thr=score_thr,
show=True,
wait_time=wait_time,
win_name=title,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241))