Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from mmcv.cnn.bricks.transformer import BaseTransformerLayer | |
from mmcv.runner import BaseModule, ModuleList | |
from mmocr.models.builder import ENCODERS | |
from mmocr.models.common.modules import PositionalEncoding | |
class TransformerEncoder(BaseModule): | |
"""Implement transformer encoder for text recognition, modified from | |
`<https://github.com/FangShancheng/ABINet>`. | |
Args: | |
n_layers (int): Number of attention layers. | |
n_head (int): Number of parallel attention heads. | |
d_model (int): Dimension :math:`D_m` of the input from previous model. | |
d_inner (int): Hidden dimension of feedforward layers. | |
dropout (float): Dropout rate. | |
max_len (int): Maximum output sequence length :math:`T`. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
""" | |
def __init__(self, | |
n_layers=2, | |
n_head=8, | |
d_model=512, | |
d_inner=2048, | |
dropout=0.1, | |
max_len=8 * 32, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert d_model % n_head == 0, 'd_model must be divisible by n_head' | |
self.pos_encoder = PositionalEncoding(d_model, n_position=max_len) | |
encoder_layer = BaseTransformerLayer( | |
operation_order=('self_attn', 'norm', 'ffn', 'norm'), | |
attn_cfgs=dict( | |
type='MultiheadAttention', | |
embed_dims=d_model, | |
num_heads=n_head, | |
attn_drop=dropout, | |
dropout_layer=dict(type='Dropout', drop_prob=dropout), | |
), | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=d_model, | |
feedforward_channels=d_inner, | |
ffn_drop=dropout, | |
), | |
norm_cfg=dict(type='LN'), | |
) | |
self.transformer = ModuleList( | |
[copy.deepcopy(encoder_layer) for _ in range(n_layers)]) | |
def forward(self, feature): | |
""" | |
Args: | |
feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. | |
Returns: | |
Tensor: Features of shape :math:`(N, D_m, H, W)`. | |
""" | |
n, c, h, w = feature.shape | |
feature = feature.view(n, c, -1).transpose(1, 2) # (n, h*w, c) | |
feature = self.pos_encoder(feature) # (n, h*w, c) | |
feature = feature.transpose(0, 1) # (h*w, n, c) | |
for m in self.transformer: | |
feature = m(feature) | |
feature = feature.permute(1, 2, 0).view(n, c, h, w) | |
return feature | |