File size: 4,169 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os

import pytest
from mmcv import Config

from mmocr.apis.utils import (disable_text_recog_aug_test,
                              replace_image_to_tensor)


@pytest.mark.parametrize('cfg_file', [
    '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
])
def test_disable_text_recog_aug_test(cfg_file):
    tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    config_file = os.path.join(tmp_dir, cfg_file)

    cfg = Config.fromfile(config_file)
    test = cfg.data.test.datasets[0]

    # cfg.data.test.type is 'OCRDataset'
    cfg1 = copy.deepcopy(cfg)
    test1 = copy.deepcopy(test)
    test1.pipeline = cfg1.data.test.pipeline
    cfg1.data.test = test1
    cfg1 = disable_text_recog_aug_test(cfg1, set_types=['test'])
    assert cfg1.data.test.pipeline[1].type != 'MultiRotateAugOCR'

    # cfg.data.test.type is 'UniformConcatDataset'
    # and cfg.data.test.pipeline is list[dict]
    cfg2 = copy.deepcopy(cfg)
    test2 = copy.deepcopy(test)
    test2.pipeline = cfg2.data.test.pipeline
    cfg2.data.test.datasets = [test2]
    cfg2 = disable_text_recog_aug_test(cfg2, set_types=['test'])
    assert cfg2.data.test.pipeline[1].type != 'MultiRotateAugOCR'
    assert cfg2.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR'

    # cfg.data.test.type is 'ConcatDataset'
    cfg3 = copy.deepcopy(cfg)
    test3 = copy.deepcopy(test)
    test3.pipeline = cfg3.data.test.pipeline
    cfg3.data.test = Config(dict(type='ConcatDataset', datasets=[test3]))
    cfg3 = disable_text_recog_aug_test(cfg3, set_types=['test'])
    assert cfg3.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR'

    # cfg.data.test.type is 'UniformConcatDataset'
    # and cfg.data.test.pipeline is list[list[dict]]
    cfg4 = copy.deepcopy(cfg)
    test4 = copy.deepcopy(test)
    test4.pipeline = cfg4.data.test.pipeline
    cfg4.data.test.datasets = [[test4], [test]]
    cfg4.data.test.pipeline = [
        cfg4.data.test.pipeline, cfg4.data.test.pipeline
    ]
    cfg4 = disable_text_recog_aug_test(cfg4, set_types=['test'])
    assert cfg4.data.test.datasets[0][0].pipeline[1].type != \
        'MultiRotateAugOCR'

    # cfg.data.test.type is 'UniformConcatDataset'
    # and cfg.data.test.pipeline is None
    cfg5 = copy.deepcopy(cfg)
    test5 = copy.deepcopy(test)
    test5.pipeline = copy.deepcopy(cfg5.data.test.pipeline)
    cfg5.data.test.datasets = [test5]
    cfg5.data.test.pipeline = None
    cfg5 = disable_text_recog_aug_test(cfg5, set_types=['test'])
    assert cfg5.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR'


@pytest.mark.parametrize('cfg_file', [
    '../configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py',
])
def test_replace_image_to_tensor(cfg_file):
    tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    config_file = os.path.join(tmp_dir, cfg_file)

    cfg = Config.fromfile(config_file)
    test = cfg.data.test.datasets[0]

    # cfg.data.test.pipeline is list[dict]
    # and cfg.data.test.datasets is list[dict]
    cfg1 = copy.deepcopy(cfg)
    test1 = copy.deepcopy(test)
    test1.pipeline = copy.deepcopy(cfg.data.test.pipeline)
    cfg1.data.test.datasets = [test1]
    cfg1 = replace_image_to_tensor(cfg1, set_types=['test'])
    assert cfg1.data.test.pipeline[1]['transforms'][3][
        'type'] == 'DefaultFormatBundle'
    assert cfg1.data.test.datasets[0].pipeline[1]['transforms'][3][
        'type'] == 'DefaultFormatBundle'

    # cfg.data.test.pipeline is list[list[dict]]
    # and cfg.data.test.datasets is list[list[dict]]
    cfg2 = copy.deepcopy(cfg)
    test2 = copy.deepcopy(test)
    test2.pipeline = copy.deepcopy(cfg.data.test.pipeline)
    cfg2.data.test.datasets = [[test2], [test2]]
    cfg2.data.test.pipeline = [
        cfg2.data.test.pipeline, cfg2.data.test.pipeline
    ]
    cfg2 = replace_image_to_tensor(cfg2, set_types=['test'])
    assert cfg2.data.test.pipeline[0][1]['transforms'][3][
        'type'] == 'DefaultFormatBundle'
    assert cfg2.data.test.datasets[0][0].pipeline[1]['transforms'][3][
        'type'] == 'DefaultFormatBundle'