# Copyright (c) OpenMMLab. All rights reserved. import math import torch.nn as nn from mmcv.runner import ModuleList from mmocr.models.builder import ENCODERS from mmocr.models.common import TFEncoderLayer from .base_encoder import BaseEncoder @ENCODERS.register_module() class NRTREncoder(BaseEncoder): """Transformer Encoder block with self attention mechanism. Args: n_layers (int): The number of sub-encoder-layers in the encoder (default=6). n_head (int): The number of heads in the multiheadattention models (default=8). d_k (int): Total number of features in key. d_v (int): Total number of features in value. d_model (int): The number of expected features in the decoder inputs (default=512). d_inner (int): The dimension of the feedforward network model (default=256). dropout (float): Dropout layer on attn_output_weights. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, n_layers=6, n_head=8, d_k=64, d_v=64, d_model=512, d_inner=256, dropout=0.1, init_cfg=None, **kwargs): super().__init__(init_cfg=init_cfg) self.d_model = d_model self.layer_stack = ModuleList([ TFEncoderLayer( d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model) def _get_mask(self, logit, img_metas): valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] N, T, _ = logit.size() mask = None if valid_ratios is not None: mask = logit.new_zeros((N, T)) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(T, math.ceil(T * valid_ratio)) mask[i, :valid_width] = 1 return mask def forward(self, feat, img_metas=None): r""" Args: feat (Tensor): Backbone output of shape :math:`(N, C, H, W)`. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: The encoder output tensor. Shape :math:`(N, T, C)`. """ n, c, h, w = feat.size() feat = feat.view(n, c, h * w).permute(0, 2, 1).contiguous() mask = self._get_mask(feat, img_metas) output = feat for enc_layer in self.layer_stack: output = enc_layer(output, mask) output = self.layer_norm(output) return output