MMOCR / tests /test_dataset /test_ocr_seg_target.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
3.63 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
import numpy as np
import pytest
from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets
def _create_dummy_dict_file(dict_file):
chars = list('0123456789')
with open(dict_file, 'w') as fw:
for char in chars:
fw.write(char + '\n')
def test_ocr_segm_targets():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy dict file
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
# dummy label convertor
label_convertor = dict(
type='SegConvertor',
dict_file=dict_file,
with_unknown=True,
lower=True)
# test init
with pytest.raises(AssertionError):
OCRSegTargets(None, 0.5, 0.5)
with pytest.raises(AssertionError):
OCRSegTargets(label_convertor, '1by2', 0.5)
with pytest.raises(AssertionError):
OCRSegTargets(label_convertor, 0.5, 2)
ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
# test generate kernels
img_size = (8, 8)
pad_size = (8, 10)
char_boxes = [[2, 2, 6, 6]]
char_idxs = [2]
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
True)
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
char_idxs, 0.5, True)
with pytest.raises(AssertionError):
ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
True)
attn_tgt = ocr_seg_tgt.generate_kernels(
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))
segm_tgt = ocr_seg_tgt.generate_kernels(
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))
# test __call__
results = {}
results['img_shape'] = (4, 4, 3)
results['resize_shape'] = (8, 8, 3)
results['pad_shape'] = (8, 10)
results['ann_info'] = {}
results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
results['ann_info']['chars'] = ['1']
results = ocr_seg_tgt(results)
assert results['mask_fields'] == ['gt_kernels']
assert np.allclose(results['gt_kernels'].masks[0],
np.array(expect_attn_tgt, dtype=np.int32))
assert np.allclose(results['gt_kernels'].masks[1],
np.array(expect_segm_tgt, dtype=np.int32))
tmp_dir.cleanup()