Spaces:
Runtime error
Runtime error
File size: 2,887 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch.nn as nn
from mmcv.runner import ModuleList
from mmocr.models.builder import ENCODERS
from mmocr.models.common import TFEncoderLayer
from .base_encoder import BaseEncoder
@ENCODERS.register_module()
class NRTREncoder(BaseEncoder):
"""Transformer Encoder block with self attention mechanism.
Args:
n_layers (int): The number of sub-encoder-layers
in the encoder (default=6).
n_head (int): The number of heads in the
multiheadattention models (default=8).
d_k (int): Total number of features in key.
d_v (int): Total number of features in value.
d_model (int): The number of expected features
in the decoder inputs (default=512).
d_inner (int): The dimension of the feedforward
network model (default=256).
dropout (float): Dropout layer on attn_output_weights.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
n_layers=6,
n_head=8,
d_k=64,
d_v=64,
d_model=512,
d_inner=256,
dropout=0.1,
init_cfg=None,
**kwargs):
super().__init__(init_cfg=init_cfg)
self.d_model = d_model
self.layer_stack = ModuleList([
TFEncoderLayer(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs)
for _ in range(n_layers)
])
self.layer_norm = nn.LayerNorm(d_model)
def _get_mask(self, logit, img_metas):
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
]
N, T, _ = logit.size()
mask = None
if valid_ratios is not None:
mask = logit.new_zeros((N, T))
for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(T, math.ceil(T * valid_ratio))
mask[i, :valid_width] = 1
return mask
def forward(self, feat, img_metas=None):
r"""
Args:
feat (Tensor): Backbone output of shape :math:`(N, C, H, W)`.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
Returns:
Tensor: The encoder output tensor. Shape :math:`(N, T, C)`.
"""
n, c, h, w = feat.size()
feat = feat.view(n, c, h * w).permute(0, 2, 1).contiguous()
mask = self._get_mask(feat, img_metas)
output = feat
for enc_layer in self.layer_stack:
output = enc_layer(output, mask)
output = self.layer_norm(output)
return output
|