File size: 2,051 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.datasets.builder import PIPELINES

from mmocr.models.builder import build_convertor


@PIPELINES.register_module()
class NerTransform:
    """Convert text to ID and entity in ground truth to label ID. The masks and
    tokens are generated at the same time. The four parameters will be used as
    input to the model.

    Args:
        label_convertor: Convert text to ID and entity
        in ground truth to label ID.
        max_len (int): Limited maximum input length.
    """

    def __init__(self, label_convertor, max_len):
        self.label_convertor = build_convertor(label_convertor)
        self.max_len = max_len

    def __call__(self, results):
        texts = results['text']
        input_ids = self.label_convertor.convert_text2id(texts)
        labels = self.label_convertor.convert_entity2label(
            results['label'], len(texts))

        attention_mask = [0] * self.max_len
        token_type_ids = [0] * self.max_len
        # The beginning and end IDs are added to the ID,
        # so the mask length is increased by 2
        for i in range(len(texts) + 2):
            attention_mask[i] = 1
        results = dict(
            labels=labels,
            texts=texts,
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids)
        return results


@PIPELINES.register_module()
class ToTensorNER:
    """Convert data with ``list`` type to tensor."""

    def __call__(self, results):

        input_ids = torch.tensor(results['input_ids'])
        labels = torch.tensor(results['labels'])
        attention_masks = torch.tensor(results['attention_mask'])
        token_type_ids = torch.tensor(results['token_type_ids'])

        results = dict(
            img=[],
            img_metas=dict(
                input_ids=input_ids,
                attention_masks=attention_masks,
                labels=labels,
                token_type_ids=token_type_ids))
        return results