File size: 6,152 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import mmcv
import numpy as np
import torch
from mmcv.image import tensor2imgs
from mmcv.parallel import DataContainer
from mmdet.core import encode_mask_results

from .utils import tensor2grayimgs


def retrieve_img_tensor_and_meta(data):
    """Retrieval img_tensor, img_metas and img_norm_cfg.

    Args:
        data (dict): One batch data from data_loader.

    Returns:
        tuple: Returns (img_tensor, img_metas, img_norm_cfg).

            - | img_tensor (Tensor): Input image tensor with shape
                :math:`(N, C, H, W)`.
            - | img_metas (list[dict]): The metadata of images.
            - | img_norm_cfg (dict): Config for image normalization.
    """

    if isinstance(data['img'], torch.Tensor):
        # for textrecog with batch_size > 1
        # and not use 'DefaultFormatBundle' in pipeline
        img_tensor = data['img']
        img_metas = data['img_metas'].data[0]
    elif isinstance(data['img'], list):
        if isinstance(data['img'][0], torch.Tensor):
            # for textrecog with aug_test and batch_size = 1
            img_tensor = data['img'][0]
        elif isinstance(data['img'][0], DataContainer):
            # for textdet with 'MultiScaleFlipAug'
            # and 'DefaultFormatBundle' in pipeline
            img_tensor = data['img'][0].data[0]
        img_metas = data['img_metas'][0].data[0]
    elif isinstance(data['img'], DataContainer):
        # for textrecog with 'DefaultFormatBundle' in pipeline
        img_tensor = data['img'].data[0]
        img_metas = data['img_metas'].data[0]

    must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape']
    for key in must_keys:
        if key not in img_metas[0]:
            raise KeyError(
                f'Please add {key} to the "meta_keys" in the pipeline')

    img_norm_cfg = img_metas[0]['img_norm_cfg']
    if max(img_norm_cfg['mean']) <= 1:
        img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']]
        img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']]

    return img_tensor, img_metas, img_norm_cfg


def single_gpu_test(model,
                    data_loader,
                    show=False,
                    out_dir=None,
                    is_kie=False,
                    show_score_thr=0.3):
    model.eval()
    results = []
    dataset = data_loader.dataset
    prog_bar = mmcv.ProgressBar(len(dataset))
    for data in data_loader:
        with torch.no_grad():
            result = model(return_loss=False, rescale=True, **data)

        batch_size = len(result)
        if show or out_dir:
            if is_kie:
                img_tensor = data['img'].data[0]
                if img_tensor.shape[0] != 1:
                    raise KeyError('Visualizing KIE outputs in batches is'
                                   'currently not supported.')
                gt_bboxes = data['gt_bboxes'].data[0]
                img_metas = data['img_metas'].data[0]
                must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
                for key in must_keys:
                    if key not in img_metas[0]:
                        raise KeyError(
                            f'Please add {key} to the "meta_keys" in config.')
                # for no visual model
                if np.prod(img_tensor.shape) == 0:
                    imgs = []
                    for img_meta in img_metas:
                        try:
                            img = mmcv.imread(img_meta['filename'])
                        except Exception as e:
                            print(f'Load image with error: {e}, '
                                  'use empty image instead.')
                            img = np.ones(
                                img_meta['img_shape'], dtype=np.uint8)
                        imgs.append(img)
                else:
                    imgs = tensor2imgs(img_tensor,
                                       **img_metas[0]['img_norm_cfg'])
                for i, img in enumerate(imgs):
                    h, w, _ = img_metas[i]['img_shape']
                    img_show = img[:h, :w, :]
                    if out_dir:
                        out_file = osp.join(out_dir,
                                            img_metas[i]['ori_filename'])
                    else:
                        out_file = None

                    model.module.show_result(
                        img_show,
                        result[i],
                        gt_bboxes[i],
                        show=show,
                        out_file=out_file)
            else:
                img_tensor, img_metas, img_norm_cfg = \
                    retrieve_img_tensor_and_meta(data)

                if img_tensor.size(1) == 1:
                    imgs = tensor2grayimgs(img_tensor, **img_norm_cfg)
                else:
                    imgs = tensor2imgs(img_tensor, **img_norm_cfg)
                assert len(imgs) == len(img_metas)

                for j, (img, img_meta) in enumerate(zip(imgs, img_metas)):
                    img_shape, ori_shape = img_meta['img_shape'], img_meta[
                        'ori_shape']
                    img_show = img[:img_shape[0], :img_shape[1]]
                    img_show = mmcv.imresize(img_show,
                                             (ori_shape[1], ori_shape[0]))

                    if out_dir:
                        out_file = osp.join(out_dir, img_meta['ori_filename'])
                    else:
                        out_file = None

                    model.module.show_result(
                        img_show,
                        result[j],
                        show=show,
                        out_file=out_file,
                        score_thr=show_score_thr)

        # encode mask results
        if isinstance(result[0], tuple):
            result = [(bbox_results, encode_mask_results(mask_results))
                      for bbox_results, mask_results in result]
        results.extend(result)

        for _ in range(batch_size):
            prog_bar.update()
    return results