Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
from mmocr.models.builder import CONVERTORS | |
from mmocr.utils import list_from_file | |
class NerConvertor: | |
"""Convert between text, index and tensor for NER pipeline. | |
Args: | |
annotation_type (str): BIO((B-begin, I-inside, O-outside)), | |
BIOES(B-begin, I-inside, O-outside, E-end, S-single) | |
vocab_file (str): File to convert words to ids. | |
categories (list[str]): All entity categories supported by the model. | |
max_len (int): The maximum length of the input text. | |
unknown_id (int): For words that do not appear in vocab.txt. | |
start_id (int): Each input is prefixed with an input ID. | |
end_id (int): Each output is prefixed with an output ID. | |
""" | |
def __init__(self, | |
annotation_type='bio', | |
vocab_file=None, | |
categories=None, | |
max_len=None, | |
unknown_id=100, | |
start_id=101, | |
end_id=102): | |
self.annotation_type = annotation_type | |
self.categories = categories | |
self.word2ids = {} | |
self.max_len = max_len | |
self.unknown_id = unknown_id | |
self.start_id = start_id | |
self.end_id = end_id | |
assert self.max_len > 2 | |
assert self.annotation_type in ['bio', 'bioes'] | |
vocabs = list_from_file(vocab_file) | |
self.vocab_size = len(vocabs) | |
for idx, vocab in enumerate(vocabs): | |
self.word2ids.update({vocab: idx}) | |
if self.annotation_type == 'bio': | |
self.label2id_dict, self.id2label, self.ignore_id = \ | |
self._generate_labelid_dict() | |
elif self.annotation_type == 'bioes': | |
raise NotImplementedError('Bioes format is not supported yet!') | |
assert self.ignore_id is not None | |
assert self.id2label is not None | |
self.num_labels = len(self.id2label) | |
def _generate_labelid_dict(self): | |
"""Generate a dictionary that maps input to ID and ID to output.""" | |
num_classes = len(self.categories) | |
label2id_dict = {} | |
ignore_id = 2 * num_classes + 1 | |
id2label_dict = { | |
0: 'X', | |
ignore_id: 'O', | |
2 * num_classes + 2: '[START]', | |
2 * num_classes + 3: '[END]' | |
} | |
for index, category in enumerate(self.categories): | |
start_label = index + 1 | |
end_label = index + 1 + num_classes | |
label2id_dict.update({category: [start_label, end_label]}) | |
id2label_dict.update({start_label: 'B-' + category}) | |
id2label_dict.update({end_label: 'I-' + category}) | |
return label2id_dict, id2label_dict, ignore_id | |
def convert_text2id(self, text): | |
"""Convert characters to ids. | |
If the input is uppercase, | |
convert to lowercase first. | |
Args: | |
text (list[char]): Annotations of one paragraph. | |
Returns: | |
input_ids (list): Corresponding IDs after conversion. | |
""" | |
ids = [] | |
for word in text.lower(): | |
if word in self.word2ids: | |
ids.append(self.word2ids[word]) | |
else: | |
ids.append(self.unknown_id) | |
# Text that exceeds the maximum length is truncated. | |
valid_len = min(len(text), self.max_len) | |
input_ids = [0] * self.max_len | |
input_ids[0] = self.start_id | |
for i in range(1, valid_len + 1): | |
input_ids[i] = ids[i - 1] | |
input_ids[i + 1] = self.end_id | |
return input_ids | |
def convert_entity2label(self, label, text_len): | |
"""Convert labeled entities to ids. | |
Args: | |
label (dict): Labels of entities. | |
text_len (int): The length of input text. | |
Returns: | |
labels (list): Label ids of an input text. | |
""" | |
labels = [0] * self.max_len | |
for j in range(min(text_len + 2, self.max_len)): | |
labels[j] = self.ignore_id | |
categories = label | |
for key in categories: | |
for text in categories[key]: | |
for place in categories[key][text]: | |
# Remove the label position beyond the maximum length. | |
if place[0] + 1 < len(labels): | |
labels[place[0] + 1] = self.label2id_dict[key][0] | |
for i in range(place[0] + 1, place[1] + 1): | |
if i + 1 < len(labels): | |
labels[i + 1] = self.label2id_dict[key][1] | |
return labels | |
def convert_pred2entities(self, preds, masks): | |
"""Gets entities from preds. | |
Args: | |
preds (list): Sequence of preds. | |
masks (tensor): The valid part is 1 and the invalid part is 0. | |
Returns: | |
pred_entities (list): List of [[[entity_type, | |
entity_start, entity_end]]]. | |
""" | |
masks = masks.detach().cpu().numpy() | |
pred_entities = [] | |
assert isinstance(preds, list) | |
for index, pred in enumerate(preds): | |
entities = [] | |
entity = [-1, -1, -1] | |
results = (masks[index][1:] * np.array(pred[1:])).tolist() | |
for index, tag in enumerate(results): | |
if not isinstance(tag, str): | |
tag = self.id2label[tag] | |
if self.annotation_type == 'bio': | |
if tag.startswith('B-'): | |
if entity[2] != -1 and entity[1] < entity[2]: | |
entities.append(entity) | |
entity = [-1, -1, -1] | |
entity[1] = index | |
entity[0] = tag.split('-')[1] | |
entity[2] = index | |
if index == len(results) - 1 and entity[1] < entity[2]: | |
entities.append(entity) | |
elif tag.startswith('I-') and entity[1] != -1: | |
_type = tag.split('-')[1] | |
if _type == entity[0]: | |
entity[2] = index | |
if index == len(results) - 1 and entity[1] < entity[2]: | |
entities.append(entity) | |
else: | |
if entity[2] != -1 and entity[1] < entity[2]: | |
entities.append(entity) | |
entity = [-1, -1, -1] | |
else: | |
raise NotImplementedError( | |
'The data format is not supported yet!') | |
pred_entities.append(entities) | |
return pred_entities | |