|
import sys |
|
import pandas |
|
import gradio |
|
|
|
sys.path.append("lib") |
|
|
|
import torch |
|
|
|
from transformers import AutoModelForSequenceClassification |
|
from BERT_explainability.ExplanationGenerator import Generator |
|
from BERT_explainability.roberta2 import RobertaForSequenceClassification |
|
from transformers import AutoTokenizer |
|
from captum.attr import LayerIntegratedGradients |
|
from captum.attr import visualization |
|
import torch |
|
|
|
|
|
class PyTMinMaxScalerVectorized(object): |
|
""" |
|
Transforms each channel to the range [0, 1]. |
|
""" |
|
|
|
def __init__(self, dimension=-1): |
|
self.d = dimension |
|
|
|
def __call__(self, tensor): |
|
d = self.d |
|
scale = 1.0 / ( |
|
tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0] |
|
) |
|
tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0]) |
|
return tensor |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
model = RobertaForSequenceClassification.from_pretrained( |
|
"textattack/roberta-base-SST-2" |
|
).to(device) |
|
model.eval() |
|
model2 = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2") |
|
tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2") |
|
|
|
explanations = Generator(model, "roberta") |
|
|
|
classifications = ["NEGATIVE", "POSITIVE"] |
|
|
|
|
|
def avg_heads(cam, grad): |
|
cam = (grad * cam).clamp(min=0).mean(dim=-3) |
|
|
|
|
|
return cam |
|
|
|
|
|
|
|
def apply_self_attention_rules(R_ss, cam_ss): |
|
R_ss_addition = torch.matmul(cam_ss, R_ss) |
|
return R_ss_addition |
|
|
|
|
|
def generate_relevance(model, input_ids, attention_mask, index=None, start_layer=0): |
|
output = model(input_ids=input_ids, attention_mask=attention_mask)[0] |
|
if index == None: |
|
|
|
|
|
index = output.argmax(axis=-1).detach().cpu().numpy() |
|
|
|
|
|
one_hot = ( |
|
torch.nn.functional.one_hot( |
|
torch.tensor(index, dtype=torch.int64), num_classes=output.size(-1) |
|
) |
|
.to(torch.float) |
|
.requires_grad_(True) |
|
).to(device) |
|
one_hot = torch.sum(one_hot * output) |
|
model.zero_grad() |
|
|
|
one_hot.backward(retain_graph=True) |
|
|
|
num_tokens = model.roberta.encoder.layer[0].attention.self.get_attn().shape[-1] |
|
R = torch.eye(num_tokens).expand(output.size(0), -1, -1).clone().to(device) |
|
|
|
for i, blk in enumerate(model.roberta.encoder.layer): |
|
if i < start_layer: |
|
continue |
|
grad = blk.attention.self.get_attn_gradients() |
|
cam = blk.attention.self.get_attn() |
|
cam = avg_heads(cam, grad) |
|
joint = apply_self_attention_rules(R, cam) |
|
R += joint |
|
return output, R[:, 0, 1:-1] |
|
|
|
|
|
def visualize_text(datarecords, legend=True): |
|
dom = ["<table width: 100%>"] |
|
rows = [ |
|
"<tr><th>True Label</th>" |
|
"<th>Predicted Label</th>" |
|
"<th>Attribution Label</th>" |
|
"<th>Attribution Score</th>" |
|
"<th>Word Importance</th>" |
|
] |
|
for datarecord in datarecords: |
|
rows.append( |
|
"".join( |
|
[ |
|
"<tr>", |
|
visualization.format_classname(datarecord.true_class), |
|
visualization.format_classname( |
|
"{0} ({1:.2f})".format( |
|
datarecord.pred_class, datarecord.pred_prob |
|
) |
|
), |
|
visualization.format_classname(datarecord.attr_class), |
|
visualization.format_classname( |
|
"{0:.2f}".format(datarecord.attr_score) |
|
), |
|
visualization.format_word_importances( |
|
datarecord.raw_input_ids, datarecord.word_attributions |
|
), |
|
"<tr>", |
|
] |
|
) |
|
) |
|
|
|
if legend: |
|
dom.append( |
|
'<div style="border-top: 1px solid; margin-top: 5px; \ |
|
padding-top: 5px; display: inline-block">' |
|
) |
|
dom.append("<b>Legend: </b>") |
|
|
|
for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]): |
|
dom.append( |
|
'<span style="display: inline-block; width: 10px; height: 10px; \ |
|
border: 1px solid; background-color: \ |
|
{value}"></span> {label} '.format( |
|
value=visualization._get_color(value), label=label |
|
) |
|
) |
|
dom.append("</div>") |
|
|
|
dom.append("".join(rows)) |
|
dom.append("</table>") |
|
html = "".join(dom) |
|
|
|
return html |
|
|
|
|
|
def show_explanation(model, input_ids, attention_mask, index=None, start_layer=8): |
|
|
|
output, expl = generate_relevance( |
|
model, input_ids, attention_mask, index=index, start_layer=start_layer |
|
) |
|
|
|
scaler = PyTMinMaxScalerVectorized() |
|
|
|
norm = scaler(expl) |
|
|
|
output = torch.nn.functional.softmax(output, dim=-1) |
|
|
|
vis_data_records = [] |
|
for record in range(input_ids.size(0)): |
|
classification = output[record].argmax(dim=-1).item() |
|
class_name = classifications[classification] |
|
nrm = norm[record] |
|
|
|
|
|
|
|
if class_name == "NEGATIVE": |
|
nrm *= -1 |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[ |
|
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1) |
|
] |
|
|
|
vis_data_records.append( |
|
visualization.VisualizationDataRecord( |
|
nrm, |
|
output[record][classification], |
|
classification, |
|
classification, |
|
index, |
|
1, |
|
tokens, |
|
1, |
|
) |
|
) |
|
return visualize_text(vis_data_records) |
|
|
|
def custom_forward(inputs, attention_mask=None, pos=0): |
|
result = model2(inputs, attention_mask=attention_mask, return_dict=True) |
|
preds = result.logits |
|
return preds |
|
|
|
def summarize_attributions(attributions): |
|
attributions = attributions.sum(dim=-1).squeeze(0) |
|
attributions = attributions / torch.norm(attributions) |
|
return attributions |
|
|
|
|
|
def run_attribution_model(input_ids, attention_mask, ref_token_id=tokenizer.unk_token_id, layer=None, steps=20): |
|
try: |
|
output = model2(input_ids=input_ids, attention_mask=attention_mask)[0] |
|
index = output.argmax(axis=-1).detach().cpu().numpy() |
|
|
|
ablator = LayerIntegratedGradients(custom_forward, layer) |
|
input_tensor = input_ids |
|
attention_mask = attention_mask |
|
attributions = ablator.attribute( |
|
inputs=input_ids, |
|
baselines=ref_token_id, |
|
additional_forward_args=(attention_mask), |
|
target=1, |
|
n_steps=steps, |
|
) |
|
attributions = summarize_attributions(attributions).unsqueeze_(0) |
|
finally: |
|
pass |
|
vis_data_records = [] |
|
for record in range(input_ids.size(0)): |
|
classification = output[record].argmax(dim=-1).item() |
|
class_name = classifications[classification] |
|
attr = attributions[record] |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[ |
|
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1) |
|
] |
|
vis_data_records.append( |
|
visualization.VisualizationDataRecord( |
|
attr, |
|
output[record][classification], |
|
classification, |
|
classification, |
|
index, |
|
1, |
|
tokens, |
|
1, |
|
) |
|
) |
|
return visualize_text(vis_data_records) |
|
|
|
def sentence_sentiment(input_text, layer): |
|
text_batch = [input_text] |
|
encoding = tokenizer(text_batch, return_tensors="pt") |
|
input_ids = encoding["input_ids"].to(device) |
|
attention_mask = encoding["attention_mask"].to(device) |
|
layer = int(layer) |
|
if layer == 0: |
|
layer = model2.roberta.embeddings |
|
else: |
|
layer = getattr(model2.roberta.encoder.layer, str(layer-1)) |
|
|
|
output = run_attribution_model(input_ids, attention_mask, layer=layer) |
|
return output |
|
|
|
def sentiment_explanation_hila(input_text, layer): |
|
text_batch = [input_text] |
|
encoding = tokenizer(text_batch, return_tensors="pt") |
|
input_ids = encoding["input_ids"].to(device) |
|
attention_mask = encoding["attention_mask"].to(device) |
|
|
|
|
|
true_class = 1 |
|
|
|
return show_explanation(model, input_ids, attention_mask, start_layer=int(layer)) |
|
|
|
layer_slider = gradio.Slider(minimum=0, maximum=12, value=8, step=1, label="Select layer") |
|
hila = gradio.Interface( |
|
fn=sentiment_explanation_hila, |
|
inputs=["text", layer_slider], |
|
outputs="html", |
|
) |
|
|
|
lig = gradio.Interface( |
|
fn=sentence_sentiment, |
|
inputs=["text", layer_slider], |
|
outputs="html", |
|
) |
|
|
|
with open("description.md", "r") as fh: |
|
description = fh.read() |
|
|
|
examples = pandas.read_csv("examples.csv").to_numpy().tolist() |
|
|
|
iface = gradio.Parallel(hila, lig, title="RoBERTa Explainability", description=description, examples=examples) |
|
|
|
iface.launch() |
|
|