File size: 3,831 Bytes
9d1fa85
 
 
 
 
 
 
 
4f67e27
 
9d1fa85
 
 
 
 
4f67e27
 
 
 
7dff594
 
 
 
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dff594
 
 
 
 
9d1fa85
 
 
decf1d6
 
9d1fa85
 
 
 
 
 
7dff594
9d1fa85
decf1d6
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dff594
9d1fa85
 
 
 
 
 
 
 
 
 
7dff594
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch

from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

from captum.attr import LayerIntegratedGradients
from captum.attr import visualization

from roberta2 import RobertaForSequenceClassification
from ExplanationGenerator import Generator
from util import visualize_text

classifications = ["NEGATIVE", "POSITIVE"]

class IntegratedGradientsExplainer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.device = model.device
        self.tokenizer = tokenizer
        self.baseline_map = {
                'Unknown': self.tokenizer.unk_token_id,
                'Padding': self.tokenizer.pad_token_id,
            }

    def tokens_from_ids(self, ids):
        return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids)))

    def custom_forward(self, inputs, attention_mask=None, pos=0):
        result = self.model(inputs, attention_mask=attention_mask, return_dict=True)
        preds = result.logits
        return preds

    @staticmethod
    def summarize_attributions(attributions):
        attributions = attributions.sum(dim=-1).squeeze(0)
        attributions = attributions / torch.norm(attributions)
        return attributions

    def run_attribution_model(self, input_ids, attention_mask, baseline=None, index=None, layer=None, steps=20):
        if baseline is None:
            baseline = self.tokenizer.unk_token_id
        else:
            baseline = self.baseline_map[baseline]

        try:
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
#            if index is None:
#                index = output.argmax(axis=-1).item()

            ablator = LayerIntegratedGradients(self.custom_forward, layer)
            input_tensor = input_ids
            attention_mask = attention_mask
            attributions = ablator.attribute(
                    inputs=input_ids,
                    baselines=baseline,
                    additional_forward_args=(attention_mask),
                    target=1,
                    n_steps=steps,
            )
            return self.summarize_attributions(attributions).unsqueeze_(0), output, index
        finally:
            pass

    def build_visualization(self, input_ids, attention_mask, **kwargs):
        vis_data_records = []
        attributions, output, index = self.run_attribution_model(input_ids, attention_mask, **kwargs)
        for record in range(input_ids.size(0)):
            classification = output[record].argmax(dim=-1).item()
            class_name = classifications[classification]
            attr = attributions[record]
            tokens = self.tokens_from_ids(input_ids[record].flatten())[
                1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
            ]
            vis_data_records.append(
                visualization.VisualizationDataRecord(
                    attr,
                    output[record][classification],
                    classification,
                    classification,
                    index,
                    1,
                    tokens,
                    1,
                )
            )
        return visualize_text(vis_data_records)

    def __call__(self, input_text, layer, baseline):
        text_batch = [input_text]
        encoding = self.tokenizer(text_batch, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)
        layer = int(layer)
        if layer == 0:
            layer = self.model.roberta.embeddings
        else:
            layer = getattr(self.model.roberta.encoder.layer, str(layer-1))

        return self.build_visualization(input_ids, attention_mask, layer=layer, baseline=baseline)