DecompX / app.py
mohsenfayyaz's picture
Update app.py
f34a8cd
import gradio as gr
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
plt.style.use("ggplot")
MODELS = ['TehranNLP-org/bert-base-uncased-cls-sst2', 'TehranNLP-org/bert-large-sst2', "WillHeld/roberta-base-sst2"]
def plot_clf(tokens, logits, label_names, title="", file_name=None):
print(tokens)
plt.figure(figsize=(4.5, 5))
colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)]
plt.barh(range(len(tokens)), logits, color=colors)
plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2)
plt.gca().invert_yaxis()
max_limit = np.max(np.abs(logits)) + 0.2
min_limit = -0.01 if np.min(logits) > 0 else -max_limit
plt.xlim(min_limit, max_limit)
plt.gca().set_xticks([min_limit, max_limit])
plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold")
plt.gca().set_yticks(range(len(tokens)))
plt.gca().set_yticklabels(tokens)
plt.gca().yaxis.tick_right()
for xtick, color in zip(plt.gca().get_yticklabels(), colors):
xtick.set_color(color)
xtick.set_fontweight("bold")
xtick.set_verticalalignment("center")
for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]):
xtick.set_color(color)
# plt.title(title, fontsize=14, fontweight="bold")
plt.title(title)
plt.tight_layout()
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
"""
importance: (sent_len)
"""
if no_cls_sep:
importance = importance[1:-1]
tokenized_text = tokenized_text[1:-1]
importance = importance / np.abs(importance).max() / 1.5 # Normalize
if discrete:
importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6
html = "<pre style='color:black; padding: 3px;'>"+prefix
for i in range(len(tokenized_text)):
if importance[i] >= 0:
rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia
else:
rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia
text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
html += (f"<span style='"
f"{color}"
f"color:black; border-radius: 5px; padding: 3px;"
f"font-weight: {int(800)};"
"'>")
html += tokenized_text[i].replace('<', "[").replace(">", "]")
html += "</span> "
html += "</pre>"
# display(HTML(html))
return html
def print_preview(decompx_outputs_df, idx=0, discrete=False):
html = ""
NO_CLS_SEP = False
df = decompx_outputs_df
for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
if col in df and df[col][idx] is not None:
if "aggregated" in col:
sentence_importance = df[col].iloc[idx][0, :]
if "classifier" in col:
for label in range(df[col].iloc[idx].shape[-1]):
sentence_importance = df[col].iloc[idx][:, label]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=False
)
break
sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=discrete
)
return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html
def run_decompx(text, model):
"""
Provide DecompX Token Explanation of Model on Text
"""
SENTENCES = [text, "nothing"]
CONFIGS = {
"DecompX":
DecompXConfig(
include_biases=True,
bias_decomp_type="absdot",
include_LN1=True,
include_FFN=True,
FFN_approx_type="GeLU_ZO",
include_LN2=True,
aggregation="vector",
include_classifier_w_pooler=True,
tanh_approx_type="ZO",
output_all_layers=True,
output_attention=None,
output_res1=None,
output_LN1=None,
output_FFN=None,
output_res2=None,
output_encoder=None,
output_aggregated="norm",
output_pooler="norm",
output_classifier=True,
),
}
MODEL = model
# LOAD MODEL AND TOKENIZER
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
if "roberta" in MODEL:
model = RobertaForSequenceClassification.from_pretrained(MODEL)
elif "bert" in MODEL:
model = BertForSequenceClassification.from_pretrained(MODEL)
else:
raise Exception(f"Not implented model: {MODEL}")
# RUN DECOMPX
with torch.no_grad():
model.eval()
logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
**tokenized_sentence,
output_attentions=False,
return_dict=False,
output_hidden_states=True,
decompx_config=CONFIGS["DecompX"]
)
decompx_outputs = {
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
"logits": logits.cpu().detach().numpy().tolist(), # (batch, classes)
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
}
### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
decompx_outputs["importance_last_layer_classifier"] = importance
### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
decompx_outputs["importance_all_layers_aggregated"] = importance
decompx_outputs_df = pd.DataFrame(decompx_outputs)
idx = 0
pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1)
label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label]
tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1]
label = label[1:-1]
label = label / np.max(np.abs(label))
plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method")
return plt, print_preview(decompx_outputs_df)
demo = gr.Interface(
fn=run_decompx,
inputs=[
gr.components.Textbox(label="Text"),
gr.components.Dropdown(label="Model", choices=MODELS),
],
outputs=["plot", "html"],
examples=[
["a good piece of work more often than not.", "TehranNLP-org/bert-base-uncased-cls-sst2"],
["a good piece of work more often than not.", "TehranNLP-org/bert-large-sst2"],
["a good piece of work more often than not.", "WillHeld/roberta-base-sst2"],
["A deep and meaningful film.", "TehranNLP-org/bert-base-uncased-cls-sst2"],
],
cache_examples=True,
title="DecompX Demo",
description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)"
)
demo.launch()