# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import mmcv import numpy as np import torch from mmcv.image import tensor2imgs from mmcv.parallel import DataContainer from mmdet.core import encode_mask_results from .utils import tensor2grayimgs def retrieve_img_tensor_and_meta(data): """Retrieval img_tensor, img_metas and img_norm_cfg. Args: data (dict): One batch data from data_loader. Returns: tuple: Returns (img_tensor, img_metas, img_norm_cfg). - | img_tensor (Tensor): Input image tensor with shape :math:`(N, C, H, W)`. - | img_metas (list[dict]): The metadata of images. - | img_norm_cfg (dict): Config for image normalization. """ if isinstance(data['img'], torch.Tensor): # for textrecog with batch_size > 1 # and not use 'DefaultFormatBundle' in pipeline img_tensor = data['img'] img_metas = data['img_metas'].data[0] elif isinstance(data['img'], list): if isinstance(data['img'][0], torch.Tensor): # for textrecog with aug_test and batch_size = 1 img_tensor = data['img'][0] elif isinstance(data['img'][0], DataContainer): # for textdet with 'MultiScaleFlipAug' # and 'DefaultFormatBundle' in pipeline img_tensor = data['img'][0].data[0] img_metas = data['img_metas'][0].data[0] elif isinstance(data['img'], DataContainer): # for textrecog with 'DefaultFormatBundle' in pipeline img_tensor = data['img'].data[0] img_metas = data['img_metas'].data[0] must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape'] for key in must_keys: if key not in img_metas[0]: raise KeyError( f'Please add {key} to the "meta_keys" in the pipeline') img_norm_cfg = img_metas[0]['img_norm_cfg'] if max(img_norm_cfg['mean']) <= 1: img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']] img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']] return img_tensor, img_metas, img_norm_cfg def single_gpu_test(model, data_loader, show=False, out_dir=None, is_kie=False, show_score_thr=0.3): model.eval() results = [] dataset = data_loader.dataset prog_bar = mmcv.ProgressBar(len(dataset)) for data in data_loader: with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) batch_size = len(result) if show or out_dir: if is_kie: img_tensor = data['img'].data[0] if img_tensor.shape[0] != 1: raise KeyError('Visualizing KIE outputs in batches is' 'currently not supported.') gt_bboxes = data['gt_bboxes'].data[0] img_metas = data['img_metas'].data[0] must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape'] for key in must_keys: if key not in img_metas[0]: raise KeyError( f'Please add {key} to the "meta_keys" in config.') # for no visual model if np.prod(img_tensor.shape) == 0: imgs = [] for img_meta in img_metas: try: img = mmcv.imread(img_meta['filename']) except Exception as e: print(f'Load image with error: {e}, ' 'use empty image instead.') img = np.ones( img_meta['img_shape'], dtype=np.uint8) imgs.append(img) else: imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) for i, img in enumerate(imgs): h, w, _ = img_metas[i]['img_shape'] img_show = img[:h, :w, :] if out_dir: out_file = osp.join(out_dir, img_metas[i]['ori_filename']) else: out_file = None model.module.show_result( img_show, result[i], gt_bboxes[i], show=show, out_file=out_file) else: img_tensor, img_metas, img_norm_cfg = \ retrieve_img_tensor_and_meta(data) if img_tensor.size(1) == 1: imgs = tensor2grayimgs(img_tensor, **img_norm_cfg) else: imgs = tensor2imgs(img_tensor, **img_norm_cfg) assert len(imgs) == len(img_metas) for j, (img, img_meta) in enumerate(zip(imgs, img_metas)): img_shape, ori_shape = img_meta['img_shape'], img_meta[ 'ori_shape'] img_show = img[:img_shape[0], :img_shape[1]] img_show = mmcv.imresize(img_show, (ori_shape[1], ori_shape[0])) if out_dir: out_file = osp.join(out_dir, img_meta['ori_filename']) else: out_file = None model.module.show_result( img_show, result[j], show=show, out_file=out_file, score_thr=show_score_thr) # encode mask results if isinstance(result[0], tuple): result = [(bbox_results, encode_mask_results(mask_results)) for bbox_results, mask_results in result] results.extend(result) for _ in range(batch_size): prog_bar.update() return results