MMOCR / mmocr /models /ner /classifiers /ner_classifier.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
1.9 kB
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import (DETECTORS, build_convertor, build_decoder,
build_encoder, build_loss)
from mmocr.models.textrecog.recognizer.base import BaseRecognizer
@DETECTORS.register_module()
class NerClassifier(BaseRecognizer):
"""Base class for NER classifier."""
def __init__(self,
encoder,
decoder,
loss,
label_convertor,
train_cfg=None,
test_cfg=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.label_convertor = build_convertor(label_convertor)
self.encoder = build_encoder(encoder)
decoder.update(num_labels=self.label_convertor.num_labels)
self.decoder = build_decoder(decoder)
loss.update(num_labels=self.label_convertor.num_labels)
self.loss = build_loss(loss)
def extract_feat(self, imgs):
"""Extract features from images."""
raise NotImplementedError(
'Extract feature module is not implemented yet.')
def forward_train(self, imgs, img_metas, **kwargs):
encode_out = self.encoder(img_metas)
logits, _ = self.decoder(encode_out)
loss = self.loss(logits, img_metas)
return loss
def forward_test(self, imgs, img_metas, **kwargs):
encode_out = self.encoder(img_metas)
_, preds = self.decoder(encode_out)
pred_entities = self.label_convertor.convert_pred2entities(
preds, img_metas['attention_masks'])
return pred_entities
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('Augmentation test is not implemented yet.')
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('Simple test is not implemented yet.')