|
from __future__ import absolute_import
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
from BERT_explainability.modules.layers_ours import *
|
|
|
|
import transformers
|
|
|
|
from transformers import BertConfig
|
|
from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
|
|
from transformers import (
|
|
BertPreTrainedModel,
|
|
PreTrainedModel,
|
|
)
|
|
|
|
|
|
ACT2FN = {
|
|
"relu": ReLU,
|
|
"tanh": Tanh,
|
|
"gelu": GELU,
|
|
}
|
|
|
|
|
|
def get_activation(activation_string):
|
|
if activation_string in ACT2FN:
|
|
return ACT2FN[activation_string]
|
|
else:
|
|
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
|
|
|
|
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
|
|
|
num_tokens = all_layer_matrices[0].shape[1]
|
|
batch_size = all_layer_matrices[0].shape[0]
|
|
eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
|
|
all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
|
|
all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
|
for i in range(len(all_layer_matrices))]
|
|
joint_attention = all_layer_matrices[start_layer]
|
|
for i in range(start_layer+1, len(all_layer_matrices)):
|
|
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
|
return joint_attention
|
|
|
|
class RPBertEmbeddings(BertEmbeddings):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.add1 = Add()
|
|
self.add2 = Add()
|
|
|
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, :seq_length]
|
|
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
|
|
embeddings = self.add1([token_type_embeddings, position_embeddings])
|
|
embeddings = self.add2([embeddings, inputs_embeds])
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
cam = self.dropout.relprop(cam, **kwargs)
|
|
cam = self.LayerNorm.relprop(cam, **kwargs)
|
|
|
|
|
|
(cam) = self.add2.relprop(cam, **kwargs)
|
|
|
|
return cam
|
|
|
|
class RPBertEncoder(transformers.modeling_bert.BertEncoder):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
|
|
for layer_module in reversed(self.layer):
|
|
cam = layer_module.relprop(cam, **kwargs)
|
|
return cam
|
|
|
|
|
|
|
|
class RPBertPooler(transformers.modeling_bert.BertPooler):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.pool = IndexSelect()
|
|
|
|
def forward(self, hidden_states):
|
|
|
|
|
|
self._seq_size = hidden_states.shape[1]
|
|
|
|
|
|
first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
|
|
first_token_tensor = first_token_tensor.squeeze(1)
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
cam = self.activation.relprop(cam, **kwargs)
|
|
|
|
cam = self.dense.relprop(cam, **kwargs)
|
|
|
|
cam = cam.unsqueeze(1)
|
|
cam = self.pool.relprop(cam, **kwargs)
|
|
|
|
|
|
return cam
|
|
|
|
class BertAttention(transformers.modeling_bert.BertAttention):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.clone = Clone()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
output_attentions=False,
|
|
):
|
|
h1, h2 = self.clone(hidden_states, 2)
|
|
self_outputs = self.self(
|
|
h1,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions,
|
|
)
|
|
attention_output = self.output(self_outputs[0], h2)
|
|
outputs = (attention_output,) + self_outputs[1:]
|
|
return outputs
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
|
|
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
|
|
|
cam1 = self.self.relprop(cam1, **kwargs)
|
|
|
|
|
|
return self.clone.relprop((cam1, cam2), **kwargs)
|
|
|
|
class BertSelfAttention(transformers.modeling_bert.BertSelfAttention):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.matmul1 = MatMul()
|
|
self.matmul2 = MatMul()
|
|
self.softmax = Softmax(dim=-1)
|
|
self.add = Add()
|
|
self.mul = Mul()
|
|
self.head_mask = None
|
|
self.attention_mask = None
|
|
self.clone = Clone()
|
|
|
|
self.attn_cam = None
|
|
self.attn = None
|
|
self.attn_gradients = None
|
|
|
|
def get_attn(self):
|
|
return self.attn
|
|
|
|
def save_attn(self, attn):
|
|
self.attn = attn
|
|
|
|
def save_attn_cam(self, cam):
|
|
self.attn_cam = cam
|
|
|
|
def get_attn_cam(self):
|
|
return self.attn_cam
|
|
|
|
def save_attn_gradients(self, attn_gradients):
|
|
self.attn_gradients = attn_gradients
|
|
|
|
def get_attn_gradients(self):
|
|
return self.attn_gradients
|
|
|
|
def transpose_for_scores_relprop(self, x):
|
|
return x.permute(0, 2, 1, 3).flatten(2)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
output_attentions=False,
|
|
):
|
|
self.head_mask = head_mask
|
|
self.attention_mask = attention_mask
|
|
|
|
h1, h2, h3 = self.clone(hidden_states, 3)
|
|
mixed_query_layer = self.query(h1)
|
|
|
|
|
|
|
|
|
|
if encoder_hidden_states is not None:
|
|
mixed_key_layer = self.key(encoder_hidden_states)
|
|
mixed_value_layer = self.value(encoder_hidden_states)
|
|
attention_mask = encoder_attention_mask
|
|
else:
|
|
mixed_key_layer = self.key(h2)
|
|
mixed_value_layer = self.value(h3)
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
|
|
|
|
|
attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
if attention_mask is not None:
|
|
|
|
attention_scores = self.add([attention_scores, attention_mask])
|
|
|
|
|
|
attention_probs = self.softmax(attention_scores)
|
|
|
|
self.save_attn(attention_probs)
|
|
attention_probs.register_hook(self.save_attn_gradients)
|
|
|
|
|
|
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = self.matmul2([attention_probs, value_layer])
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
|
return outputs
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
|
|
cam = self.transpose_for_scores(cam)
|
|
|
|
|
|
(cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
|
|
cam1 /= 2
|
|
cam2 /= 2
|
|
if self.head_mask is not None:
|
|
|
|
(cam1, _)= self.mul.relprop(cam1, **kwargs)
|
|
|
|
|
|
self.save_attn_cam(cam1)
|
|
|
|
cam1 = self.dropout.relprop(cam1, **kwargs)
|
|
|
|
cam1 = self.softmax.relprop(cam1, **kwargs)
|
|
|
|
if self.attention_mask is not None:
|
|
|
|
(cam1, _) = self.add.relprop(cam1, **kwargs)
|
|
|
|
|
|
(cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
|
|
cam1_1 /= 2
|
|
cam1_2 /= 2
|
|
|
|
|
|
cam1_1 = self.transpose_for_scores_relprop(cam1_1)
|
|
cam1_1 = self.query.relprop(cam1_1, **kwargs)
|
|
|
|
|
|
cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
|
|
cam1_2 = self.key.relprop(cam1_2, **kwargs)
|
|
|
|
|
|
cam2 = self.transpose_for_scores_relprop(cam2)
|
|
cam2 = self.value.relprop(cam2, **kwargs)
|
|
|
|
cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
|
|
|
|
return cam
|
|
|
|
|
|
class BertSelfOutput(transformers.modeling_bert.BertSelfOutput):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.add = Add()
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
add = self.add([hidden_states, input_tensor])
|
|
hidden_states = self.LayerNorm(add)
|
|
return hidden_states
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
cam = self.LayerNorm.relprop(cam, **kwargs)
|
|
|
|
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
|
cam1 = self.dropout.relprop(cam1, **kwargs)
|
|
cam1 = self.dense.relprop(cam1, **kwargs)
|
|
|
|
return (cam1, cam2)
|
|
|
|
|
|
class BertIntermediate(transformers.modeling_bert.BertIntermediate):
|
|
def relprop(self, cam, **kwargs):
|
|
cam = self.intermediate_act_fn.relprop(cam, **kwargs)
|
|
|
|
cam = self.dense.relprop(cam, **kwargs)
|
|
|
|
return cam
|
|
|
|
|
|
class BertOutput(transformers.modeling_bert.BertOutput):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.add = Add()
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
add = self.add([hidden_states, input_tensor])
|
|
hidden_states = self.LayerNorm(add)
|
|
return hidden_states
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
|
|
cam = self.LayerNorm.relprop(cam, **kwargs)
|
|
|
|
|
|
(cam1, cam2)= self.add.relprop(cam, **kwargs)
|
|
|
|
cam1 = self.dropout.relprop(cam1, **kwargs)
|
|
|
|
cam1 = self.dense.relprop(cam1, **kwargs)
|
|
|
|
|
|
|
|
return (cam1, cam2)
|
|
|
|
|
|
class RPBertLayer(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.attention = BertAttention(config)
|
|
self.intermediate = BertIntermediate(config)
|
|
self.output = BertOutput(config)
|
|
self.clone = Clone()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
output_attentions=False,
|
|
):
|
|
self_attention_outputs = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = self_attention_outputs[0]
|
|
outputs = self_attention_outputs[1:]
|
|
|
|
ao1, ao2 = self.clone(attention_output, 2)
|
|
intermediate_output = self.intermediate(ao1)
|
|
layer_output = self.output(intermediate_output, ao2)
|
|
|
|
outputs = (layer_output,) + outputs
|
|
return outputs
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
|
|
|
cam1 = self.intermediate.relprop(cam1, **kwargs)
|
|
|
|
cam = self.clone.relprop((cam1, cam2), **kwargs)
|
|
|
|
cam = self.attention.relprop(cam, **kwargs)
|
|
|
|
return cam
|
|
|
|
|
|
class BertModel(BertPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = BertEmbeddings(config)
|
|
self.encoder = BertEncoder(config)
|
|
self.pooler = BertPooler(config)
|
|
|
|
self.init_weights()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
if the model is configured as a decoder.
|
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
|
is used in the cross-attention if the model is configured as a decoder.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(input_shape, device=device)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
|
|
|
|
|
|
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
if encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output)
|
|
|
|
if not return_dict:
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPooling(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
def relprop(self, cam, **kwargs):
|
|
cam = self.pooler.relprop(cam, **kwargs)
|
|
|
|
cam = self.encoder.relprop(cam, **kwargs)
|
|
|
|
|
|
return cam
|
|
|
|
|
|
transformers.modeling_bert.BertEmbeddings = RPBertEmbeddings
|
|
transformers.modeling_bert.BertEncoder = RPBertEncoder
|
|
|
|
if __name__ == '__main__':
|
|
class Config:
|
|
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
|
self.hidden_size = hidden_size
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
|
|
model = BertSelfAttention(Config(1024, 4, 0.1))
|
|
x = torch.rand(2, 20, 1024)
|
|
x.requires_grad_()
|
|
|
|
model.eval()
|
|
|
|
y = model.forward(x)
|
|
|
|
relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
|
|
|
|
print(relprop[1][0].shape)
|
|
|