File size: 1,706 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder
from .base_encoder import BaseEncoder


@ENCODERS.register_module()
class ABIVisionModel(BaseEncoder):
    """A wrapper of visual feature encoder and language token decoder that
    converts visual features into text tokens.

    Implementation of VisionEncoder in
        `ABINet <https://arxiv.org/abs/1910.04396>`_.

    Args:
        encoder (dict): Config for image feature encoder.
        decoder (dict): Config for language token decoder.
        init_cfg (dict): Specifies the initialization method for model layers.
    """

    def __init__(self,
                 encoder=dict(type='TransformerEncoder'),
                 decoder=dict(type='ABIVisionDecoder'),
                 init_cfg=dict(type='Xavier', layer='Conv2d'),
                 **kwargs):
        super().__init__(init_cfg=init_cfg)
        self.encoder = build_encoder(encoder)
        self.decoder = build_decoder(decoder)

    def forward(self, feat, img_metas=None):
        """
        Args:
            feat (Tensor): Images of shape (N, E, H, W).

        Returns:
            dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.

            - | feature (Tensor): Shape (N, T, E). Raw visual features for
                language decoder.
            - | logits (Tensor): Shape (N, T, C). The raw logits for
                characters. C is the number of characters.
            - | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
                for vision-language aligner.
        """
        feat = self.encoder(feat)
        return self.decoder(feat=feat, out_enc=None)