File size: 3,503 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import tempfile

import torch

from mmocr.datasets.ner_dataset import NerDataset
from mmocr.models.ner.convertors.ner_convertor import NerConvertor
from mmocr.utils import list_to_file


def _create_dummy_ann_file(ann_file):
    data = {
        'text': '彭小军认为,国内银行现在走的是台湾的发卡模式',
        'label': {
            'address': {
                '台湾': [[15, 16]]
            },
            'name': {
                '彭小军': [[0, 2]]
            }
        }
    }

    list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)])


def _create_dummy_vocab_file(vocab_file):
    for char in list(map(chr, range(ord('a'), ord('z') + 1))):
        list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)])


def _create_dummy_loader():
    loader = dict(
        type='HardDiskLoader',
        repeat=1,
        parser=dict(type='LineJsonParser', keys=['text', 'label']))
    return loader


def test_ner_dataset():
    # test initialization
    loader = _create_dummy_loader()
    categories = [
        'address', 'book', 'company', 'game', 'government', 'movie', 'name',
        'organization', 'position', 'scene'
    ]

    # create dummy data
    tmp_dir = tempfile.TemporaryDirectory()
    ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
    vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt')
    _create_dummy_ann_file(ann_file)
    _create_dummy_vocab_file(vocab_file)

    max_len = 128
    ner_convertor = dict(
        type='NerConvertor',
        annotation_type='bio',
        vocab_file=vocab_file,
        categories=categories,
        max_len=max_len)

    test_pipeline = [
        dict(
            type='NerTransform',
            label_convertor=ner_convertor,
            max_len=max_len),
        dict(type='ToTensorNER')
    ]
    dataset = NerDataset(ann_file, loader, pipeline=test_pipeline)

    # test pre_pipeline
    img_info = dataset.data_infos[0]
    results = dict(img_info=img_info)
    dataset.pre_pipeline(results)

    # test prepare_train_img
    dataset.prepare_train_img(0)

    # test evaluation
    result = [[['address', 15, 16], ['name', 0, 2]]]

    dataset.evaluate(result)

    # test pred convert2entity function
    pred = [
        21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11,
        21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1,
        11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
        21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
        21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
        21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
        21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
        21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21,
        21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21,
        21, 21
    ]
    preds = [pred[:128]]
    mask = [0] * 128
    for i in range(10):
        mask[i] = 1
    assert len(preds[0]) == len(mask)
    masks = torch.tensor([mask])
    convertor = NerConvertor(
        annotation_type='bio',
        vocab_file=vocab_file,
        categories=categories,
        max_len=128)
    all_entities = convertor.convert_pred2entities(preds=preds, masks=masks)
    assert len(all_entities[0][0]) == 3

    tmp_dir.cleanup()