|
import sys |
|
import pandas |
|
import gradio |
|
import pathlib |
|
|
|
sys.path.append("lib") |
|
|
|
import torch |
|
|
|
from roberta2 import RobertaForSequenceClassification |
|
from transformers import AutoTokenizer |
|
|
|
from gradient_rollout import GradientRolloutExplainer |
|
from rollout import RolloutExplainer |
|
from integrated_gradients import IntegratedGradientsExplainer |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
model = RobertaForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2") |
|
|
|
ig_explainer = IntegratedGradientsExplainer(model, tokenizer) |
|
gr_explainer = GradientRolloutExplainer(model, tokenizer) |
|
ro_explainer = RolloutExplainer(model, tokenizer) |
|
|
|
def run(sent, gradient, rollout, ig, ig_baseline): |
|
a = gr_explainer(sent, gradient) |
|
b = ro_explainer(sent, rollout) |
|
c = ig_explainer(sent, ig, ig_baseline) |
|
return a, b, c |
|
|
|
examples = pandas.read_csv("examples.csv").to_numpy().tolist() |
|
|
|
with gradio.Blocks(title="Explanations with attention rollout") as iface: |
|
gradio.Markdown(pathlib.Path("description.md").read_text) |
|
with gradio.Row(equal_height=True): |
|
with gradio.Column(scale=4): |
|
sent = gradio.Textbox(label="Input sentence") |
|
with gradio.Column(scale=1): |
|
but = gradio.Button("Submit") |
|
with gradio.Row(equal_height=True): |
|
with gradio.Column(): |
|
rollout_layer = gradio.Slider( |
|
minimum=1, |
|
maximum=12, |
|
value=1, |
|
step=1, |
|
label="Select rollout start layer" |
|
) |
|
with gradio.Column(): |
|
gradient_layer = gradio.Slider( |
|
minimum=1, |
|
maximum=12, |
|
value=8, |
|
step=1, |
|
label="Select gradient rollout start layer" |
|
) |
|
with gradio.Column(): |
|
ig_layer = gradio.Slider( |
|
minimum=0, |
|
maximum=12, |
|
value=0, |
|
step=1, |
|
label="Select IG layer" |
|
) |
|
ig_baseline = gradio.Dropdown( |
|
label="Baseline token", |
|
choices=['Unknown', 'Padding'], value="Unknown" |
|
) |
|
with gradio.Row(equal_height=True): |
|
with gradio.Column(): |
|
gradio.Markdown("### Attention Rollout") |
|
rollout_result = gradio.HTML() |
|
with gradio.Column(): |
|
gradio.Markdown("### Gradient-weighted Attention Rollout") |
|
gradient_result = gradio.HTML() |
|
with gradio.Column(): |
|
gradio.Markdown("### Layer-Integrated Gradients") |
|
ig_result = gradio.HTML() |
|
gradio.Examples(examples, [sent]) |
|
with gradio.Accordion("Some more details"): |
|
gradio.Markdown(pathlib.Path("notice.md").read_text) |
|
|
|
gradient_layer.change(gr_explainer, [sent, gradient_layer], gradient_result) |
|
rollout_layer.change(ro_explainer, [sent, rollout_layer], rollout_result) |
|
ig_layer.change(ig_explainer, [sent, ig_layer, ig_baseline], ig_result) |
|
but.click(run, |
|
inputs=[sent, gradient_layer, rollout_layer, ig_layer, ig_baseline], |
|
outputs=[gradient_result, rollout_result, ig_result] |
|
) |
|
|
|
|
|
iface.launch() |
|
|