File size: 6,706 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np

from mmocr.models.builder import CONVERTORS
from mmocr.utils import list_from_file


@CONVERTORS.register_module()
class NerConvertor:
    """Convert between text, index and tensor for NER pipeline.

    Args:
        annotation_type (str): BIO((B-begin, I-inside, O-outside)),
                    BIOES(B-begin, I-inside, O-outside, E-end, S-single)
        vocab_file (str): File to convert words to ids.
        categories (list[str]): All entity categories supported by the model.
        max_len (int): The maximum length of the input text.
        unknown_id (int): For words that do not appear in vocab.txt.
        start_id (int): Each input is prefixed with an input ID.
        end_id (int): Each output is prefixed with an output ID.
    """

    def __init__(self,
                 annotation_type='bio',
                 vocab_file=None,
                 categories=None,
                 max_len=None,
                 unknown_id=100,
                 start_id=101,
                 end_id=102):
        self.annotation_type = annotation_type
        self.categories = categories
        self.word2ids = {}
        self.max_len = max_len
        self.unknown_id = unknown_id
        self.start_id = start_id
        self.end_id = end_id
        assert self.max_len > 2
        assert self.annotation_type in ['bio', 'bioes']

        vocabs = list_from_file(vocab_file)
        self.vocab_size = len(vocabs)
        for idx, vocab in enumerate(vocabs):
            self.word2ids.update({vocab: idx})

        if self.annotation_type == 'bio':
            self.label2id_dict, self.id2label, self.ignore_id = \
                self._generate_labelid_dict()
        elif self.annotation_type == 'bioes':
            raise NotImplementedError('Bioes format is not supported yet!')

        assert self.ignore_id is not None
        assert self.id2label is not None
        self.num_labels = len(self.id2label)

    def _generate_labelid_dict(self):
        """Generate a dictionary that maps input to ID and ID to output."""
        num_classes = len(self.categories)
        label2id_dict = {}
        ignore_id = 2 * num_classes + 1
        id2label_dict = {
            0: 'X',
            ignore_id: 'O',
            2 * num_classes + 2: '[START]',
            2 * num_classes + 3: '[END]'
        }

        for index, category in enumerate(self.categories):
            start_label = index + 1
            end_label = index + 1 + num_classes
            label2id_dict.update({category: [start_label, end_label]})
            id2label_dict.update({start_label: 'B-' + category})
            id2label_dict.update({end_label: 'I-' + category})

        return label2id_dict, id2label_dict, ignore_id

    def convert_text2id(self, text):
        """Convert characters to ids.

        If the input is uppercase,
            convert to lowercase first.
        Args:
            text (list[char]): Annotations of one paragraph.
        Returns:
            input_ids (list): Corresponding IDs after conversion.
        """
        ids = []
        for word in text.lower():
            if word in self.word2ids:
                ids.append(self.word2ids[word])
            else:
                ids.append(self.unknown_id)
        # Text that exceeds the maximum length is truncated.
        valid_len = min(len(text), self.max_len)
        input_ids = [0] * self.max_len
        input_ids[0] = self.start_id
        for i in range(1, valid_len + 1):
            input_ids[i] = ids[i - 1]
        input_ids[i + 1] = self.end_id

        return input_ids

    def convert_entity2label(self, label, text_len):
        """Convert labeled entities to ids.

        Args:
            label (dict): Labels of entities.
            text_len (int): The length of input text.
        Returns:
            labels (list): Label ids of an input text.
        """
        labels = [0] * self.max_len
        for j in range(min(text_len + 2, self.max_len)):
            labels[j] = self.ignore_id
        categories = label
        for key in categories:
            for text in categories[key]:
                for place in categories[key][text]:
                    # Remove the label position beyond the maximum length.
                    if place[0] + 1 < len(labels):
                        labels[place[0] + 1] = self.label2id_dict[key][0]
                        for i in range(place[0] + 1, place[1] + 1):
                            if i + 1 < len(labels):
                                labels[i + 1] = self.label2id_dict[key][1]
        return labels

    def convert_pred2entities(self, preds, masks):
        """Gets entities from preds.

        Args:
            preds (list): Sequence of preds.
            masks (tensor): The valid part is 1 and the invalid part is 0.
        Returns:
            pred_entities (list): List of [[[entity_type,
                                entity_start, entity_end]]].
        """

        masks = masks.detach().cpu().numpy()
        pred_entities = []
        assert isinstance(preds, list)
        for index, pred in enumerate(preds):
            entities = []
            entity = [-1, -1, -1]
            results = (masks[index][1:] * np.array(pred[1:])).tolist()
            for index, tag in enumerate(results):
                if not isinstance(tag, str):
                    tag = self.id2label[tag]
                if self.annotation_type == 'bio':
                    if tag.startswith('B-'):
                        if entity[2] != -1 and entity[1] < entity[2]:
                            entities.append(entity)
                        entity = [-1, -1, -1]
                        entity[1] = index
                        entity[0] = tag.split('-')[1]
                        entity[2] = index
                        if index == len(results) - 1 and entity[1] < entity[2]:
                            entities.append(entity)
                    elif tag.startswith('I-') and entity[1] != -1:
                        _type = tag.split('-')[1]
                        if _type == entity[0]:
                            entity[2] = index

                        if index == len(results) - 1 and entity[1] < entity[2]:
                            entities.append(entity)
                    else:
                        if entity[2] != -1 and entity[1] < entity[2]:
                            entities.append(entity)
                        entity = [-1, -1, -1]
                else:
                    raise NotImplementedError(
                        'The data format is not supported yet!')
            pred_entities.append(entities)
        return pred_entities