import gradio as gr import torch import numpy as np import pandas as pd from tqdm.auto import tqdm import matplotlib.pyplot as plt import matplotlib from IPython.display import display, HTML from transformers import AutoTokenizer from DecompX.src.globenc_utils import GlobencConfig from DecompX.src.modeling_bert import BertForSequenceClassification from DecompX.src.modeling_roberta import RobertaForSequenceClassification plt.style.use("ggplot") MODELS = ["WillHeld/roberta-base-sst2"] def plot_clf(tokens, logits, label_names, title="", file_name=None): print(tokens) plt.figure(figsize=(4.5, 5)) colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)] plt.barh(range(len(tokens)), logits, color=colors) plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2) plt.gca().invert_yaxis() max_limit = np.max(np.abs(logits)) + 0.2 min_limit = -0.01 if np.min(logits) > 0 else -max_limit plt.xlim(min_limit, max_limit) plt.gca().set_xticks([min_limit, max_limit]) plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold") plt.gca().set_yticks(range(len(tokens))) plt.gca().set_yticklabels(tokens) plt.gca().yaxis.tick_right() for xtick, color in zip(plt.gca().get_yticklabels(), colors): xtick.set_color(color) xtick.set_fontweight("bold") xtick.set_verticalalignment("center") for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]): xtick.set_color(color) # plt.title(title, fontsize=14, fontweight="bold") plt.title(title) plt.tight_layout() def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False): """ importance: (sent_len) """ if no_cls_sep: importance = importance[1:-1] tokenized_text = tokenized_text[1:-1] importance = importance / np.abs(importance).max() / 1.5 # Normalize if discrete: importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6 html = "
"+prefix
for i in range(len(tokenized_text)):
if importance[i] >= 0:
rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia
else:
rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia
text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
html += (f"")
html += tokenized_text[i].replace('<', "[").replace(">", "]")
html += " "
html += "
"
# display(HTML(html))
return html
def print_preview(decompx_outputs_df, idx=0, discrete=False):
html = ""
NO_CLS_SEP = False
df = decompx_outputs_df
for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
if col in df and df[col][idx] is not None:
if "aggregated" in col:
sentence_importance = df[col].iloc[idx][0, :]
if "classifier" in col:
for label in range(df[col].iloc[idx].shape[-1]):
sentence_importance = df[col].iloc[idx][:, label]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=False
)
break
sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
html += print_importance(
sentence_importance,
df["tokens"].iloc[idx],
prefix=f"{col.split('_')[-1]}:".ljust(20),
no_cls_sep=NO_CLS_SEP,
discrete=discrete
)
return "