|
import math |
|
from transformers import PreTrainedModel |
|
from typing import List, Optional, Tuple |
|
from dataclasses import dataclass |
|
import torch |
|
import torch.nn as nn |
|
from fairseq.modules.multihead_attention import MultiheadAttention |
|
from .extra_fns import ACT2FN |
|
|
|
|
|
@dataclass |
|
class AbRepOutput(): |
|
""" |
|
Dataclass used to store AbRep output. |
|
""" |
|
last_hidden_state: torch.FloatTensor |
|
all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class EncoderBlocks(PreTrainedModel): |
|
""" |
|
Wrapper for multiple EncoderBlocks (or a single). |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.Layers = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_hidden_layers)]) |
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False): |
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
for num_block, a_EncoderBlock in enumerate(self.Layers): |
|
hidden_states, attentions = a_EncoderBlock(hidden_states, attention_mask, output_attentions) |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (attentions,) |
|
return AbRepOutput(last_hidden_state=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions) |
|
|
|
|
|
class EncoderBlock(PreTrainedModel): |
|
""" |
|
Single EncoderBlock. |
|
An EncoderBlock consists of a MultiHeadAttention and a IntermediateLayer. |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.MultiHeadAttention = ThirdMultiHeadAttention(config) |
|
self.MHADropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.MHALayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.IntermediateLayer = IntermediateLayer(config) |
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
|
MHAoutput, attentions = self.MultiHeadAttention(hidden_states, attention_mask, output_attentions=output_attentions) |
|
output = self.MHADropout(MHAoutput) |
|
output = self.MHALayerNorm(output + hidden_states) |
|
output = self.IntermediateLayer(output) |
|
return output, attentions |
|
|
|
|
|
class ThirdMultiHeadAttention(PreTrainedModel): |
|
""" |
|
New MultiHeadAttention which can return the weights of the individual heads. |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.Attention = MultiheadAttention(config.hidden_size, config.num_attention_heads, dropout=config.attention_probs_dropout_prob, self_attention=True) |
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
|
hidden_states = torch.transpose(hidden_states, 0, 1) |
|
|
|
attn_output, attn_weights = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, static_kv=True, |
|
need_weights=output_attentions, need_head_weights=output_attentions) |
|
return torch.transpose(attn_output, 0, 1), attn_weights |
|
|
|
|
|
class OldMultiHeadAttention(PreTrainedModel): |
|
""" |
|
MultiHeadAttention contains a Scaled Dot Product Attention and a Linear Layer. |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.Attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.attention_probs_dropout_prob) |
|
|
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
|
hidden_states = torch.transpose(hidden_states, 0, 1) |
|
output, attentions = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, need_weights=output_attentions) |
|
attention_output = torch.transpose(output, 0, 1) |
|
return attention_output, attentions |
|
|
|
|
|
class IntermediateLayer(PreTrainedModel): |
|
""" |
|
Contains an expanding layer, while also functioning as a residual block ending with a drop-norm layer |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.expand_dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
|
|
self.dense_dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states): |
|
output = self.expand_dense(hidden_states) |
|
output = self.intermediate_act_fn(output) |
|
output = self.dense_dense(output) |
|
output = self.dropout(output) |
|
output = self.LayerNorm(output + hidden_states) |
|
return output |
|
|