# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.runner import BaseModule from mmocr.models.common.modules import (MultiHeadAttention, PositionwiseFeedForward) class TFEncoderLayer(BaseModule): """Transformer Encoder Layer. Args: d_model (int): The number of expected features in the decoder inputs (default=512). d_inner (int): The dimension of the feedforward network model (default=256). n_head (int): The number of heads in the multiheadattention models (default=8). d_k (int): Total number of features in key. d_v (int): Total number of features in value. dropout (float): Dropout layer on attn_output_weights. qkv_bias (bool): Add bias in projection layer. Default: False. act_cfg (dict): Activation cfg for feedforward module. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm', 'ffn'). Default:None. """ def __init__(self, d_model=512, d_inner=256, n_head=8, d_k=64, d_v=64, dropout=0.1, qkv_bias=False, act_cfg=dict(type='mmcv.GELU'), operation_order=None): super().__init__() self.attn = MultiHeadAttention( n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout) self.norm1 = nn.LayerNorm(d_model) self.mlp = PositionwiseFeedForward( d_model, d_inner, dropout=dropout, act_cfg=act_cfg) self.norm2 = nn.LayerNorm(d_model) self.operation_order = operation_order if self.operation_order is None: self.operation_order = ('norm', 'self_attn', 'norm', 'ffn') assert self.operation_order in [('norm', 'self_attn', 'norm', 'ffn'), ('self_attn', 'norm', 'ffn', 'norm')] def forward(self, x, mask=None): if self.operation_order == ('self_attn', 'norm', 'ffn', 'norm'): residual = x x = residual + self.attn(x, x, x, mask) x = self.norm1(x) residual = x x = residual + self.mlp(x) x = self.norm2(x) elif self.operation_order == ('norm', 'self_attn', 'norm', 'ffn'): residual = x x = self.norm1(x) x = residual + self.attn(x, x, x, mask) residual = x x = self.norm2(x) x = residual + self.mlp(x) return x class TFDecoderLayer(nn.Module): """Transformer Decoder Layer. Args: d_model (int): The number of expected features in the decoder inputs (default=512). d_inner (int): The dimension of the feedforward network model (default=256). n_head (int): The number of heads in the multiheadattention models (default=8). d_k (int): Total number of features in key. d_v (int): Total number of features in value. dropout (float): Dropout layer on attn_output_weights. qkv_bias (bool): Add bias in projection layer. Default: False. act_cfg (dict): Activation cfg for feedforward module. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm') or ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'). Default:None. """ def __init__(self, d_model=512, d_inner=256, n_head=8, d_k=64, d_v=64, dropout=0.1, qkv_bias=False, act_cfg=dict(type='mmcv.GELU'), operation_order=None): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.self_attn = MultiHeadAttention( n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias) self.enc_attn = MultiHeadAttention( n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias) self.mlp = PositionwiseFeedForward( d_model, d_inner, dropout=dropout, act_cfg=act_cfg) self.operation_order = operation_order if self.operation_order is None: self.operation_order = ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn') assert self.operation_order in [ ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'), ('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm') ] def forward(self, dec_input, enc_output, self_attn_mask=None, dec_enc_attn_mask=None): if self.operation_order == ('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', 'norm'): dec_attn_out = self.self_attn(dec_input, dec_input, dec_input, self_attn_mask) dec_attn_out += dec_input dec_attn_out = self.norm1(dec_attn_out) enc_dec_attn_out = self.enc_attn(dec_attn_out, enc_output, enc_output, dec_enc_attn_mask) enc_dec_attn_out += dec_attn_out enc_dec_attn_out = self.norm2(enc_dec_attn_out) mlp_out = self.mlp(enc_dec_attn_out) mlp_out += enc_dec_attn_out mlp_out = self.norm3(mlp_out) elif self.operation_order == ('norm', 'self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn'): dec_input_norm = self.norm1(dec_input) dec_attn_out = self.self_attn(dec_input_norm, dec_input_norm, dec_input_norm, self_attn_mask) dec_attn_out += dec_input enc_dec_attn_in = self.norm2(dec_attn_out) enc_dec_attn_out = self.enc_attn(enc_dec_attn_in, enc_output, enc_output, dec_enc_attn_mask) enc_dec_attn_out += dec_attn_out mlp_out = self.mlp(self.norm3(enc_dec_attn_out)) mlp_out += enc_dec_attn_out return mlp_out