File size: 5,378 Bytes
001cc1f 0344cb1 001cc1f 0344cb1 001cc1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import math
from transformers import PreTrainedModel
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch
import torch.nn as nn
from fairseq.modules.multihead_attention import MultiheadAttention
from .extra_fns import ACT2FN
@dataclass
class AbRepOutput():
"""
Dataclass used to store AbRep output.
"""
last_hidden_state: torch.FloatTensor
all_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class EncoderBlocks(PreTrainedModel):
"""
Wrapper for multiple EncoderBlocks (or a single).
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.Layers = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for num_block, a_EncoderBlock in enumerate(self.Layers):
hidden_states, attentions = a_EncoderBlock(hidden_states, attention_mask, output_attentions)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) # Takes out each hidden states after each EncoderBlock
if output_attentions:
all_self_attentions = all_self_attentions + (attentions,) # Takes out attention layers for analysis
return AbRepOutput(last_hidden_state=hidden_states, all_hidden_states=all_hidden_states, attentions=all_self_attentions)
class EncoderBlock(PreTrainedModel):
"""
Single EncoderBlock.
An EncoderBlock consists of a MultiHeadAttention and a IntermediateLayer.
"""
def __init__(self, config):
super().__init__(config)
self.MultiHeadAttention = ThirdMultiHeadAttention(config)
self.MHADropout = nn.Dropout(config.hidden_dropout_prob)
self.MHALayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.IntermediateLayer = IntermediateLayer(config)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
MHAoutput, attentions = self.MultiHeadAttention(hidden_states, attention_mask, output_attentions=output_attentions)
output = self.MHADropout(MHAoutput)
output = self.MHALayerNorm(output + hidden_states) # HIDDEN_STATES ARE ADDED FOR RESIDUAL BLOCK EFFECT
output = self.IntermediateLayer(output) # INTERMEDIATELAYER HAS RESIDUAL BLOCK EFFECT INTERNALLY
return output, attentions
class ThirdMultiHeadAttention(PreTrainedModel):
"""
New MultiHeadAttention which can return the weights of the individual heads.
"""
def __init__(self, config):
super().__init__(config)
self.Attention = MultiheadAttention(config.hidden_size, config.num_attention_heads, dropout=config.attention_probs_dropout_prob, self_attention=True)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
hidden_states = torch.transpose(hidden_states, 0, 1)
# static_kv is only True because there is currently a bug which doesn't return the head weights unaveraged unless its true
attn_output, attn_weights = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, static_kv=True,
need_weights=output_attentions, need_head_weights=output_attentions)
return torch.transpose(attn_output, 0, 1), attn_weights
class OldMultiHeadAttention(PreTrainedModel):
"""
MultiHeadAttention contains a Scaled Dot Product Attention and a Linear Layer.
"""
def __init__(self, config):
super().__init__(config)
self.Attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, config.attention_probs_dropout_prob)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
hidden_states = torch.transpose(hidden_states, 0, 1)
output, attentions = self.Attention(hidden_states, hidden_states, hidden_states, key_padding_mask=attention_mask, need_weights=output_attentions)
attention_output = torch.transpose(output, 0, 1)
return attention_output, attentions
class IntermediateLayer(PreTrainedModel):
"""
Contains an expanding layer, while also functioning as a residual block ending with a drop-norm layer
"""
def __init__(self, config):
super().__init__(config)
self.expand_dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act]
self.dense_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
output = self.expand_dense(hidden_states)
output = self.intermediate_act_fn(output)
output = self.dense_dense(output)
output = self.dropout(output)
output = self.LayerNorm(output + hidden_states)
return output
|