File size: 3,625 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import numpy as np
import pytest

from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets


def _create_dummy_dict_file(dict_file):
    chars = list('0123456789')
    with open(dict_file, 'w') as fw:
        for char in chars:
            fw.write(char + '\n')


def test_ocr_segm_targets():
    tmp_dir = tempfile.TemporaryDirectory()
    # create dummy dict file
    dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
    _create_dummy_dict_file(dict_file)
    # dummy label convertor
    label_convertor = dict(
        type='SegConvertor',
        dict_file=dict_file,
        with_unknown=True,
        lower=True)
    # test init
    with pytest.raises(AssertionError):
        OCRSegTargets(None, 0.5, 0.5)
    with pytest.raises(AssertionError):
        OCRSegTargets(label_convertor, '1by2', 0.5)
    with pytest.raises(AssertionError):
        OCRSegTargets(label_convertor, 0.5, 2)

    ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
    # test generate kernels
    img_size = (8, 8)
    pad_size = (8, 10)
    char_boxes = [[2, 2, 6, 6]]
    char_idxs = [2]

    with pytest.raises(AssertionError):
        ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
                                     True)
    with pytest.raises(AssertionError):
        ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
                                     char_idxs, 0.5, True)
    with pytest.raises(AssertionError):
        ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
                                     True)

    attn_tgt = ocr_seg_tgt.generate_kernels(
        img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
    expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
                       [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
                       [0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
    assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))

    segm_tgt = ocr_seg_tgt.generate_kernels(
        img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
    expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
                       [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
                       [0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
                       [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
    assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))

    # test __call__
    results = {}
    results['img_shape'] = (4, 4, 3)
    results['resize_shape'] = (8, 8, 3)
    results['pad_shape'] = (8, 10)
    results['ann_info'] = {}
    results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
    results['ann_info']['chars'] = ['1']

    results = ocr_seg_tgt(results)
    assert results['mask_fields'] == ['gt_kernels']
    assert np.allclose(results['gt_kernels'].masks[0],
                       np.array(expect_attn_tgt, dtype=np.int32))
    assert np.allclose(results['gt_kernels'].masks[1],
                       np.array(expect_segm_tgt, dtype=np.int32))

    tmp_dir.cleanup()