MMOCR / mmocr /apis /test.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
6.15 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv.image import tensor2imgs
from mmcv.parallel import DataContainer
from mmdet.core import encode_mask_results
from .utils import tensor2grayimgs
def retrieve_img_tensor_and_meta(data):
"""Retrieval img_tensor, img_metas and img_norm_cfg.
Args:
data (dict): One batch data from data_loader.
Returns:
tuple: Returns (img_tensor, img_metas, img_norm_cfg).
- | img_tensor (Tensor): Input image tensor with shape
:math:`(N, C, H, W)`.
- | img_metas (list[dict]): The metadata of images.
- | img_norm_cfg (dict): Config for image normalization.
"""
if isinstance(data['img'], torch.Tensor):
# for textrecog with batch_size > 1
# and not use 'DefaultFormatBundle' in pipeline
img_tensor = data['img']
img_metas = data['img_metas'].data[0]
elif isinstance(data['img'], list):
if isinstance(data['img'][0], torch.Tensor):
# for textrecog with aug_test and batch_size = 1
img_tensor = data['img'][0]
elif isinstance(data['img'][0], DataContainer):
# for textdet with 'MultiScaleFlipAug'
# and 'DefaultFormatBundle' in pipeline
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
elif isinstance(data['img'], DataContainer):
# for textrecog with 'DefaultFormatBundle' in pipeline
img_tensor = data['img'].data[0]
img_metas = data['img_metas'].data[0]
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape']
for key in must_keys:
if key not in img_metas[0]:
raise KeyError(
f'Please add {key} to the "meta_keys" in the pipeline')
img_norm_cfg = img_metas[0]['img_norm_cfg']
if max(img_norm_cfg['mean']) <= 1:
img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']]
img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']]
return img_tensor, img_metas, img_norm_cfg
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
is_kie=False,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
if is_kie:
img_tensor = data['img'].data[0]
if img_tensor.shape[0] != 1:
raise KeyError('Visualizing KIE outputs in batches is'
'currently not supported.')
gt_bboxes = data['gt_bboxes'].data[0]
img_metas = data['img_metas'].data[0]
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
for key in must_keys:
if key not in img_metas[0]:
raise KeyError(
f'Please add {key} to the "meta_keys" in config.')
# for no visual model
if np.prod(img_tensor.shape) == 0:
imgs = []
for img_meta in img_metas:
try:
img = mmcv.imread(img_meta['filename'])
except Exception as e:
print(f'Load image with error: {e}, '
'use empty image instead.')
img = np.ones(
img_meta['img_shape'], dtype=np.uint8)
imgs.append(img)
else:
imgs = tensor2imgs(img_tensor,
**img_metas[0]['img_norm_cfg'])
for i, img in enumerate(imgs):
h, w, _ = img_metas[i]['img_shape']
img_show = img[:h, :w, :]
if out_dir:
out_file = osp.join(out_dir,
img_metas[i]['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
gt_bboxes[i],
show=show,
out_file=out_file)
else:
img_tensor, img_metas, img_norm_cfg = \
retrieve_img_tensor_and_meta(data)
if img_tensor.size(1) == 1:
imgs = tensor2grayimgs(img_tensor, **img_norm_cfg)
else:
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
for j, (img, img_meta) in enumerate(zip(imgs, img_metas)):
img_shape, ori_shape = img_meta['img_shape'], img_meta[
'ori_shape']
img_show = img[:img_shape[0], :img_shape[1]]
img_show = mmcv.imresize(img_show,
(ori_shape[1], ori_shape[0]))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[j],
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results