Spaces:
Sleeping
Sleeping
import requests, wikipedia, re | |
from rank_bm25 import BM25Okapi | |
import streamlit as st | |
import pandas as pd | |
import spacy | |
#################################### | |
## Streamlit app helper functions ## | |
#################################### | |
def get_examples(): | |
""" | |
Function for loading example questions | |
and contexts from examples.csv | |
Parameters: None | |
----------- | |
Returns: | |
-------- | |
ex_queries, ex_questions, ex_contexts : list(str), list(list(str)), list(str) | |
Example search query, question, and context strings | |
(each entry of ex_questions is a list of three question strings) | |
""" | |
examples = pd.read_csv('examples.csv') | |
ex_questions = [q.split(':') for q in list(examples['question'])] | |
ex_contexts = list(examples['context']) | |
ex_queries = list(examples['query']) | |
return ex_queries, ex_questions, ex_contexts | |
def basic_clear_boxes(): | |
""" | |
Clears the question, context, response | |
""" | |
for field in ['question','context','response']: | |
st.session_state['basic'][field] = '' | |
def basic_ex_click(examples, i): | |
""" | |
Fills in the chosen example | |
""" | |
st.session_state['basic']['question'] = examples[1][i][0] | |
st.session_state['basic']['context'] = examples[2][i] | |
def semi_clear_query(): | |
""" | |
Clears the search query field | |
and page options list | |
""" | |
st.session_state['semi']['query'] = '' | |
for field in ['selected_pages','page_options']: | |
st.session_state['semi'][field] = [] | |
def semi_clear_question(): | |
""" | |
Clears the question and response field | |
and selected pages list | |
""" | |
for field in ['question','response']: | |
st.session_state['semi'][field] = '' | |
def semi_ex_query_click(examples,i): | |
""" | |
Fills in the query example and | |
populates question examples when query | |
example button is clicked | |
""" | |
st.session_state['semi']['query'] = examples[0][i] | |
st.session_state['semi']['ex_questions'] = examples[1][i] | |
def semi_ex_question_click(i): | |
""" | |
Fills in the question example | |
""" | |
st.session_state['semi']['question'] = st.session_state['semi']['ex_questions'][i] | |
def auto_clear_boxes(): | |
""" | |
Clears the response and question fields | |
""" | |
for field in ['question','response']: | |
st.session_state['auto'][field]='' | |
def auto_ex_click(examples,i): | |
""" | |
Fills in the chosen example question | |
""" | |
st.session_state['auto']['question'] = examples[1][i][0] | |
########################### | |
## Query helper function ## | |
########################### | |
def generate_query(nlp,text): | |
""" | |
Process text into a search query, | |
only retaining nouns, proper nouns, | |
numerals, verbs, and adjectives | |
Parameters: | |
----------- | |
nlp : spacy.Pipe | |
spaCy pipeline for processing search query | |
text : str | |
The input text to be processed | |
Returns: | |
-------- | |
query : str | |
The condensed search query | |
""" | |
tokens = nlp(text) | |
keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'} | |
query = ' '.join(token.text for token in tokens \ | |
if token.pos_ in keep) | |
return query | |
############################## | |
## Document retriever class ## | |
############################## | |
class ContextRetriever: | |
""" | |
Retrieves documents from Wikipedia based on a query, | |
and prepared context paragraphs for a RoBERTa model | |
""" | |
def __init__(self,url='https://en.wikipedia.org/w/api.php'): | |
self.url = url | |
self.pageids = None | |
self.pages = None | |
self.paragraphs = None | |
def get_pageids(self,query,topn = None): | |
""" | |
Retrieve page ids corresponding to a search query | |
Parameters: | |
----------- | |
query : str | |
A query to use for Wikipedia page search | |
topn : int or None | |
If topn is provided, will only return pageids | |
for topn search results | |
Returns: None, but stores: | |
-------- | |
self.pageids : list(tuple(int,str)) | |
A list of Wikipedia (pageid,title) tuples resulting | |
from the search | |
""" | |
params = { | |
'action':'query', | |
'list':'search', | |
'srsearch':query, | |
'format':'json', | |
} | |
results = requests.get(self.url, params=params).json() | |
pageids = [(page['pageid'],page['title']) for page in results['query']['search']] | |
pageids = pageids[:topn] | |
self.pageids = pageids | |
def ids_to_pages(self,ids): | |
""" | |
Use MediaWiki API to retrieve page content corresponding to | |
a list of pageids | |
Parameters: | |
----------- | |
ids : list(tuple(int,str)) | |
A list of Wikipedia (pageid,title) tuples | |
Returns: None, but stores | |
-------- | |
pages : list(tuple(str,str)) | |
The k-th enry is a tuple consisting of the title and page content | |
of the page corresponding to the k-th entry of ids | |
""" | |
pages = [] | |
for pageid in ids: | |
try: | |
page = wikipedia.page(pageid=pageid[0],auto_suggest=False) | |
pages.append((page.title, page.content)) | |
except wikipedia.DisambiguationError: | |
continue | |
return pages | |
def get_all_pages(self): | |
""" | |
Use MediaWiki API to retrieve page content corresponding to | |
the list of pageids in self.pageids | |
Parameters: None | |
----------- | |
Returns: None, but stores | |
-------- | |
self.pages : list(tuple(str,str)) | |
The k-th enry is a tuple consisting of the title and page content | |
of the page corresponding to the k-th entry of self.pageids | |
""" | |
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids" | |
self.pages = self.ids_to_pages(self.pageids) | |
def pages_to_paragraphs(self,pages): | |
""" | |
Process a list of pages into a list of paragraphs from those pages | |
Parameters: | |
----------- | |
pages : list(str) | |
A list of Wikipedia page content dumps, as strings | |
Returns: | |
-------- | |
paragraphs : dict | |
keys are titles of pages from pages (as strings) | |
paragraphs[page] is a list of paragraphs (as strings) | |
extracted from page | |
""" | |
# Content from WikiMedia has these headings. We only grab content appearing | |
# before the first instance of any of these | |
pattern = '|'.join([ | |
'== References ==', | |
'== Further reading ==', | |
'== External links', | |
'== See also ==', | |
'== Sources ==', | |
'== Notes ==', | |
'== Further references ==', | |
'== Footnotes ==', | |
'=== Notes ===', | |
'=== Sources ===', | |
'=== Citations ===', | |
]) | |
pattern = re.compile(pattern) | |
paragraphs = {} | |
for page in pages: | |
# Truncate page to the first index of the start of a matching heading, | |
# or the end of the page if no matches exist | |
title, content = page | |
idx = min([match.start() for match in pattern.finditer(content)]+[len(content)]) | |
content = content[:idx] | |
# Split into paragraphs, omitting lines with headings (start with '='), | |
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear | |
paragraphs[title] = [ | |
p for p in content.split('\n') if p \ | |
and not p.startswith('=') \ | |
and not p.startswith('\t\t') \ | |
and not p.startswith(' ') | |
] | |
return paragraphs | |
def get_all_paragraphs(self): | |
""" | |
Process self.pages into list of paragraphs from pages | |
Parameters: None | |
----------- | |
Returns: None, but stores | |
-------- | |
self.paragraphs : dict | |
keys are titles of pages from self.pages (as strings) | |
self.paragraphs[page] is a list of paragraphs (as strings) | |
extracted from page | |
""" | |
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages" | |
# Content from WikiMedia has these headings. We only grab content appearing | |
# before the first instance of any of these | |
self.paragraphs = self.pages_to_paragraphs(self.pages) | |
def rank_paragraphs(self,paragraphs,query,topn=10): | |
""" | |
Ranks the elements of paragraphs in descending order | |
by relevance to query using BM25 Okapi, and returns top | |
topn results | |
Parameters: | |
----------- | |
paragraphs : dict | |
keys are titles of pages (as strings) | |
paragraphs[page] is a list of paragraphs (as strings) | |
extracted from page | |
query : str | |
The query to use in ranking paragraphs by relevance | |
topn : int or None | |
The number of most relevant paragraphs to return | |
If None, will return roughly the top 1/4 of the | |
paragraphs | |
Returns: | |
-------- | |
best_paragraphs : list(list(str,str)) | |
The k-th entry is a list [title,paragraph] for the k-th | |
most relevant paragraph, where title is the title of the | |
Wikipedia article from which that paragraph was sourced | |
""" | |
corpus, titles, page_nums = [],[],[] | |
# Compile paragraphs into corpus | |
for i,page in enumerate(paragraphs): | |
titles.append(page) | |
paras = paragraphs[page] | |
corpus += paras | |
page_nums += len(paras)*[i] | |
# Tokenize corpus and query and initialize bm25 object | |
tokenized_corpus = [p.split(" ") for p in corpus] | |
bm25 = BM25Okapi(tokenized_corpus) | |
tokenized_query = query.split(" ") | |
# Compute scores and compile tuples (paragraph number, score, page number) | |
# before sorting tuples by score | |
bm_scores = bm25.get_scores(tokenized_query) | |
paragraph_data = [[i,score,page_nums[i]] for i,score in enumerate(bm_scores)] | |
paragraph_data.sort(reverse=True,key=lambda p:p[1]) | |
# Grab topn best [title,paragraph] pairs sorted by bm25 score | |
topn = len(paragraph_data)//4+1 if topn is None else min(topn,len(paragraph_data)) | |
best_paragraphs = [[titles[p[2]],corpus[p[0]]] for p in paragraph_data[:topn]] | |
return best_paragraphs | |
def generate_answer(pipeline,paragraphs, question): | |
""" | |
Generate an answer using a question-answer pipeline | |
Parameters: | |
----------- | |
pipeline : transformers.QuestionAnsweringPipeline | |
The question answering pipeline object | |
paragraphs : list(list(str,str)) | |
The k-th entry is a list [title,paragraph] consisting | |
of a context paragraph and the title of the page from which the | |
paragraph was sourced | |
question : str | |
A question that is to be answered based on context given | |
in the entries of paragraphs | |
Returns: | |
-------- | |
response : str | |
A response indicating the answer that was discovered, | |
or indicating that no answer could be found. | |
""" | |
# For each paragraph, format input to QA pipeline... | |
for paragraph in paragraphs: | |
input = { | |
'context':paragraph[1], | |
'question':question, | |
} | |
# ...and pass to QA pipeline | |
output = pipeline(**input) | |
# Append answers and scores. Report score of | |
# zero for paragraphs without answer, so they are | |
# deprioritized when the max is taken below | |
if output['answer']!='': | |
paragraph += [output['answer'],output['score']] | |
else: | |
paragraph += ['',0] | |
# Get paragraph with max confidence score and collect data | |
best_paragraph = max(paragraphs,key = lambda x:x[3]) | |
best_answer = best_paragraph[2] | |
best_context_page = best_paragraph[0] | |
best_context = best_paragraph[1] | |
# Update response in session state | |
if best_answer == "": | |
response = "I cannot find the answer to your question." | |
else: | |
response = f""" | |
My answer is: {best_answer} | |
...and here's where I found it: | |
Page title: {best_context_page} | |
Paragraph containing answer: | |
{best_context} | |
""" | |
return response |