# Copyright (c) OpenMMLab. All rights reserved. import math import torch import torch.nn as nn import torch.nn.functional as F import mmocr.utils as utils from mmocr.models.builder import ENCODERS from .base_encoder import BaseEncoder @ENCODERS.register_module() class SAREncoder(BaseEncoder): """Implementation of encoder module in `SAR. `_. Args: enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. enc_do_rnn (float): Dropout probability of RNN layer in encoder. enc_gru (bool): If True, use GRU, else LSTM in encoder. d_model (int): Dim :math:`D_i` of channels from backbone. d_enc (int): Dim :math:`D_m` of encoder RNN layer. mask (bool): If True, mask padding in RNN sequence. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, enc_bi_rnn=False, enc_do_rnn=0.0, enc_gru=False, d_model=512, d_enc=512, mask=True, init_cfg=[ dict(type='Xavier', layer='Conv2d'), dict(type='Uniform', layer='BatchNorm2d') ], **kwargs): super().__init__(init_cfg=init_cfg) assert isinstance(enc_bi_rnn, bool) assert isinstance(enc_do_rnn, (int, float)) assert 0 <= enc_do_rnn < 1.0 assert isinstance(enc_gru, bool) assert isinstance(d_model, int) assert isinstance(d_enc, int) assert isinstance(mask, bool) self.enc_bi_rnn = enc_bi_rnn self.enc_do_rnn = enc_do_rnn self.mask = mask # LSTM Encoder kwargs = dict( input_size=d_model, hidden_size=d_enc, num_layers=2, batch_first=True, dropout=enc_do_rnn, bidirectional=enc_bi_rnn) if enc_gru: self.rnn_encoder = nn.GRU(**kwargs) else: self.rnn_encoder = nn.LSTM(**kwargs) # global feature transformation encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) def forward(self, feat, img_metas=None): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A tensor of shape :math:`(N, D_m)`. """ if img_metas is not None: assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None h_feat = feat.size(2) feat_v = F.max_pool2d( feat, kernel_size=(h_feat, 1), stride=1, padding=0) feat_v = feat_v.squeeze(2) # bsz * C * W feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C if valid_ratios is not None: valid_hf = [] T = holistic_feat.size(1) for i, valid_ratio in enumerate(valid_ratios): valid_step = min(T, math.ceil(T * valid_ratio)) - 1 valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf = torch.stack(valid_hf, dim=0) else: valid_hf = holistic_feat[:, -1, :] # bsz * C holistic_feat = self.linear(valid_hf) # bsz * C return holistic_feat