MMOCR / mmocr /apis /inference.py
tomofi's picture
Add application file
2366e36
raw
history blame
8.08 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmocr.models import build_detector
from mmocr.utils import is_2dlist
from .utils import disable_text_recog_aug_test
def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): Options to override some settings in the used
config.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
if config.model.get('pretrained'):
config.model.pretrained = None
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def model_inference(model,
imgs,
ann=None,
batch_mode=False,
return_data=False):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
batch_mode (bool): If True, use batch mode for inference.
ann (dict): Annotation info for key information extraction.
return_data: Return postprocessed data.
Returns:
result (dict): Predicted results.
"""
if isinstance(imgs, (list, tuple)):
is_batch = True
if len(imgs) == 0:
raise Exception('empty imgs provided, please check and try again')
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
is_batch = False
else:
raise AssertionError('imgs must be strings or numpy arrays')
is_ndarray = isinstance(imgs[0], np.ndarray)
cfg = model.cfg
if batch_mode:
cfg = disable_text_recog_aug_test(cfg, set_types=['test'])
device = next(model.parameters()).device # model device
if cfg.data.test.get('pipeline', None) is None:
if is_2dlist(cfg.data.test.datasets):
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
else:
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
if is_2dlist(cfg.data.test.pipeline):
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
if is_ndarray:
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
datas = []
for img in imgs:
# prepare data
if is_ndarray:
# directly add img
data = dict(
img=img,
ann_info=ann,
img_info=dict(width=img.shape[1], height=img.shape[0]),
bbox_fields=[])
else:
# add information into dict
data = dict(
img_info=dict(filename=img),
img_prefix=None,
ann_info=ann,
bbox_fields=[])
if ann is not None:
data.update(dict(**ann))
# build the data pipeline
data = test_pipeline(data)
# get tensor from list to stack for batch mode (text detection)
if batch_mode:
if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug':
for key, value in data.items():
data[key] = value[0]
datas.append(data)
if isinstance(datas[0]['img'], list) and len(datas) > 1:
raise Exception('aug test does not support '
f'inference with batch size '
f'{len(datas)}')
data = collate(datas, samples_per_gpu=len(imgs))
# process img_metas
if isinstance(data['img_metas'], list):
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
else:
data['img_metas'] = data['img_metas'].data
if isinstance(data['img'], list):
data['img'] = [img.data for img in data['img']]
if isinstance(data['img'][0], list):
data['img'] = [img[0] for img in data['img']]
else:
data['img'] = data['img'].data
# for KIE models
if ann is not None:
data['relations'] = data['relations'].data[0]
data['gt_bboxes'] = data['gt_bboxes'].data[0]
data['texts'] = data['texts'].data[0]
data['img'] = data['img'][0]
data['img_metas'] = data['img_metas'][0]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# forward the model
with torch.no_grad():
results = model(return_loss=False, rescale=True, **data)
if not is_batch:
if not return_data:
return results[0]
return results[0], datas[0]
else:
if not return_data:
return results
return results, datas
def text_model_inference(model, input_sentence):
"""Inference text(s) with the entity recognizer.
Args:
model (nn.Module): The loaded recognizer.
input_sentence (str): A text entered by the user.
Returns:
result (dict): Predicted results.
"""
assert isinstance(input_sentence, str)
cfg = model.cfg
if cfg.data.test.get('pipeline', None) is None:
if is_2dlist(cfg.data.test.datasets):
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
else:
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
if is_2dlist(cfg.data.test.pipeline):
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
test_pipeline = Compose(cfg.data.test.pipeline)
data = {'text': input_sentence, 'label': {}}
# build the data pipeline
data = test_pipeline(data)
if isinstance(data['img_metas'], dict):
img_metas = data['img_metas']
else:
img_metas = data['img_metas'].data
assert isinstance(img_metas, dict)
img_metas = {
'input_ids': img_metas['input_ids'].unsqueeze(0),
'attention_masks': img_metas['attention_masks'].unsqueeze(0),
'token_type_ids': img_metas['token_type_ids'].unsqueeze(0),
'labels': img_metas['labels'].unsqueeze(0)
}
# forward the model
with torch.no_grad():
result = model(None, img_metas, return_loss=False)
return result