countgd / models /GroundingDINO /transformer_loca.py
nikigoli's picture
Upload folder using huggingface_hub
a277bb8 verified
raw
history blame
2.69 kB
from .mlp_loca import MLP
from torch import nn
class TransformerEncoder(nn.Module):
def __init__(
self,
num_layers: int,
emb_dim: int,
num_heads: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
norm: bool,
):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([
TransformerEncoderLayer(
emb_dim, num_heads, dropout, layer_norm_eps,
mlp_factor, norm_first, activation
) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
output = src
for layer in self.layers:
output = layer(output, pos_emb, src_mask, src_key_padding_mask)
return self.norm(output)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
emb_dim: int,
num_heads: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
):
super(TransformerEncoderLayer, self).__init__()
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.self_attn = nn.MultiheadAttention(
emb_dim, num_heads, dropout
)
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
def with_emb(self, x, emb):
return x if emb is None else x + emb
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
if self.norm_first:
src_norm = self.norm1(src)
q = k = src_norm + pos_emb
src = src + self.dropout1(self.self_attn(
query=q,
key=k,
value=src_norm,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)[0])
src_norm = self.norm2(src)
src = src + self.dropout2(self.mlp(src_norm))
else:
q = k = src + pos_emb
src = self.norm1(src + self.dropout1(self.self_attn(
query=q,
key=k,
value=src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)[0]))
src = self.norm2(src + self.dropout2(self.mlp(src)))
return src