Spaces:
Running
on
T4
Running
on
T4
from .mlp_loca import MLP | |
from torch import nn | |
class TransformerEncoder(nn.Module): | |
def __init__( | |
self, | |
num_layers: int, | |
emb_dim: int, | |
num_heads: int, | |
dropout: float, | |
layer_norm_eps: float, | |
mlp_factor: int, | |
norm_first: bool, | |
activation: nn.Module, | |
norm: bool, | |
): | |
super(TransformerEncoder, self).__init__() | |
self.layers = nn.ModuleList([ | |
TransformerEncoderLayer( | |
emb_dim, num_heads, dropout, layer_norm_eps, | |
mlp_factor, norm_first, activation | |
) for _ in range(num_layers) | |
]) | |
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity() | |
def forward(self, src, pos_emb, src_mask, src_key_padding_mask): | |
output = src | |
for layer in self.layers: | |
output = layer(output, pos_emb, src_mask, src_key_padding_mask) | |
return self.norm(output) | |
class TransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
emb_dim: int, | |
num_heads: int, | |
dropout: float, | |
layer_norm_eps: float, | |
mlp_factor: int, | |
norm_first: bool, | |
activation: nn.Module, | |
): | |
super(TransformerEncoderLayer, self).__init__() | |
self.norm_first = norm_first | |
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.self_attn = nn.MultiheadAttention( | |
emb_dim, num_heads, dropout | |
) | |
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation) | |
def with_emb(self, x, emb): | |
return x if emb is None else x + emb | |
def forward(self, src, pos_emb, src_mask, src_key_padding_mask): | |
if self.norm_first: | |
src_norm = self.norm1(src) | |
q = k = src_norm + pos_emb | |
src = src + self.dropout1(self.self_attn( | |
query=q, | |
key=k, | |
value=src_norm, | |
attn_mask=src_mask, | |
key_padding_mask=src_key_padding_mask | |
)[0]) | |
src_norm = self.norm2(src) | |
src = src + self.dropout2(self.mlp(src_norm)) | |
else: | |
q = k = src + pos_emb | |
src = self.norm1(src + self.dropout1(self.self_attn( | |
query=q, | |
key=k, | |
value=src, | |
attn_mask=src_mask, | |
key_padding_mask=src_key_padding_mask | |
)[0])) | |
src = self.norm2(src + self.dropout2(self.mlp(src))) | |
return src | |