File size: 1,871 Bytes
2e1a3f8 5a1ac3e e8c51f1 9d1fa85 e8c51f1 64ac833 e8c51f1 2e1a3f8 9d1fa85 ab7830f 2e1a3f8 ab7830f 9e7d7f8 9d1fa85 2e1a3f8 9d1fa85 9e7d7f8 9d1fa85 cd3f110 5a1ac3e 9d1fa85 cd3f110 a9179d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import sys
import pandas
import gradio
import pathlib
sys.path.append("lib")
import torch
from roberta2 import RobertaForSequenceClassification
from gradient_rollout import GradientRolloutExplainer
from integrated_gradients import IntegratedGradientsExplainer
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization
import util
import torch
ig_explainer = IntegratedGradientsExplainer()
gr_explainer = GradientRolloutExplainer()
def run(sent, rollout, ig):
a = gr_explainer(sent, rollout)
b = ig_explainer(sent, ig)
return a, b
examples = pandas.read_csv("examples.csv").to_numpy().tolist()
with gradio.Blocks(title="Explanations with attention rollout") as iface:
util.Markdown(pathlib.Path("description.md"))
with gradio.Row(equal_height=True):
with gradio.Column(scale=4):
sent = gradio.Textbox(label="Input sentence")
with gradio.Column(scale=1):
but = gradio.Button("Submit")
with gradio.Row(equal_height=True):
with gradio.Column():
rollout_layer = gradio.Slider(minimum=0, maximum=12, value=8, step=1, label="Select rollout start layer")
rollout_result = gradio.HTML()
with gradio.Column():
ig_layer = gradio.Slider(minimum=0, maximum=12, value=8, step=1, label="Select IG layer")
ig_result = gradio.HTML()
gradio.Examples(examples, [sent])
with gradio.Accordion("A note about explainability models"):
util.Markdown(pathlib.Path("notice.md"))
rollout_layer.change(gr_explainer, [sent, rollout_layer], rollout_result)
ig_layer.change(ig_explainer, [sent, ig_layer], ig_result)
but.click(run, [sent, rollout_layer, ig_layer], [rollout_result, ig_result])
iface.launch()
|