File size: 1,871 Bytes
2e1a3f8
5a1ac3e
e8c51f1
9d1fa85
e8c51f1
64ac833
e8c51f1
2e1a3f8
 
9d1fa85
 
 
ab7830f
2e1a3f8
ab7830f
9e7d7f8
9d1fa85
2e1a3f8
 
9d1fa85
 
9e7d7f8
9d1fa85
 
 
 
cd3f110
5a1ac3e
 
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3f110
a9179d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import sys
import pandas
import gradio
import pathlib

sys.path.append("lib")

import torch

from roberta2 import RobertaForSequenceClassification
from gradient_rollout import GradientRolloutExplainer
from integrated_gradients import IntegratedGradientsExplainer
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization
import util
import torch

ig_explainer = IntegratedGradientsExplainer()
gr_explainer = GradientRolloutExplainer()

def run(sent, rollout, ig):
    a = gr_explainer(sent, rollout)
    b = ig_explainer(sent, ig)
    return a, b

examples = pandas.read_csv("examples.csv").to_numpy().tolist()

with gradio.Blocks(title="Explanations with attention rollout") as iface:
    util.Markdown(pathlib.Path("description.md"))
    with gradio.Row(equal_height=True):
        with gradio.Column(scale=4):
            sent = gradio.Textbox(label="Input sentence")
        with gradio.Column(scale=1):
            but = gradio.Button("Submit")
    with gradio.Row(equal_height=True):
        with gradio.Column():
            rollout_layer = gradio.Slider(minimum=0, maximum=12, value=8, step=1, label="Select rollout start layer")
            rollout_result = gradio.HTML()
        with gradio.Column():
            ig_layer = gradio.Slider(minimum=0, maximum=12, value=8, step=1, label="Select IG layer")
            ig_result = gradio.HTML()
    gradio.Examples(examples, [sent])
    with gradio.Accordion("A note about explainability models"):
        util.Markdown(pathlib.Path("notice.md"))

    rollout_layer.change(gr_explainer, [sent, rollout_layer], rollout_result)
    ig_layer.change(ig_explainer, [sent, ig_layer], ig_result)
    but.click(run, [sent, rollout_layer, ig_layer], [rollout_result, ig_result])


iface.launch()