attention-rollout / lib /integrated_gradients.py
Martijn van Beers
Clean up code
9d1fa85
raw
history blame
3.65 kB
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization
from util import visualize_text
classifications = ["NEGATIVE", "POSITIVE"]
class IntegratedGradientsExplainer:
def __init__(self):
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
self.ref_token_id = self.tokenizer.unk_token_id
def tokens_from_ids(self, ids):
return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))
def custom_forward(self, inputs, attention_mask=None, pos=0):
result = self.model(inputs, attention_mask=attention_mask, return_dict=True)
preds = result.logits
return preds
@staticmethod
def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20):
try:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
if index is None:
index = output.argmax(axis=-1).item()
ablator = LayerIntegratedGradients(self.custom_forward, layer)
input_tensor = input_ids
attention_mask = attention_mask
attributions = ablator.attribute(
inputs=input_ids,
baselines=self.ref_token_id,
additional_forward_args=(attention_mask),
target=index,
n_steps=steps,
)
return self.summarize_attributions(attributions).unsqueeze_(0), output, index
finally:
pass
def build_visualization(self, input_ids, attention_mask, **kwargs):
vis_data_records = []
attributions, output, index = self.run_attribution_model(input_ids, attention_mask, **kwargs)
for record in range(input_ids.size(0)):
classification = output[record].argmax(dim=-1).item()
class_name = classifications[classification]
attr = attributions[record]
tokens = self.tokens_from_ids(input_ids[record].flatten())[
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
]
vis_data_records.append(
visualization.VisualizationDataRecord(
attr,
output[record][classification],
classification,
classification,
index,
1,
tokens,
1,
)
)
return visualize_text(vis_data_records)
def __call__(self, input_text, layer):
text_batch = [input_text]
encoding = self.tokenizer(text_batch, return_tensors="pt")
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
layer = int(layer)
if layer == 0:
layer = self.model.roberta.embeddings
else:
layer = getattr(self.model.roberta.encoder.layer, str(layer-1))
return self.build_visualization(input_ids, attention_mask, layer=layer)