Spaces:
Runtime error
Runtime error
File size: 6,706 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmocr.models.builder import CONVERTORS
from mmocr.utils import list_from_file
@CONVERTORS.register_module()
class NerConvertor:
"""Convert between text, index and tensor for NER pipeline.
Args:
annotation_type (str): BIO((B-begin, I-inside, O-outside)),
BIOES(B-begin, I-inside, O-outside, E-end, S-single)
vocab_file (str): File to convert words to ids.
categories (list[str]): All entity categories supported by the model.
max_len (int): The maximum length of the input text.
unknown_id (int): For words that do not appear in vocab.txt.
start_id (int): Each input is prefixed with an input ID.
end_id (int): Each output is prefixed with an output ID.
"""
def __init__(self,
annotation_type='bio',
vocab_file=None,
categories=None,
max_len=None,
unknown_id=100,
start_id=101,
end_id=102):
self.annotation_type = annotation_type
self.categories = categories
self.word2ids = {}
self.max_len = max_len
self.unknown_id = unknown_id
self.start_id = start_id
self.end_id = end_id
assert self.max_len > 2
assert self.annotation_type in ['bio', 'bioes']
vocabs = list_from_file(vocab_file)
self.vocab_size = len(vocabs)
for idx, vocab in enumerate(vocabs):
self.word2ids.update({vocab: idx})
if self.annotation_type == 'bio':
self.label2id_dict, self.id2label, self.ignore_id = \
self._generate_labelid_dict()
elif self.annotation_type == 'bioes':
raise NotImplementedError('Bioes format is not supported yet!')
assert self.ignore_id is not None
assert self.id2label is not None
self.num_labels = len(self.id2label)
def _generate_labelid_dict(self):
"""Generate a dictionary that maps input to ID and ID to output."""
num_classes = len(self.categories)
label2id_dict = {}
ignore_id = 2 * num_classes + 1
id2label_dict = {
0: 'X',
ignore_id: 'O',
2 * num_classes + 2: '[START]',
2 * num_classes + 3: '[END]'
}
for index, category in enumerate(self.categories):
start_label = index + 1
end_label = index + 1 + num_classes
label2id_dict.update({category: [start_label, end_label]})
id2label_dict.update({start_label: 'B-' + category})
id2label_dict.update({end_label: 'I-' + category})
return label2id_dict, id2label_dict, ignore_id
def convert_text2id(self, text):
"""Convert characters to ids.
If the input is uppercase,
convert to lowercase first.
Args:
text (list[char]): Annotations of one paragraph.
Returns:
input_ids (list): Corresponding IDs after conversion.
"""
ids = []
for word in text.lower():
if word in self.word2ids:
ids.append(self.word2ids[word])
else:
ids.append(self.unknown_id)
# Text that exceeds the maximum length is truncated.
valid_len = min(len(text), self.max_len)
input_ids = [0] * self.max_len
input_ids[0] = self.start_id
for i in range(1, valid_len + 1):
input_ids[i] = ids[i - 1]
input_ids[i + 1] = self.end_id
return input_ids
def convert_entity2label(self, label, text_len):
"""Convert labeled entities to ids.
Args:
label (dict): Labels of entities.
text_len (int): The length of input text.
Returns:
labels (list): Label ids of an input text.
"""
labels = [0] * self.max_len
for j in range(min(text_len + 2, self.max_len)):
labels[j] = self.ignore_id
categories = label
for key in categories:
for text in categories[key]:
for place in categories[key][text]:
# Remove the label position beyond the maximum length.
if place[0] + 1 < len(labels):
labels[place[0] + 1] = self.label2id_dict[key][0]
for i in range(place[0] + 1, place[1] + 1):
if i + 1 < len(labels):
labels[i + 1] = self.label2id_dict[key][1]
return labels
def convert_pred2entities(self, preds, masks):
"""Gets entities from preds.
Args:
preds (list): Sequence of preds.
masks (tensor): The valid part is 1 and the invalid part is 0.
Returns:
pred_entities (list): List of [[[entity_type,
entity_start, entity_end]]].
"""
masks = masks.detach().cpu().numpy()
pred_entities = []
assert isinstance(preds, list)
for index, pred in enumerate(preds):
entities = []
entity = [-1, -1, -1]
results = (masks[index][1:] * np.array(pred[1:])).tolist()
for index, tag in enumerate(results):
if not isinstance(tag, str):
tag = self.id2label[tag]
if self.annotation_type == 'bio':
if tag.startswith('B-'):
if entity[2] != -1 and entity[1] < entity[2]:
entities.append(entity)
entity = [-1, -1, -1]
entity[1] = index
entity[0] = tag.split('-')[1]
entity[2] = index
if index == len(results) - 1 and entity[1] < entity[2]:
entities.append(entity)
elif tag.startswith('I-') and entity[1] != -1:
_type = tag.split('-')[1]
if _type == entity[0]:
entity[2] = index
if index == len(results) - 1 and entity[1] < entity[2]:
entities.append(entity)
else:
if entity[2] != -1 and entity[1] < entity[2]:
entities.append(entity)
entity = [-1, -1, -1]
else:
raise NotImplementedError(
'The data format is not supported yet!')
pred_entities.append(entities)
return pred_entities
|