File size: 4,324 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings

import mmcv
import numpy as np
import torch
from mmdet.datasets import replace_ImageToTensor

from mmocr.utils import is_2dlist, is_type_list


def update_pipeline(cfg, idx=None):
    if idx is None:
        if cfg.pipeline is not None:
            cfg.pipeline = replace_ImageToTensor(cfg.pipeline)
    else:
        cfg.pipeline[idx] = replace_ImageToTensor(cfg.pipeline[idx])


def replace_image_to_tensor(cfg, set_types=None):
    """Replace 'ImageToTensor' to 'DefaultFormatBundle'."""
    assert set_types is None or isinstance(set_types, list)
    if set_types is None:
        set_types = ['val', 'test']

    cfg = copy.deepcopy(cfg)
    for set_type in set_types:
        assert set_type in ['val', 'test']
        uniform_pipeline = cfg.data[set_type].get('pipeline', None)
        if is_type_list(uniform_pipeline, dict):
            update_pipeline(cfg.data[set_type])
        elif is_2dlist(uniform_pipeline):
            for idx, _ in enumerate(uniform_pipeline):
                update_pipeline(cfg.data[set_type], idx)

        for dataset in cfg.data[set_type].get('datasets', []):
            if isinstance(dataset, list):
                for each_dataset in dataset:
                    update_pipeline(each_dataset)
            else:
                update_pipeline(dataset)

    return cfg


def update_pipeline_recog(cfg, idx=None):
    warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \
        'inference since samples_per_gpu > 1.'
    if idx is None:
        if cfg.get('pipeline',
                   None) and cfg.pipeline[1].type == 'MultiRotateAugOCR':
            warnings.warn(warning_msg)
            cfg.pipeline = [cfg.pipeline[0], *cfg.pipeline[1].transforms]
    else:
        if cfg[idx][1].type == 'MultiRotateAugOCR':
            warnings.warn(warning_msg)
            cfg[idx] = [cfg[idx][0], *cfg[idx][1].transforms]


def disable_text_recog_aug_test(cfg, set_types=None):
    """Remove aug_test from test pipeline for text recognition.

    Args:
        cfg (mmcv.Config): Input config.
        set_types (list[str]): Type of dataset source. Should be
            None or sublist of ['test', 'val'].
    """
    assert set_types is None or isinstance(set_types, list)
    if set_types is None:
        set_types = ['val', 'test']

    cfg = copy.deepcopy(cfg)
    warnings.simplefilter('once')
    for set_type in set_types:
        assert set_type in ['val', 'test']
        dataset_type = cfg.data[set_type].type
        if dataset_type not in [
                'ConcatDataset', 'UniformConcatDataset', 'OCRDataset',
                'OCRSegDataset'
        ]:
            continue

        uniform_pipeline = cfg.data[set_type].get('pipeline', None)
        if is_type_list(uniform_pipeline, dict):
            update_pipeline_recog(cfg.data[set_type])
        elif is_2dlist(uniform_pipeline):
            for idx, _ in enumerate(uniform_pipeline):
                update_pipeline_recog(cfg.data[set_type].pipeline, idx)

        for dataset in cfg.data[set_type].get('datasets', []):
            if isinstance(dataset, list):
                for each_dataset in dataset:
                    update_pipeline_recog(each_dataset)
            else:
                update_pipeline_recog(dataset)

    return cfg


def tensor2grayimgs(tensor, mean=(127, ), std=(127, ), **kwargs):
    """Convert tensor to 1-channel gray images.

    Args:
        tensor (torch.Tensor): Tensor that contains multiple images, shape (
            N, C, H, W).
        mean (tuple[float], optional): Mean of images. Defaults to (127).
        std (tuple[float], optional): Standard deviation of images.
            Defaults to (127).

    Returns:
        list[np.ndarray]: A list that contains multiple images.
    """

    assert torch.is_tensor(tensor) and tensor.ndim == 4
    assert tensor.size(1) == len(mean) == len(std) == 1

    num_imgs = tensor.size(0)
    mean = np.array(mean, dtype=np.float32)
    std = np.array(std, dtype=np.float32)
    imgs = []
    for img_id in range(num_imgs):
        img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
        img = mmcv.imdenormalize(img, mean, std, to_bgr=False).astype(np.uint8)
        imgs.append(np.ascontiguousarray(img))
    return imgs