Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import torch.nn as nn | |
from mmcv.runner import ModuleList | |
from mmocr.models.builder import ENCODERS | |
from mmocr.models.common import TFEncoderLayer | |
from .base_encoder import BaseEncoder | |
class NRTREncoder(BaseEncoder): | |
"""Transformer Encoder block with self attention mechanism. | |
Args: | |
n_layers (int): The number of sub-encoder-layers | |
in the encoder (default=6). | |
n_head (int): The number of heads in the | |
multiheadattention models (default=8). | |
d_k (int): Total number of features in key. | |
d_v (int): Total number of features in value. | |
d_model (int): The number of expected features | |
in the decoder inputs (default=512). | |
d_inner (int): The dimension of the feedforward | |
network model (default=256). | |
dropout (float): Dropout layer on attn_output_weights. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
""" | |
def __init__(self, | |
n_layers=6, | |
n_head=8, | |
d_k=64, | |
d_v=64, | |
d_model=512, | |
d_inner=256, | |
dropout=0.1, | |
init_cfg=None, | |
**kwargs): | |
super().__init__(init_cfg=init_cfg) | |
self.d_model = d_model | |
self.layer_stack = ModuleList([ | |
TFEncoderLayer( | |
d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) | |
for _ in range(n_layers) | |
]) | |
self.layer_norm = nn.LayerNorm(d_model) | |
def _get_mask(self, logit, img_metas): | |
valid_ratios = None | |
if img_metas is not None: | |
valid_ratios = [ | |
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas | |
] | |
N, T, _ = logit.size() | |
mask = None | |
if valid_ratios is not None: | |
mask = logit.new_zeros((N, T)) | |
for i, valid_ratio in enumerate(valid_ratios): | |
valid_width = min(T, math.ceil(T * valid_ratio)) | |
mask[i, :valid_width] = 1 | |
return mask | |
def forward(self, feat, img_metas=None): | |
r""" | |
Args: | |
feat (Tensor): Backbone output of shape :math:`(N, C, H, W)`. | |
img_metas (dict): A dict that contains meta information of input | |
images. Preferably with the key ``valid_ratio``. | |
Returns: | |
Tensor: The encoder output tensor. Shape :math:`(N, T, C)`. | |
""" | |
n, c, h, w = feat.size() | |
feat = feat.view(n, c, h * w).permute(0, 2, 1).contiguous() | |
mask = self._get_mask(feat, img_metas) | |
output = feat | |
for enc_layer in self.layer_stack: | |
output = enc_layer(output, mask) | |
output = self.layer_norm(output) | |
return output | |