Martijn van Beers
commited on
Commit
•
f59e918
1
Parent(s):
64ac833
Remove files that shouldn't have been committed
Browse files- lib/BERTalt.py +0 -551
- lib/roberta2.py.rej +0 -63
lib/BERTalt.py
DELETED
@@ -1,551 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import math
|
7 |
-
from BERT_explainability.modules.layers_ours import *
|
8 |
-
|
9 |
-
import transformers
|
10 |
-
|
11 |
-
from transformers import BertConfig
|
12 |
-
from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
|
13 |
-
from transformers import (
|
14 |
-
BertPreTrainedModel,
|
15 |
-
PreTrainedModel,
|
16 |
-
)
|
17 |
-
|
18 |
-
|
19 |
-
ACT2FN = {
|
20 |
-
"relu": ReLU,
|
21 |
-
"tanh": Tanh,
|
22 |
-
"gelu": GELU,
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
def get_activation(activation_string):
|
27 |
-
if activation_string in ACT2FN:
|
28 |
-
return ACT2FN[activation_string]
|
29 |
-
else:
|
30 |
-
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
|
31 |
-
|
32 |
-
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
33 |
-
# adding residual consideration
|
34 |
-
num_tokens = all_layer_matrices[0].shape[1]
|
35 |
-
batch_size = all_layer_matrices[0].shape[0]
|
36 |
-
eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
|
37 |
-
all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
|
38 |
-
all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
39 |
-
for i in range(len(all_layer_matrices))]
|
40 |
-
joint_attention = all_layer_matrices[start_layer]
|
41 |
-
for i in range(start_layer+1, len(all_layer_matrices)):
|
42 |
-
joint_attention = all_layer_matrices[i].bmm(joint_attention)
|
43 |
-
return joint_attention
|
44 |
-
|
45 |
-
class RPBertEmbeddings(BertEmbeddings):
|
46 |
-
def __init__(self, config):
|
47 |
-
super().__init__()
|
48 |
-
|
49 |
-
self.add1 = Add()
|
50 |
-
self.add2 = Add()
|
51 |
-
|
52 |
-
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
53 |
-
if input_ids is not None:
|
54 |
-
input_shape = input_ids.size()
|
55 |
-
else:
|
56 |
-
input_shape = inputs_embeds.size()[:-1]
|
57 |
-
|
58 |
-
seq_length = input_shape[1]
|
59 |
-
|
60 |
-
if position_ids is None:
|
61 |
-
position_ids = self.position_ids[:, :seq_length]
|
62 |
-
|
63 |
-
if token_type_ids is None:
|
64 |
-
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
65 |
-
|
66 |
-
if inputs_embeds is None:
|
67 |
-
inputs_embeds = self.word_embeddings(input_ids)
|
68 |
-
position_embeddings = self.position_embeddings(position_ids)
|
69 |
-
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
70 |
-
|
71 |
-
# embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
72 |
-
embeddings = self.add1([token_type_embeddings, position_embeddings])
|
73 |
-
embeddings = self.add2([embeddings, inputs_embeds])
|
74 |
-
embeddings = self.LayerNorm(embeddings)
|
75 |
-
embeddings = self.dropout(embeddings)
|
76 |
-
return embeddings
|
77 |
-
|
78 |
-
def relprop(self, cam, **kwargs):
|
79 |
-
cam = self.dropout.relprop(cam, **kwargs)
|
80 |
-
cam = self.LayerNorm.relprop(cam, **kwargs)
|
81 |
-
|
82 |
-
# [inputs_embeds, position_embeddings, token_type_embeddings]
|
83 |
-
(cam) = self.add2.relprop(cam, **kwargs)
|
84 |
-
|
85 |
-
return cam
|
86 |
-
|
87 |
-
class RPBertEncoder(transformers.modeling_bert.BertEncoder):
|
88 |
-
def __init__(self, config):
|
89 |
-
super().__init__()
|
90 |
-
self.config = config
|
91 |
-
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
92 |
-
|
93 |
-
def relprop(self, cam, **kwargs):
|
94 |
-
# assuming output_hidden_states is False
|
95 |
-
for layer_module in reversed(self.layer):
|
96 |
-
cam = layer_module.relprop(cam, **kwargs)
|
97 |
-
return cam
|
98 |
-
|
99 |
-
|
100 |
-
# not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
|
101 |
-
class RPBertPooler(transformers.modeling_bert.BertPooler):
|
102 |
-
def __init__(self, config):
|
103 |
-
super().__init__()
|
104 |
-
self.pool = IndexSelect()
|
105 |
-
|
106 |
-
def forward(self, hidden_states):
|
107 |
-
# We "pool" the model by simply taking the hidden state corresponding
|
108 |
-
# to the first token.
|
109 |
-
self._seq_size = hidden_states.shape[1]
|
110 |
-
|
111 |
-
# first_token_tensor = hidden_states[:, 0]
|
112 |
-
first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
|
113 |
-
first_token_tensor = first_token_tensor.squeeze(1)
|
114 |
-
pooled_output = self.dense(first_token_tensor)
|
115 |
-
pooled_output = self.activation(pooled_output)
|
116 |
-
return pooled_output
|
117 |
-
|
118 |
-
def relprop(self, cam, **kwargs):
|
119 |
-
cam = self.activation.relprop(cam, **kwargs)
|
120 |
-
#print(cam.sum())
|
121 |
-
cam = self.dense.relprop(cam, **kwargs)
|
122 |
-
#print(cam.sum())
|
123 |
-
cam = cam.unsqueeze(1)
|
124 |
-
cam = self.pool.relprop(cam, **kwargs)
|
125 |
-
#print(cam.sum())
|
126 |
-
|
127 |
-
return cam
|
128 |
-
|
129 |
-
class BertAttention(transformers.modeling_bert.BertAttention):
|
130 |
-
def __init__(self, config):
|
131 |
-
super().__init__()
|
132 |
-
self.clone = Clone()
|
133 |
-
|
134 |
-
def forward(
|
135 |
-
self,
|
136 |
-
hidden_states,
|
137 |
-
attention_mask=None,
|
138 |
-
head_mask=None,
|
139 |
-
encoder_hidden_states=None,
|
140 |
-
encoder_attention_mask=None,
|
141 |
-
output_attentions=False,
|
142 |
-
):
|
143 |
-
h1, h2 = self.clone(hidden_states, 2)
|
144 |
-
self_outputs = self.self(
|
145 |
-
h1,
|
146 |
-
attention_mask,
|
147 |
-
head_mask,
|
148 |
-
encoder_hidden_states,
|
149 |
-
encoder_attention_mask,
|
150 |
-
output_attentions,
|
151 |
-
)
|
152 |
-
attention_output = self.output(self_outputs[0], h2)
|
153 |
-
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
154 |
-
return outputs
|
155 |
-
|
156 |
-
def relprop(self, cam, **kwargs):
|
157 |
-
# assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
|
158 |
-
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
159 |
-
#print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
|
160 |
-
cam1 = self.self.relprop(cam1, **kwargs)
|
161 |
-
#print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
|
162 |
-
|
163 |
-
return self.clone.relprop((cam1, cam2), **kwargs)
|
164 |
-
|
165 |
-
class BertSelfAttention(transformers.modeling_bert.BertSelfAttention):
|
166 |
-
def __init__(self, config):
|
167 |
-
super().__init__()
|
168 |
-
|
169 |
-
self.matmul1 = MatMul()
|
170 |
-
self.matmul2 = MatMul()
|
171 |
-
self.softmax = Softmax(dim=-1)
|
172 |
-
self.add = Add()
|
173 |
-
self.mul = Mul()
|
174 |
-
self.head_mask = None
|
175 |
-
self.attention_mask = None
|
176 |
-
self.clone = Clone()
|
177 |
-
|
178 |
-
self.attn_cam = None
|
179 |
-
self.attn = None
|
180 |
-
self.attn_gradients = None
|
181 |
-
|
182 |
-
def get_attn(self):
|
183 |
-
return self.attn
|
184 |
-
|
185 |
-
def save_attn(self, attn):
|
186 |
-
self.attn = attn
|
187 |
-
|
188 |
-
def save_attn_cam(self, cam):
|
189 |
-
self.attn_cam = cam
|
190 |
-
|
191 |
-
def get_attn_cam(self):
|
192 |
-
return self.attn_cam
|
193 |
-
|
194 |
-
def save_attn_gradients(self, attn_gradients):
|
195 |
-
self.attn_gradients = attn_gradients
|
196 |
-
|
197 |
-
def get_attn_gradients(self):
|
198 |
-
return self.attn_gradients
|
199 |
-
|
200 |
-
def transpose_for_scores_relprop(self, x):
|
201 |
-
return x.permute(0, 2, 1, 3).flatten(2)
|
202 |
-
|
203 |
-
def forward(
|
204 |
-
self,
|
205 |
-
hidden_states,
|
206 |
-
attention_mask=None,
|
207 |
-
head_mask=None,
|
208 |
-
encoder_hidden_states=None,
|
209 |
-
encoder_attention_mask=None,
|
210 |
-
output_attentions=False,
|
211 |
-
):
|
212 |
-
self.head_mask = head_mask
|
213 |
-
self.attention_mask = attention_mask
|
214 |
-
|
215 |
-
h1, h2, h3 = self.clone(hidden_states, 3)
|
216 |
-
mixed_query_layer = self.query(h1)
|
217 |
-
|
218 |
-
# If this is instantiated as a cross-attention module, the keys
|
219 |
-
# and values come from an encoder; the attention mask needs to be
|
220 |
-
# such that the encoder's padding tokens are not attended to.
|
221 |
-
if encoder_hidden_states is not None:
|
222 |
-
mixed_key_layer = self.key(encoder_hidden_states)
|
223 |
-
mixed_value_layer = self.value(encoder_hidden_states)
|
224 |
-
attention_mask = encoder_attention_mask
|
225 |
-
else:
|
226 |
-
mixed_key_layer = self.key(h2)
|
227 |
-
mixed_value_layer = self.value(h3)
|
228 |
-
|
229 |
-
query_layer = self.transpose_for_scores(mixed_query_layer)
|
230 |
-
key_layer = self.transpose_for_scores(mixed_key_layer)
|
231 |
-
value_layer = self.transpose_for_scores(mixed_value_layer)
|
232 |
-
|
233 |
-
# Take the dot product between "query" and "key" to get the raw attention scores.
|
234 |
-
attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
|
235 |
-
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
236 |
-
if attention_mask is not None:
|
237 |
-
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
238 |
-
attention_scores = self.add([attention_scores, attention_mask])
|
239 |
-
|
240 |
-
# Normalize the attention scores to probabilities.
|
241 |
-
attention_probs = self.softmax(attention_scores)
|
242 |
-
|
243 |
-
self.save_attn(attention_probs)
|
244 |
-
attention_probs.register_hook(self.save_attn_gradients)
|
245 |
-
|
246 |
-
# This is actually dropping out entire tokens to attend to, which might
|
247 |
-
# seem a bit unusual, but is taken from the original Transformer paper.
|
248 |
-
attention_probs = self.dropout(attention_probs)
|
249 |
-
|
250 |
-
# Mask heads if we want to
|
251 |
-
if head_mask is not None:
|
252 |
-
attention_probs = attention_probs * head_mask
|
253 |
-
|
254 |
-
context_layer = self.matmul2([attention_probs, value_layer])
|
255 |
-
|
256 |
-
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
257 |
-
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
258 |
-
context_layer = context_layer.view(*new_context_layer_shape)
|
259 |
-
|
260 |
-
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
261 |
-
return outputs
|
262 |
-
|
263 |
-
def relprop(self, cam, **kwargs):
|
264 |
-
# Assume output_attentions == False
|
265 |
-
cam = self.transpose_for_scores(cam)
|
266 |
-
|
267 |
-
# [attention_probs, value_layer]
|
268 |
-
(cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
|
269 |
-
cam1 /= 2
|
270 |
-
cam2 /= 2
|
271 |
-
if self.head_mask is not None:
|
272 |
-
# [attention_probs, head_mask]
|
273 |
-
(cam1, _)= self.mul.relprop(cam1, **kwargs)
|
274 |
-
|
275 |
-
|
276 |
-
self.save_attn_cam(cam1)
|
277 |
-
|
278 |
-
cam1 = self.dropout.relprop(cam1, **kwargs)
|
279 |
-
|
280 |
-
cam1 = self.softmax.relprop(cam1, **kwargs)
|
281 |
-
|
282 |
-
if self.attention_mask is not None:
|
283 |
-
# [attention_scores, attention_mask]
|
284 |
-
(cam1, _) = self.add.relprop(cam1, **kwargs)
|
285 |
-
|
286 |
-
# [query_layer, key_layer.transpose(-1, -2)]
|
287 |
-
(cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
|
288 |
-
cam1_1 /= 2
|
289 |
-
cam1_2 /= 2
|
290 |
-
|
291 |
-
# query
|
292 |
-
cam1_1 = self.transpose_for_scores_relprop(cam1_1)
|
293 |
-
cam1_1 = self.query.relprop(cam1_1, **kwargs)
|
294 |
-
|
295 |
-
# key
|
296 |
-
cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
|
297 |
-
cam1_2 = self.key.relprop(cam1_2, **kwargs)
|
298 |
-
|
299 |
-
# value
|
300 |
-
cam2 = self.transpose_for_scores_relprop(cam2)
|
301 |
-
cam2 = self.value.relprop(cam2, **kwargs)
|
302 |
-
|
303 |
-
cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
|
304 |
-
|
305 |
-
return cam
|
306 |
-
|
307 |
-
|
308 |
-
class BertSelfOutput(transformers.modeling_bert.BertSelfOutput):
|
309 |
-
def __init__(self, config):
|
310 |
-
super().__init__()
|
311 |
-
self.add = Add()
|
312 |
-
|
313 |
-
def forward(self, hidden_states, input_tensor):
|
314 |
-
hidden_states = self.dense(hidden_states)
|
315 |
-
hidden_states = self.dropout(hidden_states)
|
316 |
-
add = self.add([hidden_states, input_tensor])
|
317 |
-
hidden_states = self.LayerNorm(add)
|
318 |
-
return hidden_states
|
319 |
-
|
320 |
-
def relprop(self, cam, **kwargs):
|
321 |
-
cam = self.LayerNorm.relprop(cam, **kwargs)
|
322 |
-
# [hidden_states, input_tensor]
|
323 |
-
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
324 |
-
cam1 = self.dropout.relprop(cam1, **kwargs)
|
325 |
-
cam1 = self.dense.relprop(cam1, **kwargs)
|
326 |
-
|
327 |
-
return (cam1, cam2)
|
328 |
-
|
329 |
-
|
330 |
-
class BertIntermediate(transformers.modeling_bert.BertIntermediate):
|
331 |
-
def relprop(self, cam, **kwargs):
|
332 |
-
cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
|
333 |
-
#print(cam.sum())
|
334 |
-
cam = self.dense.relprop(cam, **kwargs)
|
335 |
-
#print(cam.sum())
|
336 |
-
return cam
|
337 |
-
|
338 |
-
|
339 |
-
class BertOutput(transformers.modeling_bert.BertOutput):
|
340 |
-
def __init__(self, config):
|
341 |
-
super().__init__()
|
342 |
-
self.add = Add()
|
343 |
-
|
344 |
-
def forward(self, hidden_states, input_tensor):
|
345 |
-
hidden_states = self.dense(hidden_states)
|
346 |
-
hidden_states = self.dropout(hidden_states)
|
347 |
-
add = self.add([hidden_states, input_tensor])
|
348 |
-
hidden_states = self.LayerNorm(add)
|
349 |
-
return hidden_states
|
350 |
-
|
351 |
-
def relprop(self, cam, **kwargs):
|
352 |
-
# print("in", cam.sum())
|
353 |
-
cam = self.LayerNorm.relprop(cam, **kwargs)
|
354 |
-
#print(cam.sum())
|
355 |
-
# [hidden_states, input_tensor]
|
356 |
-
(cam1, cam2)= self.add.relprop(cam, **kwargs)
|
357 |
-
# print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
|
358 |
-
cam1 = self.dropout.relprop(cam1, **kwargs)
|
359 |
-
#print(cam1.sum())
|
360 |
-
cam1 = self.dense.relprop(cam1, **kwargs)
|
361 |
-
# print("dense", cam1.sum())
|
362 |
-
|
363 |
-
# print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
|
364 |
-
return (cam1, cam2)
|
365 |
-
|
366 |
-
|
367 |
-
class RPBertLayer(nn.Module):
|
368 |
-
def __init__(self, config):
|
369 |
-
super().__init__()
|
370 |
-
self.attention = BertAttention(config)
|
371 |
-
self.intermediate = BertIntermediate(config)
|
372 |
-
self.output = BertOutput(config)
|
373 |
-
self.clone = Clone()
|
374 |
-
|
375 |
-
def forward(
|
376 |
-
self,
|
377 |
-
hidden_states,
|
378 |
-
attention_mask=None,
|
379 |
-
head_mask=None,
|
380 |
-
output_attentions=False,
|
381 |
-
):
|
382 |
-
self_attention_outputs = self.attention(
|
383 |
-
hidden_states,
|
384 |
-
attention_mask,
|
385 |
-
head_mask,
|
386 |
-
output_attentions=output_attentions,
|
387 |
-
)
|
388 |
-
attention_output = self_attention_outputs[0]
|
389 |
-
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
390 |
-
|
391 |
-
ao1, ao2 = self.clone(attention_output, 2)
|
392 |
-
intermediate_output = self.intermediate(ao1)
|
393 |
-
layer_output = self.output(intermediate_output, ao2)
|
394 |
-
|
395 |
-
outputs = (layer_output,) + outputs
|
396 |
-
return outputs
|
397 |
-
|
398 |
-
def relprop(self, cam, **kwargs):
|
399 |
-
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
400 |
-
# print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
|
401 |
-
cam1 = self.intermediate.relprop(cam1, **kwargs)
|
402 |
-
# print("intermediate", cam1.sum())
|
403 |
-
cam = self.clone.relprop((cam1, cam2), **kwargs)
|
404 |
-
# print("clone", cam.sum())
|
405 |
-
cam = self.attention.relprop(cam, **kwargs)
|
406 |
-
# print("attention", cam.sum())
|
407 |
-
return cam
|
408 |
-
|
409 |
-
|
410 |
-
class BertModel(BertPreTrainedModel):
|
411 |
-
def __init__(self, config):
|
412 |
-
super().__init__(config)
|
413 |
-
self.config = config
|
414 |
-
|
415 |
-
self.embeddings = BertEmbeddings(config)
|
416 |
-
self.encoder = BertEncoder(config)
|
417 |
-
self.pooler = BertPooler(config)
|
418 |
-
|
419 |
-
self.init_weights()
|
420 |
-
|
421 |
-
def get_input_embeddings(self):
|
422 |
-
return self.embeddings.word_embeddings
|
423 |
-
|
424 |
-
def set_input_embeddings(self, value):
|
425 |
-
self.embeddings.word_embeddings = value
|
426 |
-
|
427 |
-
def forward(
|
428 |
-
self,
|
429 |
-
input_ids=None,
|
430 |
-
attention_mask=None,
|
431 |
-
token_type_ids=None,
|
432 |
-
position_ids=None,
|
433 |
-
head_mask=None,
|
434 |
-
inputs_embeds=None,
|
435 |
-
encoder_hidden_states=None,
|
436 |
-
encoder_attention_mask=None,
|
437 |
-
output_attentions=None,
|
438 |
-
output_hidden_states=None,
|
439 |
-
return_dict=None,
|
440 |
-
):
|
441 |
-
r"""
|
442 |
-
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
443 |
-
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
444 |
-
if the model is configured as a decoder.
|
445 |
-
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
446 |
-
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
447 |
-
is used in the cross-attention if the model is configured as a decoder.
|
448 |
-
Mask values selected in ``[0, 1]``:
|
449 |
-
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
450 |
-
"""
|
451 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
452 |
-
output_hidden_states = (
|
453 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
454 |
-
)
|
455 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
456 |
-
|
457 |
-
if input_ids is not None and inputs_embeds is not None:
|
458 |
-
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
459 |
-
elif input_ids is not None:
|
460 |
-
input_shape = input_ids.size()
|
461 |
-
elif inputs_embeds is not None:
|
462 |
-
input_shape = inputs_embeds.size()[:-1]
|
463 |
-
else:
|
464 |
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
465 |
-
|
466 |
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
467 |
-
|
468 |
-
if attention_mask is None:
|
469 |
-
attention_mask = torch.ones(input_shape, device=device)
|
470 |
-
if token_type_ids is None:
|
471 |
-
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
472 |
-
|
473 |
-
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
474 |
-
# ourselves in which case we just need to make it broadcastable to all heads.
|
475 |
-
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
476 |
-
|
477 |
-
# If a 2D or 3D attention mask is provided for the cross-attention
|
478 |
-
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
479 |
-
if self.config.is_decoder and encoder_hidden_states is not None:
|
480 |
-
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
481 |
-
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
482 |
-
if encoder_attention_mask is None:
|
483 |
-
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
484 |
-
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
485 |
-
else:
|
486 |
-
encoder_extended_attention_mask = None
|
487 |
-
|
488 |
-
# Prepare head mask if needed
|
489 |
-
# 1.0 in head_mask indicate we keep the head
|
490 |
-
# attention_probs has shape bsz x n_heads x N x N
|
491 |
-
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
492 |
-
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
493 |
-
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
494 |
-
|
495 |
-
embedding_output = self.embeddings(
|
496 |
-
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
497 |
-
)
|
498 |
-
|
499 |
-
encoder_outputs = self.encoder(
|
500 |
-
embedding_output,
|
501 |
-
attention_mask=extended_attention_mask,
|
502 |
-
head_mask=head_mask,
|
503 |
-
encoder_hidden_states=encoder_hidden_states,
|
504 |
-
encoder_attention_mask=encoder_extended_attention_mask,
|
505 |
-
output_attentions=output_attentions,
|
506 |
-
output_hidden_states=output_hidden_states,
|
507 |
-
return_dict=return_dict,
|
508 |
-
)
|
509 |
-
sequence_output = encoder_outputs[0]
|
510 |
-
pooled_output = self.pooler(sequence_output)
|
511 |
-
|
512 |
-
if not return_dict:
|
513 |
-
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
514 |
-
|
515 |
-
return BaseModelOutputWithPooling(
|
516 |
-
last_hidden_state=sequence_output,
|
517 |
-
pooler_output=pooled_output,
|
518 |
-
hidden_states=encoder_outputs.hidden_states,
|
519 |
-
attentions=encoder_outputs.attentions,
|
520 |
-
)
|
521 |
-
|
522 |
-
def relprop(self, cam, **kwargs):
|
523 |
-
cam = self.pooler.relprop(cam, **kwargs)
|
524 |
-
# print("111111111111",cam.sum())
|
525 |
-
cam = self.encoder.relprop(cam, **kwargs)
|
526 |
-
# print("222222222222222", cam.sum())
|
527 |
-
# print("conservation: ", cam.sum())
|
528 |
-
return cam
|
529 |
-
|
530 |
-
|
531 |
-
transformers.modeling_bert.BertEmbeddings = RPBertEmbeddings
|
532 |
-
transformers.modeling_bert.BertEncoder = RPBertEncoder
|
533 |
-
|
534 |
-
if __name__ == '__main__':
|
535 |
-
class Config:
|
536 |
-
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
537 |
-
self.hidden_size = hidden_size
|
538 |
-
self.num_attention_heads = num_attention_heads
|
539 |
-
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
540 |
-
|
541 |
-
model = BertSelfAttention(Config(1024, 4, 0.1))
|
542 |
-
x = torch.rand(2, 20, 1024)
|
543 |
-
x.requires_grad_()
|
544 |
-
|
545 |
-
model.eval()
|
546 |
-
|
547 |
-
y = model.forward(x)
|
548 |
-
|
549 |
-
relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
|
550 |
-
|
551 |
-
print(relprop[1][0].shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/roberta2.py.rej
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
--- modeling_roberta.py 2022-06-28 11:59:19.974278244 +0200
|
2 |
-
+++ roberta2.py 2022-06-28 14:13:05.765050058 +0200
|
3 |
-
@@ -23,14 +23,14 @@
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
6 |
-
|
7 |
-
-from ...activations import ACT2FN, gelu
|
8 |
-
-from ...file_utils import (
|
9 |
-
+from transformers.activations import ACT2FN, gelu
|
10 |
-
+from transformers.file_utils import (
|
11 |
-
add_code_sample_docstrings,
|
12 |
-
add_start_docstrings,
|
13 |
-
add_start_docstrings_to_model_forward,
|
14 |
-
replace_return_docstrings,
|
15 |
-
)
|
16 |
-
-from ...modeling_outputs import (
|
17 |
-
+from transformers.modeling_outputs import (
|
18 |
-
BaseModelOutputWithPastAndCrossAttentions,
|
19 |
-
BaseModelOutputWithPoolingAndCrossAttentions,
|
20 |
-
CausalLMOutputWithCrossAttentions,
|
21 |
-
@@ -40,14 +40,14 @@
|
22 |
-
SequenceClassifierOutput,
|
23 |
-
TokenClassifierOutput,
|
24 |
-
)
|
25 |
-
-from ...modeling_utils import (
|
26 |
-
+from transformers.modeling_utils import (
|
27 |
-
PreTrainedModel,
|
28 |
-
apply_chunking_to_forward,
|
29 |
-
find_pruneable_heads_and_indices,
|
30 |
-
prune_linear_layer,
|
31 |
-
)
|
32 |
-
-from ...utils import logging
|
33 |
-
-from .configuration_roberta import RobertaConfig
|
34 |
-
+from transformers.utils import logging
|
35 |
-
+from transformers.models.roberta.configuration_roberta import RobertaConfig
|
36 |
-
|
37 |
-
|
38 |
-
logger = logging.get_logger(__name__)
|
39 |
-
@@ -183,6 +183,24 @@
|
40 |
-
|
41 |
-
self.is_decoder = config.is_decoder
|
42 |
-
|
43 |
-
+ def get_attn(self):
|
44 |
-
+ return self.attn
|
45 |
-
+
|
46 |
-
+ def save_attn(self, attn):
|
47 |
-
+ self.attn = attn
|
48 |
-
+
|
49 |
-
+ def save_attn_cam(self, cam):
|
50 |
-
+ self.attn_cam = cam
|
51 |
-
+
|
52 |
-
+ def get_attn_cam(self):
|
53 |
-
+ return self.attn_cam
|
54 |
-
+
|
55 |
-
+ def save_attn_gradients(self, attn_gradients):
|
56 |
-
+ self.attn_gradients = attn_gradients
|
57 |
-
+
|
58 |
-
+ def get_attn_gradients(self):
|
59 |
-
+ return self.attn_gradients
|
60 |
-
+
|
61 |
-
def transpose_for_scores(self, x):
|
62 |
-
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
63 |
-
x = x.view(*new_x_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|