File size: 1,899 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder,
                                  build_encoder, build_loss)
from mmocr.models.textrecog.recognizer.base import BaseRecognizer


@DETECTORS.register_module()
class NerClassifier(BaseRecognizer):
    """Base class for NER classifier."""

    def __init__(self,
                 encoder,
                 decoder,
                 loss,
                 label_convertor,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.label_convertor = build_convertor(label_convertor)

        self.encoder = build_encoder(encoder)

        decoder.update(num_labels=self.label_convertor.num_labels)
        self.decoder = build_decoder(decoder)

        loss.update(num_labels=self.label_convertor.num_labels)
        self.loss = build_loss(loss)

    def extract_feat(self, imgs):
        """Extract features from images."""
        raise NotImplementedError(
            'Extract feature module is not implemented yet.')

    def forward_train(self, imgs, img_metas, **kwargs):
        encode_out = self.encoder(img_metas)
        logits, _ = self.decoder(encode_out)
        loss = self.loss(logits, img_metas)
        return loss

    def forward_test(self, imgs, img_metas, **kwargs):
        encode_out = self.encoder(img_metas)
        _, preds = self.decoder(encode_out)
        pred_entities = self.label_convertor.convert_pred2entities(
            preds, img_metas['attention_masks'])
        return pred_entities

    def aug_test(self, imgs, img_metas, **kwargs):
        raise NotImplementedError('Augmentation test is not implemented yet.')

    def simple_test(self, img, img_metas, **kwargs):
        raise NotImplementedError('Simple test is not implemented yet.')