Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from functools import partial | |
import mmcv | |
import numpy as np | |
import torch | |
from mmcv.runner import load_checkpoint | |
def generate_inputs_and_wrap_model(config_path, | |
checkpoint_path, | |
input_config, | |
cfg_options=None): | |
"""Prepare sample input and wrap model for ONNX export. | |
The ONNX export API only accept args, and all inputs should be | |
torch.Tensor or corresponding types (such as tuple of tensor). | |
So we should call this function before exporting. This function will: | |
1. generate corresponding inputs which are used to execute the model. | |
2. Wrap the model's forward function. | |
For example, the MMDet models' forward function has a parameter | |
``return_loss:bool``. As we want to set it as False while export API | |
supports neither bool type or kwargs. So we have to replace the forward | |
method like ``model.forward = partial(model.forward, return_loss=False)``. | |
Args: | |
config_path (str): the OpenMMLab config for the model we want to | |
export to ONNX | |
checkpoint_path (str): Path to the corresponding checkpoint | |
input_config (dict): the exactly data in this dict depends on the | |
framework. For MMSeg, we can just declare the input shape, | |
and generate the dummy data accordingly. However, for MMDet, | |
we may pass the real img path, or the NMS will return None | |
as there is no legal bbox. | |
Returns: | |
tuple: (model, tensor_data) wrapped model which can be called by | |
``model(*tensor_data)`` and a list of inputs which are used to | |
execute the model while exporting. | |
""" | |
model = build_model_from_cfg( | |
config_path, checkpoint_path, cfg_options=cfg_options) | |
one_img, one_meta = preprocess_example_input(input_config) | |
tensor_data = [one_img] | |
model.forward = partial( | |
model.forward, img_metas=[[one_meta]], return_loss=False) | |
# pytorch has some bug in pytorch1.3, we have to fix it | |
# by replacing these existing op | |
opset_version = 11 | |
# put the import within the function thus it will not cause import error | |
# when not using this function | |
try: | |
from mmcv.onnx.symbolic import register_extra_symbolics | |
except ModuleNotFoundError: | |
raise NotImplementedError('please update mmcv to version>=v1.0.4') | |
register_extra_symbolics(opset_version) | |
return model, tensor_data | |
def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None): | |
"""Build a model from config and load the given checkpoint. | |
Args: | |
config_path (str): the OpenMMLab config for the model we want to | |
export to ONNX | |
checkpoint_path (str): Path to the corresponding checkpoint | |
Returns: | |
torch.nn.Module: the built model | |
""" | |
from mmdet.models import build_detector | |
cfg = mmcv.Config.fromfile(config_path) | |
if cfg_options is not None: | |
cfg.merge_from_dict(cfg_options) | |
# set cudnn_benchmark | |
if cfg.get('cudnn_benchmark', False): | |
torch.backends.cudnn.benchmark = True | |
cfg.model.pretrained = None | |
cfg.data.test.test_mode = True | |
# build the model | |
cfg.model.train_cfg = None | |
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |
checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu') | |
if 'CLASSES' in checkpoint.get('meta', {}): | |
model.CLASSES = checkpoint['meta']['CLASSES'] | |
else: | |
from mmdet.datasets import DATASETS | |
dataset = DATASETS.get(cfg.data.test['type']) | |
assert (dataset is not None) | |
model.CLASSES = dataset.CLASSES | |
model.cpu().eval() | |
return model | |
def preprocess_example_input(input_config): | |
"""Prepare an example input image for ``generate_inputs_and_wrap_model``. | |
Args: | |
input_config (dict): customized config describing the example input. | |
Returns: | |
tuple: (one_img, one_meta), tensor of the example input image and \ | |
meta information for the example input image. | |
Examples: | |
>>> from mmdet.core.export import preprocess_example_input | |
>>> input_config = { | |
>>> 'input_shape': (1,3,224,224), | |
>>> 'input_path': 'demo/demo.jpg', | |
>>> 'normalize_cfg': { | |
>>> 'mean': (123.675, 116.28, 103.53), | |
>>> 'std': (58.395, 57.12, 57.375) | |
>>> } | |
>>> } | |
>>> one_img, one_meta = preprocess_example_input(input_config) | |
>>> print(one_img.shape) | |
torch.Size([1, 3, 224, 224]) | |
>>> print(one_meta) | |
{'img_shape': (224, 224, 3), | |
'ori_shape': (224, 224, 3), | |
'pad_shape': (224, 224, 3), | |
'filename': '<demo>.png', | |
'scale_factor': 1.0, | |
'flip': False} | |
""" | |
input_path = input_config['input_path'] | |
input_shape = input_config['input_shape'] | |
one_img = mmcv.imread(input_path) | |
one_img = mmcv.imresize(one_img, input_shape[2:][::-1]) | |
show_img = one_img.copy() | |
if 'normalize_cfg' in input_config.keys(): | |
normalize_cfg = input_config['normalize_cfg'] | |
mean = np.array(normalize_cfg['mean'], dtype=np.float32) | |
std = np.array(normalize_cfg['std'], dtype=np.float32) | |
to_rgb = normalize_cfg.get('to_rgb', True) | |
one_img = mmcv.imnormalize(one_img, mean, std, to_rgb=to_rgb) | |
one_img = one_img.transpose(2, 0, 1) | |
one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_( | |
True) | |
(_, C, H, W) = input_shape | |
one_meta = { | |
'img_shape': (H, W, C), | |
'ori_shape': (H, W, C), | |
'pad_shape': (H, W, C), | |
'filename': '<demo>.png', | |
'scale_factor': np.ones(4, dtype=np.float32), | |
'flip': False, | |
'show_img': show_img, | |
'flip_direction': None | |
} | |
return one_img, one_meta | |