import copy import numbers from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor, nn from torch.nn import functional as F from .activation import MultiheadAttention from .scaling import ActivationBalancer, BalancedDoubleSwish from .scaling import BasicNorm as _BasicNorm from .rotary_embedding import RotaryEmbedding from .conv import ConvolutionModule, MultiLayeredConv1d _shape_t = Union[int, List[int], torch.Size] class LayerNorm(nn.Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment normalized_shape = (normalized_shape,) # type: ignore[assignment] self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) self.bias = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input return ( F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps, ), embedding, ) assert embedding is None return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps ) def extra_repr(self) -> str: return ( "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) ) class AdaptiveLayerNorm(nn.Module): r"""Adaptive Layer Normalization""" def __init__(self, d_model, norm) -> None: super(AdaptiveLayerNorm, self).__init__() self.project_layer = nn.Linear(d_model, 2 * d_model) self.norm = norm self.d_model = d_model self.eps = self.norm.eps def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: if isinstance(input, tuple): input, embedding = input weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return (weight * self.norm(input) + bias, embedding) weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return weight * self.norm(input) + bias class BasicNorm(_BasicNorm): def __init__( self, d_model: int, eps: float = 1e-5, device=None, dtype=None, ): super(BasicNorm, self).__init__(d_model, eps=eps) def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input return ( super(BasicNorm, self).forward(input), embedding, ) assert embedding is None return super(BasicNorm, self).forward(input) class BalancedBasicNorm(nn.Module): def __init__( self, d_model: int, eps: float = 1e-5, device=None, dtype=None, ): super(BalancedBasicNorm, self).__init__() self.balancer = ActivationBalancer( d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0, ) self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input return self.norm((self.balancer(input), embedding)) assert embedding is None return self.norm(self.balancer(input)) class IdentityNorm(nn.Module): def __init__( self, d_model: int, eps: float = 1e-5, device=None, dtype=None, ) -> None: super(IdentityNorm, self).__init__() def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): return input assert embedding is None return input class RMSNorm(nn.Module): def __init__(self, d, p=-1., eps=1e-8, bias=False): """ Root Mean Square Layer Normalization :param d: model size :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) :param eps: epsilon value, default 1e-8 :param bias: whether use bias term for RMSNorm, disabled by default because RMSNorm doesn't enforce re-centering invariance. """ super(RMSNorm, self).__init__() self.eps = eps self.d = d self.p = p self.bias = bias self.scale = nn.Parameter(torch.ones(d)) self.register_parameter("scale", self.scale) if self.bias: self.offset = nn.Parameter(torch.zeros(d)) self.register_parameter("offset", self.offset) def forward(self, x): if self.p < 0. or self.p > 1.: norm_x = x.norm(2, dim=-1, keepdim=True) d_x = self.d else: partial_size = int(self.d * self.p) partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) norm_x = partial_x.norm(2, dim=-1, keepdim=True) d_x = partial_size rms_x = norm_x * d_x ** (-1. / 2) x_normed = x / (rms_x + self.eps) if self.bias: return self.scale * x_normed + self.offset return self.scale * x_normed class TransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, use_conv_module: bool = False, use_depth_wise_conv: bool = False, conv_ignore_prefix_len: int = 0, cross_attention: bool = False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) if cross_attention: self.has_cross_attention = True self.cross_attn = nn.MultiheadAttention( d_model, nhead, 0.1, batch_first=True ) self.norm3 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) # Implementation of Feedforward model self.use_depth_wise_conv = use_depth_wise_conv self.use_conv_module = use_conv_module if not use_depth_wise_conv: self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) else: self.dw_ffn = MultiLayeredConv1d( in_chans=d_model, hidden_chans=dim_feedforward, kernel_size=5, dropout_rate=dropout, ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): activation = _get_activation_fn(activation) elif isinstance(activation, partial): activation = activation(d_model) elif activation == BalancedDoubleSwish: activation = BalancedDoubleSwish(d_model) self.activation = activation norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: norm2 = BalancedBasicNorm( d_model, eps=layer_norm_eps, **factory_kwargs ) else: norm2 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) else: self.norm1 = norm1 self.norm2 = norm2 self.rotary_emb = RotaryEmbedding(dim=d_model // nhead) if use_conv_module: self.conv_module = ConvolutionModule( d_model, kernel_size=31, activation=activation, ignore_prefix_len=conv_ignore_prefix_len, ) self.norm_conv = LayerNorm(d_model) # for the CNN module if adaptive_layer_norm: self.norm_conv = AdaptiveLayerNorm(d_model, self.norm_conv) else: self.conv_module = None def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, context: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, use_rope: bool = False, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ is_src_tuple = False if isinstance(src, tuple): x, stage_embedding = src is_src_tuple = True else: x, stage_embedding = src, None if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( src_key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), src_mask, src_key_padding_mask, use_rope=use_rope, ) if self.conv_module is not None: residual = x x = self.norm_conv(x, stage_embedding) x = residual + self.dropout1(self.conv_module(x)) # if self.has_cross_attention: # x = x + self.cross_attn( # self.norm3(x, stage_embedding), # context, # context, # attn_mask=src_mask, # )[0] x = x + self._ff_block(self.norm2(x, stage_embedding)) else: x = self.norm1( x + self._sa_block(x, src_mask, src_key_padding_mask, use_rope=use_rope), stage_embedding, ) if self.conv_module is not None: residual = x x = residual + self.dropout(self.conv_module(x)) x = self.norm_conv(x, stage_embedding) x = self.norm2(x + self._ff_block(x), stage_embedding) if is_src_tuple: return (x, stage_embedding) return x def infer( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, past_kv: Optional[Tensor] = None, use_cache: bool = False, use_rope: bool = False, ): x, stage_embedding = src, None is_src_tuple = False if isinstance(src, tuple): x, stage_embedding = src is_src_tuple = True if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( src_key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) if self.norm_first: x_attn_out, kv = self.self_attn.infer( self.norm1(x, stage_embedding), attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=False, past_kv=past_kv, use_cache=use_cache, use_rope=use_rope, rope=self.rotary_emb ) x = x + x_attn_out x = x + self._ff_block(self.norm2(x, stage_embedding)) if is_src_tuple: return (x, stage_embedding) return (x, kv) # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], use_rope: bool = False, ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, use_rope=use_rope, rope=self.rotary_emb )[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: if self.use_depth_wise_conv: x = self.dw_ffn(x) else: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). enable_nested_tensor: if True, input will automatically convert to nested tensor (and convert back on output). This will improve the overall performance of TransformerEncoder when padding rate is high. Default: ``True`` (enabled). Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ __constants__ = ["norm"] def __init__(self, encoder_layer, num_layers, norm=None): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, use_rope: bool = False, ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). return_layer_states: return layers' state (optional). Shape: see the docs in Transformer class. """ if return_layer_states: layer_states = [] # layers' output output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, use_rope=use_rope, ) layer_states.append(output[0]) if self.norm is not None: output = self.norm(output) return layer_states, output output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, use_rope=use_rope ) if self.norm is not None: output = self.norm(output) return output def infer( self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, past_kv: Optional[Tensor] = None, use_cache: bool = False, use_rope: bool = False, ): if past_kv is None: past_length = 0 past_kv = tuple([None] * self.num_layers) else: past_length = past_kv[0][0].size(-2) new_kv = () if use_cache else None output = src for mod, past_layer_kv in zip(self.layers, past_kv): output, kv = mod.infer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache, use_rope=use_rope ) if use_cache: new_kv = new_kv + (kv,) if self.norm is not None: output = self.norm(output) return output, new_kv class TransformerDecoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerDecoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) self.multihead_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) elif isinstance(activation, partial): self.activation = activation(d_model) elif activation == BalancedDoubleSwish: self.activation = BalancedDoubleSwish(d_model) else: self.activation = activation if adaptive_layer_norm: norm1 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) norm2 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) norm3 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) self.norm3 = AdaptiveLayerNorm(d_model, norm3) else: self.norm1 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) self.norm2 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) if layer_norm_cls == IdentityNorm: self.norm3 = BalancedBasicNorm( d_model, eps=layer_norm_eps, **factory_kwargs ) else: self.norm3 = layer_norm_cls( d_model, eps=layer_norm_eps, **factory_kwargs ) self.rotary_emb = RotaryEmbedding(dim=d_model // nhead) def forward( self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, use_rope: bool = False, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). Shape: see the docs in Transformer class. """ tgt_is_tuple = False if isinstance(tgt, tuple): x, stage_embedding = tgt tgt_is_tuple = True else: x, stage_embedding = tgt, None if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, use_rope=use_rope, ) x_mha_out, attn_map = self._mha_block( self.norm2(x, stage_embedding), memory, memory_mask, memory_key_padding_mask, use_rope=use_rope, ) x = x + x_mha_out x = x + self._ff_block(self.norm3(x, stage_embedding)) else: x = self.norm1( x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), stage_embedding, ) x = self.norm2( x + self._mha_block( x, memory, memory_mask, memory_key_padding_mask ), stage_embedding, ) x = self.norm3(x + self._ff_block(x), stage_embedding) if tgt_is_tuple: return (x, stage_embedding) return x, attn_map # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], use_rope: bool = False, ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, use_rope=use_rope, rope=self.rotary_emb )[0] return self.dropout1(x) # multihead attention block def _mha_block( self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], use_rope: bool = False, ) -> Tensor: x = self.multihead_attn( x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, use_rope=use_rope, rope=self.rotary_emb )[0] x, attn_map = x return self.dropout2(x[0]), attn_map # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x) class TransformerDecoder(nn.Module): r"""TransformerDecoder is a stack of N decoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: decoder_layer: an instance of the TransformerDecoderLayer() class (required). num_layers: the number of sub-decoder-layers in the decoder (required). norm: the layer normalization component (optional). enable_nested_tensor: if True, input will automatically convert to nested tensor (and convert back on output). This will improve the overall performance of TransformerDecoder when padding rate is high. Default: ``True`` (enabled). Examples:: >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) >>> tgt = torch.rand(10, 32, 512) >>> memory = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory) """ __constants__ = ["norm"] def __init__(self, decoder_layer, num_layers, norm=None): super(TransformerDecoder, self).__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, return_attn: bool = False, use_rope: bool = False, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layers in turn. Args: tgt: the sequence to the decoder (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). return_attn: return cross attention maps of each layer (optional). Shape: see the docs in Transformer class. """ attn_maps = [] output = tgt for mod in self.layers: output, attn_map = mod( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, use_rope=use_rope, ) if return_attn: attn_maps.append(attn_map) if self.norm is not None: output = self.norm(output) return output, attn_maps def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError( "activation should be relu/gelu, not {}".format(activation) )