Spaces:
Runtime error
Runtime error
File size: 3,503 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
import tempfile
import torch
from mmocr.datasets.ner_dataset import NerDataset
from mmocr.models.ner.convertors.ner_convertor import NerConvertor
from mmocr.utils import list_to_file
def _create_dummy_ann_file(ann_file):
data = {
'text': '彭小军认为,国内银行现在走的是台湾的发卡模式',
'label': {
'address': {
'台湾': [[15, 16]]
},
'name': {
'彭小军': [[0, 2]]
}
}
}
list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)])
def _create_dummy_vocab_file(vocab_file):
for char in list(map(chr, range(ord('a'), ord('z') + 1))):
list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)])
def _create_dummy_loader():
loader = dict(
type='HardDiskLoader',
repeat=1,
parser=dict(type='LineJsonParser', keys=['text', 'label']))
return loader
def test_ner_dataset():
# test initialization
loader = _create_dummy_loader()
categories = [
'address', 'book', 'company', 'game', 'government', 'movie', 'name',
'organization', 'position', 'scene'
]
# create dummy data
tmp_dir = tempfile.TemporaryDirectory()
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt')
_create_dummy_ann_file(ann_file)
_create_dummy_vocab_file(vocab_file)
max_len = 128
ner_convertor = dict(
type='NerConvertor',
annotation_type='bio',
vocab_file=vocab_file,
categories=categories,
max_len=max_len)
test_pipeline = [
dict(
type='NerTransform',
label_convertor=ner_convertor,
max_len=max_len),
dict(type='ToTensorNER')
]
dataset = NerDataset(ann_file, loader, pipeline=test_pipeline)
# test pre_pipeline
img_info = dataset.data_infos[0]
results = dict(img_info=img_info)
dataset.pre_pipeline(results)
# test prepare_train_img
dataset.prepare_train_img(0)
# test evaluation
result = [[['address', 15, 16], ['name', 0, 2]]]
dataset.evaluate(result)
# test pred convert2entity function
pred = [
21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11,
21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1,
11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21,
21, 21
]
preds = [pred[:128]]
mask = [0] * 128
for i in range(10):
mask[i] = 1
assert len(preds[0]) == len(mask)
masks = torch.tensor([mask])
convertor = NerConvertor(
annotation_type='bio',
vocab_file=vocab_file,
categories=categories,
max_len=128)
all_entities = convertor.convert_pred2entities(preds=preds, masks=masks)
assert len(all_entities[0][0]) == 3
tmp_dir.cleanup()
|