Spaces:
Runtime error
Runtime error
File size: 1,706 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
# Copyright (c) OpenMMLab. All rights reserved.
from mmocr.models.builder import ENCODERS, build_decoder, build_encoder
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
class ABIVisionModel(BaseEncoder):
"""A wrapper of visual feature encoder and language token decoder that
converts visual features into text tokens.
Implementation of VisionEncoder in
`ABINet <https://arxiv.org/abs/1910.04396>`_.
Args:
encoder (dict): Config for image feature encoder.
decoder (dict): Config for language token decoder.
init_cfg (dict): Specifies the initialization method for model layers.
"""
def __init__(self,
encoder=dict(type='TransformerEncoder'),
decoder=dict(type='ABIVisionDecoder'),
init_cfg=dict(type='Xavier', layer='Conv2d'),
**kwargs):
super().__init__(init_cfg=init_cfg)
self.encoder = build_encoder(encoder)
self.decoder = build_decoder(decoder)
def forward(self, feat, img_metas=None):
"""
Args:
feat (Tensor): Images of shape (N, E, H, W).
Returns:
dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.
- | feature (Tensor): Shape (N, T, E). Raw visual features for
language decoder.
- | logits (Tensor): Shape (N, T, C). The raw logits for
characters. C is the number of characters.
- | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
for vision-language aligner.
"""
feat = self.encoder(feat)
return self.decoder(feat=feat, out_enc=None)
|