Spaces:
Build error
Build error
from typing import AnyStr, Dict | |
import itertools | |
import streamlit as st | |
import en_core_web_lg | |
import torch.nn.parameter | |
from bs4 import BeautifulSoup | |
import numpy as np | |
import base64 | |
from spacy_streamlit.util import get_svg | |
from streamlit.proto.SessionState_pb2 import SessionState | |
from custom_renderer import render_sentence_custom | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
from transformers import pipeline | |
import os | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; | |
margin-bottom: 2.5rem">{}</div> """ | |
def get_sentence_embedding_model(): | |
return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
def get_spacy(): | |
nlp = en_core_web_lg.load() | |
return nlp | |
def get_transformer_pipeline(): | |
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english") | |
model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english") | |
return pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True) | |
def get_summarizer_model(): | |
model_name = 'google/pegasus-cnn_dailymail' | |
summarizer_model = pipeline("summarization", model=model_name, tokenizer=model_name, | |
device=0 if torch.cuda.is_available() else -1) | |
return summarizer_model | |
# Page setup | |
st.set_page_config( | |
page_title="📜 Hallucination detection in summaries 📜", | |
page_icon="", | |
layout="centered", | |
initial_sidebar_state="auto", | |
menu_items={ | |
'Get help': None, | |
'Report a bug': None, | |
'About': None, | |
} | |
) | |
def list_all_article_names() -> list: | |
filenames = [] | |
for file in sorted(os.listdir('./sample-articles/')): | |
if file.endswith('.txt'): | |
filenames.append(file.replace('.txt', '')) | |
# Append free use possibility: | |
filenames.append("Provide your own input") | |
return filenames | |
def fetch_article_contents(filename: str) -> AnyStr: | |
if filename == "Provide your own input": | |
return " " | |
with open(f'./sample-articles/{filename}.txt', 'r') as f: | |
data = f.read() | |
return data | |
def fetch_summary_contents(filename: str) -> AnyStr: | |
with open(f'./sample-summaries/{filename}.txt', 'r') as f: | |
data = f.read() | |
return data | |
def fetch_entity_specific_contents(filename: str) -> AnyStr: | |
with open(f'./entity-specific-text/{filename}.txt', 'r') as f: | |
data = f.read() | |
return data | |
def fetch_dependency_specific_contents(filename: str) -> AnyStr: | |
with open(f'./dependency-specific-text/{filename}.txt', 'r') as f: | |
data = f.read() | |
return data | |
def fetch_ranked_summaries(filename: str, ranknumber: int) -> AnyStr: | |
with open(f'./ranked-summaries/{filename}/Rank{ranknumber}.txt', 'r') as f: | |
data = f.read() | |
return data | |
def fetch_dependency_svg(filename: str) -> AnyStr: | |
with open(f'./dependency-images/{filename}.txt', 'r') as f: | |
lines = [line.rstrip() for line in f] | |
return lines | |
def display_summary(summary_content: str): | |
st.session_state.summary_output = summary_content | |
soup = BeautifulSoup(summary_content, features="html.parser") | |
return HTML_WRAPPER.format(soup) | |
def get_all_entities_per_sentence(text): | |
doc = nlp(text) | |
sentences = list(doc.sents) | |
entities_all_sentences = [] | |
for sentence in sentences: | |
entities_this_sentence = [] | |
# SPACY ENTITIES | |
for entity in sentence.ents: | |
entities_this_sentence.append(str(entity)) | |
# FLAIR ENTITIES (CURRENTLY NOT USED) | |
# sentence_entities = Sentence(str(sentence)) | |
# tagger.predict(sentence_entities) | |
# for entity in sentence_entities.get_spans('ner'): | |
# entities_this_sentence.append(entity.text) | |
# XLM ENTITIES | |
entities_xlm = [entity["word"] for entity in ner_model(str(sentence))] | |
for entity in entities_xlm: | |
entities_this_sentence.append(str(entity)) | |
entities_all_sentences.append(entities_this_sentence) | |
return entities_all_sentences | |
def get_all_entities(text): | |
all_entities_per_sentence = get_all_entities_per_sentence(text) | |
return list(itertools.chain.from_iterable(all_entities_per_sentence)) | |
def get_and_compare_entities(first_time: bool): | |
if first_time: | |
article_content = st.session_state.article_text | |
all_entities_per_sentence = get_all_entities_per_sentence(article_content) | |
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence)) | |
st.session_state.entities_article = entities_article | |
else: | |
entities_article = st.session_state.entities_article | |
summary_content = st.session_state.summary_output | |
all_entities_per_sentence = get_all_entities_per_sentence(summary_content) | |
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence)) | |
matched_entities = [] | |
unmatched_entities = [] | |
for entity in entities_summary: | |
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): | |
matched_entities.append(entity) | |
elif any( | |
np.inner(sentence_embedding_model.encode(entity, show_progress_bar=False), | |
sentence_embedding_model.encode(art_entity, show_progress_bar=False)) > 0.9 for | |
art_entity in entities_article): | |
matched_entities.append(entity) | |
else: | |
unmatched_entities.append(entity) | |
matched_entities = list(dict.fromkeys(matched_entities)) | |
unmatched_entities = list(dict.fromkeys(unmatched_entities)) | |
matched_entities_to_remove = [] | |
unmatched_entities_to_remove = [] | |
for entity in matched_entities: | |
for substring_entity in matched_entities: | |
if entity != substring_entity and entity.lower() in substring_entity.lower(): | |
matched_entities_to_remove.append(entity) | |
for entity in unmatched_entities: | |
for substring_entity in unmatched_entities: | |
if entity != substring_entity and entity.lower() in substring_entity.lower(): | |
unmatched_entities_to_remove.append(entity) | |
matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove)) | |
unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove)) | |
for entity in matched_entities_to_remove: | |
matched_entities.remove(entity) | |
for entity in unmatched_entities_to_remove: | |
unmatched_entities.remove(entity) | |
return matched_entities, unmatched_entities | |
def highlight_entities(): | |
summary_content = st.session_state.summary_output | |
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" | |
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" | |
markdown_end = "</mark>" | |
matched_entities, unmatched_entities = get_and_compare_entities(True) | |
for entity in matched_entities: | |
summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end) | |
for entity in unmatched_entities: | |
summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end) | |
soup = BeautifulSoup(summary_content, features="html.parser") | |
return HTML_WRAPPER.format(soup) | |
def highlight_entities_new(summary_str: str): | |
st.session_state.summary_output = summary_str | |
summary_content = st.session_state.summary_output | |
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" | |
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" | |
markdown_end = "</mark>" | |
matched_entities, unmatched_entities = get_and_compare_entities(False) | |
for entity in matched_entities: | |
summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end) | |
for entity in unmatched_entities: | |
summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end) | |
soup = BeautifulSoup(summary_content, features="html.parser") | |
return HTML_WRAPPER.format(soup) | |
def render_dependency_parsing(text: Dict): | |
html = render_sentence_custom(text, nlp) | |
html = html.replace("\n\n", "\n") | |
st.write(get_svg(html), unsafe_allow_html=True) | |
def check_dependency(article: bool): | |
if article: | |
text = st.session_state.article_text | |
all_entities = get_all_entities_per_sentence(text) | |
else: | |
text = st.session_state.summary_output | |
all_entities = get_all_entities_per_sentence(text) | |
doc = nlp(text) | |
tok_l = doc.to_json()['tokens'] | |
test_list_dict_output = [] | |
sentences = list(doc.sents) | |
for i, sentence in enumerate(sentences): | |
start_id = sentence.start | |
end_id = sentence.end | |
for t in tok_l: | |
if t["id"] < start_id or t["id"] > end_id: | |
continue | |
head = tok_l[t['head']] | |
if t['dep'] == 'amod' or t['dep'] == "pobj": | |
object_here = text[t['start']:t['end']] | |
object_target = text[head['start']:head['end']] | |
if t['dep'] == "pobj" and str.lower(object_target) != "in": | |
continue | |
# ONE NEEDS TO BE ENTITY | |
if object_here in all_entities[i]: | |
identifier = object_here + t['dep'] + object_target | |
test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start), | |
"target_word_index": (t['head'] - sentence.start), | |
"identifier": identifier, "sentence": str(sentence)}) | |
elif object_target in all_entities[i]: | |
identifier = object_here + t['dep'] + object_target | |
test_list_dict_output.append({"dep": t['dep'], "cur_word_index": (t['id'] - sentence.start), | |
"target_word_index": (t['head'] - sentence.start), | |
"identifier": identifier, "sentence": str(sentence)}) | |
else: | |
continue | |
return test_list_dict_output | |
def render_svg(svg_file): | |
with open(svg_file, "r") as f: | |
lines = f.readlines() | |
svg = "".join(lines) | |
# """Renders the given svg string.""" | |
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
html = r'<img src="data:image/svg+xml;base64,%s"/>' % b64 | |
return html | |
def generate_abstractive_summary(text, type, min_len=120, max_len=512, **kwargs): | |
text = text.strip().replace("\n", " ") | |
if type == "top_p": | |
text = summarization_model(text, min_length=min_len, | |
max_length=max_len, | |
top_k=50, top_p=0.95, clean_up_tokenization_spaces=True, truncation=True, **kwargs) | |
elif type == "greedy": | |
text = summarization_model(text, min_length=min_len, | |
max_length=max_len, clean_up_tokenization_spaces=True, truncation=True, **kwargs) | |
elif type == "top_k": | |
text = summarization_model(text, min_length=min_len, max_length=max_len, top_k=50, | |
clean_up_tokenization_spaces=True, truncation=True, **kwargs) | |
elif type == "beam": | |
text = summarization_model(text, min_length=min_len, | |
max_length=max_len, | |
clean_up_tokenization_spaces=True, truncation=True, **kwargs) | |
summary = text[0]['summary_text'].replace("<n>", " ") | |
return summary | |
# Load all different models (cached) at start time of the hugginface space | |
sentence_embedding_model = get_sentence_embedding_model() | |
ner_model = get_transformer_pipeline() | |
nlp = get_spacy() | |
summarization_model = get_summarizer_model() | |
# Page | |
st.title('📜 Hallucination detection 📜') | |
st.subheader("🔎 Detecting errors in generated abstractive summaries") | |
#st.title('📜 Error detection in summaries 📜') | |
# INTRODUCTION | |
st.header("🧑🏫 Introduction") | |
#introduction_checkbox = st.checkbox("Show introduction text", value=True) | |
#if introduction_checkbox: | |
st.markdown(""" | |
Recent work using 🤖 **transformers** 🤖 on large text corpora has shown great success when fine-tuned on | |
several different downstream NLP tasks. One such task is that of text summarization. The goal of text summarization | |
is to generate concise and accurate summaries from input document(s). There are 2 types of summarization: | |
- **Extractive summarization** merely copies informative fragments from the input. | |
- **Abstractive summarization** | |
may generate novel words. A good abstractive summary should cover principal information in the input and has to be | |
linguistically fluent. This interactive blogpost will focus on this more difficult task of abstractive summary | |
generation. Furthermore we will focus mainly on hallucination errors, and less on sentence fluency.""") | |
st.markdown("###") | |
st.markdown("🤔 **Why is this important?** 🤔 Let's say we want to summarize news articles for a popular " | |
"newspaper. If an article tells the story of Elon Musk buying **Twitter**, we don't want our summarization " | |
"model to say that he bought **Facebook** instead. Summarization could also be done for financial reports " | |
"for example. In such environments, these errors can be very critical, so we want to find a way to " | |
"detect them.") | |
st.markdown("###") | |
st.markdown("""To generate summaries we will use the 🐎 [PEGASUS](https://huggingface.co/google/pegasus-cnn_dailymail) 🐎 | |
model, producing abstractive summaries from large articles. These summaries often contain sentences with different | |
kinds of errors. Rather than improving the core model, we will look into possible post-processing steps to detect errors | |
from the generated summaries. Throughout this blog, we will also explain the results for some methods on specific | |
examples. These text blocks will be indicated and they change according to the currently selected article.""") | |
# GENERATING SUMMARIES PART | |
st.header("🪶 Generating summaries") | |
st.markdown("Let’s start by selecting an article text for which we want to generate a summary, or you can provide " | |
"text yourself. Note that it’s suggested to provide a sufficiently large article, as otherwise the " | |
"summary generated from it might not be optimal, leading to suboptimal performance of the post-processing " | |
"steps. However, too long articles will be truncated and might miss information in the summary.") | |
st.markdown("####") | |
selected_article = st.selectbox('Select an article or provide your own:', | |
list_all_article_names(), index=2) | |
st.session_state.article_text = fetch_article_contents(selected_article) | |
article_text = st.text_area( | |
label='Full article text', | |
value=st.session_state.article_text, | |
height=250 | |
) | |
summarize_button = st.button(label='🤯 Process article content', | |
help="Start interactive blogpost") | |
if summarize_button: | |
st.session_state.article_text = article_text | |
st.markdown("####") | |
st.markdown( | |
"*Below you can find the generated summary for the article. We will discuss two approaches that we found are " | |
"able to detect some common errors. Based on these errors, one could then score different summaries, indicating how " | |
"factual a summary is for a given article. The idea is that in production, you could generate a set of " | |
"summaries for the same article, with different parameters (or even different models). By using " | |
"post-processing error detection, we can then select the best possible summary.*") | |
st.markdown("####") | |
if st.session_state.article_text: | |
with st.spinner('Generating summary, this might take a while...'): | |
if selected_article != "Provide your own input" and article_text == fetch_article_contents( | |
selected_article): | |
st.session_state.unchanged_text = True | |
summary_content = fetch_summary_contents(selected_article) | |
else: | |
summary_content = generate_abstractive_summary(article_text, type="beam", do_sample=True, num_beams=15, | |
no_repeat_ngram_size=4) | |
st.session_state.unchanged_text = False | |
summary_displayed = display_summary(summary_content) | |
st.write("✍ **Generated summary:** ✍", summary_displayed, unsafe_allow_html=True) | |
else: | |
st.error('**Error**: No comment to classify. Please provide a comment.') | |
# ENTITY MATCHING PART | |
st.header("1️⃣ Entity matching") | |
st.markdown("The first method we will discuss is called **Named Entity Recognition** (NER). NER is the task of " | |
"identifying and categorising key information (entities) in text. An entity can be a singular word or a " | |
"series of words that consistently refers to the same thing. Common entity classes are person names, " | |
"organisations, locations and so on. By applying NER to both the article and its summary, we can spot " | |
"possible **hallucinations**. ") | |
st.markdown("Hallucinations are words generated by the model that are not supported by " | |
"the source input. Deep learning based generation is [prone to hallucinate](" | |
"https://arxiv.org/pdf/2202.03629.pdf) unintended text. These hallucinations degrade " | |
"system performance and fail to meet user expectations in many real-world scenarios. By applying entity matching, we can improve this problem" | |
" for the downstream task of summary generation.") | |
st.markdown(" In theory all entities in the summary (such as dates, locations and so on), " | |
"should also be present in the article. Thus we can extract all entities from the summary and compare " | |
"them to the entities of the original article, spotting potential hallucinations. The more unmatched " | |
"entities we find, the lower the factualness score of the summary. ") | |
with st.spinner("Calculating and matching entities, this takes about 10-20 seconds..."): | |
entity_match_html = highlight_entities() | |
st.markdown("####") | |
st.write(entity_match_html, unsafe_allow_html=True) | |
red_text = """<font color="black"><span style="background-color: rgb(238, 135, 135); opacity: | |
1;">red</span></font> """ | |
green_text = """<font color="black"> | |
<span style="background-color: rgb(121, 236, 121); opacity: 1;">green</span> | |
</font>""" | |
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">" | |
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">" | |
st.markdown( | |
"We call this technique **entity matching** and here you can see what this looks like when we apply this " | |
"method on the summary. Entities in the summary are marked " + green_text + " when the entity also " | |
"exists in the article, " | |
"while unmatched entities " | |
"are marked " + red_text + | |
". Several of the example articles and their summaries indicate different errors we find by using this " | |
"technique. Based on the current article, we provide a short explanation of the results below **(only for " | |
"example articles)**. ", unsafe_allow_html=True) | |
if st.session_state.unchanged_text: | |
entity_specific_text = fetch_entity_specific_contents(selected_article) | |
soup = BeautifulSoup(entity_specific_text, features="html.parser") | |
st.markdown("####") | |
st.write("💡👇 **Specific example explanation** 👇💡", HTML_WRAPPER.format(soup), unsafe_allow_html=True) | |
# DEPENDENCY PARSING PART | |
st.header("2️⃣ Dependency comparison") | |
st.markdown( | |
"The second method we use for post-processing is called **Dependency Parsing**: the process in which the " | |
"grammatical structure in a sentence is analysed, to find out related words as well as the type of the " | |
"relationship between them. For the sentence “Jan’s wife is called Sarah” you would get the following " | |
"dependency graph:") | |
# TODO: I wonder why the first doesn't work but the second does (it doesn't show deps otherwise) | |
# st.image("ExampleParsing.svg") | |
st.write(render_svg('ExampleParsing.svg'), unsafe_allow_html=True) | |
st.markdown( | |
"Here, *“Jan”* is the *“poss”* (possession modifier) of *“wife”*. If suddenly the summary would read *“Jan’s" | |
" husband…”*, there would be a dependency in the summary that is non-existent in the article itself (namely " | |
"*“Jan”* is the “poss” of *“husband”*)." | |
"However, often new dependencies are introduced in the summary that " | |
"are still correct, as can be seen in the example below. ") | |
st.write(render_svg('SecondExampleParsing.svg'), unsafe_allow_html=True) | |
st.markdown("*“The borders of Ukraine”* have a different dependency between *“borders”* and " | |
"*“Ukraine”* " | |
"than *“Ukraine’s borders”*, while both descriptions have the same meaning. So just matching all " | |
"dependencies between article and summary (as we did with entity matching) would not be a robust method." | |
" More on the different sorts of dependencies and their description can be found [here](https://universaldependencies.org/docs/en/dep/).") | |
st.markdown("However, we have found that **there are specific dependencies that are often an " | |
"indication of a wrongly constructed sentence** when there is no article match. We (currently) use 2 " | |
"common dependencies which - when present in the summary but not in the article - are highly " | |
"indicative of factualness errors. " | |
"Furthermore, we only check dependencies between an existing **entity** and its direct connections. " | |
"Below we highlight all unmatched dependencies that satisfy the discussed constraints. We also " | |
"discuss the specific results for the currently selected example article.") | |
with st.spinner("Doing dependency parsing..."): | |
if st.session_state.unchanged_text: | |
for cur_svg_image in fetch_dependency_svg(selected_article): | |
st.write(cur_svg_image, unsafe_allow_html=True) | |
dep_specific_text = fetch_dependency_specific_contents(selected_article) | |
soup = BeautifulSoup(dep_specific_text, features="html.parser") | |
st.write("💡👇 **Specific example explanation** 👇💡", HTML_WRAPPER.format(soup), unsafe_allow_html=True) | |
else: | |
summary_deps = check_dependency(False) | |
article_deps = check_dependency(True) | |
total_unmatched_deps = [] | |
for summ_dep in summary_deps: | |
if not any(summ_dep['identifier'] in art_dep['identifier'] for art_dep in article_deps): | |
total_unmatched_deps.append(summ_dep) | |
if total_unmatched_deps: | |
for current_drawing_list in total_unmatched_deps: | |
render_dependency_parsing(current_drawing_list) | |
# CURRENTLY DISABLED | |
# OUTRO/CONCLUSION | |
st.header("🤝 Bringing it together") | |
st.markdown("We have presented 2 methods that try to detect errors in summaries via post-processing steps. Entity " | |
"matching can be used to solve hallucinations, while dependency comparison can be used to filter out " | |
"some bad sentences (and thus worse summaries). These methods highlight the possibilities of " | |
"post-processing AI-made summaries, but are only a first introduction. As the methods were " | |
"empirically tested they are definitely not sufficiently robust for general use-cases.") | |
st.markdown("####") | |
st.markdown( | |
"*Below we generate 3 different kind of summaries, and based on the two discussed methods, their errors are " | |
"detected to estimate a summary score. Based on this basic approach, " | |
"the best summary (read: the one that a human would prefer or indicate as the best one) " | |
"will hopefully be at the top. We currently " | |
"only do this for the example articles (for which the different summmaries are already generated). The reason " | |
"for this is that HuggingFace spaces are limited in their CPU memory. We also highlight the entities as done " | |
"before, but note that the rankings are done on a combination of unmatched entities and " | |
"dependencies (with the latter not shown here).*") | |
st.markdown("####") | |
if selected_article != "Provide your own input" and article_text == fetch_article_contents(selected_article): | |
with st.spinner("Fetching summaries, ranking them and highlighting entities, this might take a minute or two..."): | |
summaries_list = [] | |
deduction_points = [] | |
# FOR NEW GENERATED SUMMARY | |
for i in range(1 , 4): | |
st.session_state.summary_output = fetch_ranked_summaries(selected_article, i) | |
_, amount_unmatched = get_and_compare_entities(False) | |
summary_deps = check_dependency(False) | |
article_deps = check_dependency(True) | |
total_unmatched_deps = [] | |
for summ_dep in summary_deps: | |
if not any(summ_dep['identifier'] in art_dep['identifier'] for art_dep in article_deps): | |
total_unmatched_deps.append(summ_dep) | |
summaries_list.append(st.session_state.summary_output) | |
deduction_points.append(len(amount_unmatched) + len(total_unmatched_deps)) | |
# RANKING AND SHOWING THE SUMMARIES | |
deduction_points, summaries_list = (list(t) for t in zip(*sorted(zip(deduction_points, summaries_list)))) | |
cur_rank = 1 | |
rank_downgrade = 0 | |
for i in range(len(deduction_points)): | |
#st.write(f'🏆 Rank {cur_rank} summary: 🏆', display_summary(summaries_list[i]), unsafe_allow_html=True) | |
st.write(f'🏆 Rank {cur_rank} summary: 🏆', highlight_entities_new(summaries_list[i]), unsafe_allow_html=True) | |
if i < len(deduction_points) - 1: | |
rank_downgrade += 1 | |
if not deduction_points[i + 1] == deduction_points[i]: | |
cur_rank += rank_downgrade | |
rank_downgrade = 0 | |