Spaces:
Running
on
T4
Running
on
T4
File size: 2,686 Bytes
a277bb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from .mlp_loca import MLP
from torch import nn
class TransformerEncoder(nn.Module):
def __init__(
self,
num_layers: int,
emb_dim: int,
num_heads: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
norm: bool,
):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([
TransformerEncoderLayer(
emb_dim, num_heads, dropout, layer_norm_eps,
mlp_factor, norm_first, activation
) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
output = src
for layer in self.layers:
output = layer(output, pos_emb, src_mask, src_key_padding_mask)
return self.norm(output)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
emb_dim: int,
num_heads: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
):
super(TransformerEncoderLayer, self).__init__()
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.self_attn = nn.MultiheadAttention(
emb_dim, num_heads, dropout
)
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
def with_emb(self, x, emb):
return x if emb is None else x + emb
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
if self.norm_first:
src_norm = self.norm1(src)
q = k = src_norm + pos_emb
src = src + self.dropout1(self.self_attn(
query=q,
key=k,
value=src_norm,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)[0])
src_norm = self.norm2(src)
src = src + self.dropout2(self.mlp(src_norm))
else:
q = k = src + pos_emb
src = self.norm1(src + self.dropout1(self.self_attn(
query=q,
key=k,
value=src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)[0]))
src = self.norm2(src + self.dropout2(self.mlp(src)))
return src
|