roberta-wiki / lib /.ipynb_checkpoints /utils-checkpoint.py
etweedy's picture
Upload 9 files
ddadb1a
import requests, wikipedia, re
from rank_bm25 import BM25Okapi
import streamlit as st
import pandas as pd
import spacy
####################################
## Streamlit app helper functions ##
####################################
def get_examples():
"""
Function for loading example questions
and contexts from examples.csv
Parameters: None
-----------
Returns:
--------
ex_queries, ex_questions, ex_contexts : list(str), list(list(str)), list(str)
Example search query, question, and context strings
(each entry of ex_questions is a list of three question strings)
"""
examples = pd.read_csv('examples.csv')
ex_questions = [q.split(':') for q in list(examples['question'])]
ex_contexts = list(examples['context'])
ex_queries = list(examples['query'])
return ex_queries, ex_questions, ex_contexts
def basic_clear_boxes():
"""
Clears the question, context, response
"""
for field in ['question','context','response']:
st.session_state['basic'][field] = ''
def basic_ex_click(examples, i):
"""
Fills in the chosen example
"""
st.session_state['basic']['question'] = examples[1][i][0]
st.session_state['basic']['context'] = examples[2][i]
def semi_clear_query():
"""
Clears the search query field
and page options list
"""
st.session_state['semi']['query'] = ''
for field in ['selected_pages','page_options']:
st.session_state['semi'][field] = []
def semi_clear_question():
"""
Clears the question and response field
and selected pages list
"""
for field in ['question','response']:
st.session_state['semi'][field] = ''
def semi_ex_query_click(examples,i):
"""
Fills in the query example and
populates question examples when query
example button is clicked
"""
st.session_state['semi']['query'] = examples[0][i]
st.session_state['semi']['ex_questions'] = examples[1][i]
def semi_ex_question_click(i):
"""
Fills in the question example
"""
st.session_state['semi']['question'] = st.session_state['semi']['ex_questions'][i]
def auto_clear_boxes():
"""
Clears the response and question fields
"""
for field in ['question','response']:
st.session_state['auto'][field]=''
def auto_ex_click(examples,i):
"""
Fills in the chosen example question
"""
st.session_state['auto']['question'] = examples[1][i][0]
###########################
## Query helper function ##
###########################
def generate_query(nlp,text):
"""
Process text into a search query,
only retaining nouns, proper nouns,
numerals, verbs, and adjectives
Parameters:
-----------
nlp : spacy.Pipe
spaCy pipeline for processing search query
text : str
The input text to be processed
Returns:
--------
query : str
The condensed search query
"""
tokens = nlp(text)
keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
query = ' '.join(token.text for token in tokens \
if token.pos_ in keep)
return query
##############################
## Document retriever class ##
##############################
class ContextRetriever:
"""
Retrieves documents from Wikipedia based on a query,
and prepared context paragraphs for a RoBERTa model
"""
def __init__(self,url='https://en.wikipedia.org/w/api.php'):
self.url = url
self.pageids = None
self.pages = None
self.paragraphs = None
def get_pageids(self,query,topn = None):
"""
Retrieve page ids corresponding to a search query
Parameters:
-----------
query : str
A query to use for Wikipedia page search
topn : int or None
If topn is provided, will only return pageids
for topn search results
Returns: None, but stores:
--------
self.pageids : list(tuple(int,str))
A list of Wikipedia (pageid,title) tuples resulting
from the search
"""
params = {
'action':'query',
'list':'search',
'srsearch':query,
'format':'json',
}
results = requests.get(self.url, params=params).json()
pageids = [(page['pageid'],page['title']) for page in results['query']['search']]
pageids = pageids[:topn]
self.pageids = pageids
def ids_to_pages(self,ids):
"""
Use MediaWiki API to retrieve page content corresponding to
a list of pageids
Parameters:
-----------
ids : list(tuple(int,str))
A list of Wikipedia (pageid,title) tuples
Returns: None, but stores
--------
pages : list(tuple(str,str))
The k-th enry is a tuple consisting of the title and page content
of the page corresponding to the k-th entry of ids
"""
pages = []
for pageid in ids:
try:
page = wikipedia.page(pageid=pageid[0],auto_suggest=False)
pages.append((page.title, page.content))
except wikipedia.DisambiguationError:
continue
return pages
def get_all_pages(self):
"""
Use MediaWiki API to retrieve page content corresponding to
the list of pageids in self.pageids
Parameters: None
-----------
Returns: None, but stores
--------
self.pages : list(tuple(str,str))
The k-th enry is a tuple consisting of the title and page content
of the page corresponding to the k-th entry of self.pageids
"""
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
self.pages = self.ids_to_pages(self.pageids)
def pages_to_paragraphs(self,pages):
"""
Process a list of pages into a list of paragraphs from those pages
Parameters:
-----------
pages : list(str)
A list of Wikipedia page content dumps, as strings
Returns:
--------
paragraphs : dict
keys are titles of pages from pages (as strings)
paragraphs[page] is a list of paragraphs (as strings)
extracted from page
"""
# Content from WikiMedia has these headings. We only grab content appearing
# before the first instance of any of these
pattern = '|'.join([
'== References ==',
'== Further reading ==',
'== External links',
'== See also ==',
'== Sources ==',
'== Notes ==',
'== Further references ==',
'== Footnotes ==',
'=== Notes ===',
'=== Sources ===',
'=== Citations ===',
])
pattern = re.compile(pattern)
paragraphs = {}
for page in pages:
# Truncate page to the first index of the start of a matching heading,
# or the end of the page if no matches exist
title, content = page
idx = min([match.start() for match in pattern.finditer(content)]+[len(content)])
content = content[:idx]
# Split into paragraphs, omitting lines with headings (start with '='),
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
paragraphs[title] = [
p for p in content.split('\n') if p \
and not p.startswith('=') \
and not p.startswith('\t\t') \
and not p.startswith(' ')
]
return paragraphs
def get_all_paragraphs(self):
"""
Process self.pages into list of paragraphs from pages
Parameters: None
-----------
Returns: None, but stores
--------
self.paragraphs : dict
keys are titles of pages from self.pages (as strings)
self.paragraphs[page] is a list of paragraphs (as strings)
extracted from page
"""
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
# Content from WikiMedia has these headings. We only grab content appearing
# before the first instance of any of these
self.paragraphs = self.pages_to_paragraphs(self.pages)
def rank_paragraphs(self,paragraphs,query,topn=10):
"""
Ranks the elements of paragraphs in descending order
by relevance to query using BM25 Okapi, and returns top
topn results
Parameters:
-----------
paragraphs : dict
keys are titles of pages (as strings)
paragraphs[page] is a list of paragraphs (as strings)
extracted from page
query : str
The query to use in ranking paragraphs by relevance
topn : int or None
The number of most relevant paragraphs to return
If None, will return roughly the top 1/4 of the
paragraphs
Returns:
--------
best_paragraphs : list(list(str,str))
The k-th entry is a list [title,paragraph] for the k-th
most relevant paragraph, where title is the title of the
Wikipedia article from which that paragraph was sourced
"""
corpus, titles, page_nums = [],[],[]
# Compile paragraphs into corpus
for i,page in enumerate(paragraphs):
titles.append(page)
paras = paragraphs[page]
corpus += paras
page_nums += len(paras)*[i]
# Tokenize corpus and query and initialize bm25 object
tokenized_corpus = [p.split(" ") for p in corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = query.split(" ")
# Compute scores and compile tuples (paragraph number, score, page number)
# before sorting tuples by score
bm_scores = bm25.get_scores(tokenized_query)
paragraph_data = [[i,score,page_nums[i]] for i,score in enumerate(bm_scores)]
paragraph_data.sort(reverse=True,key=lambda p:p[1])
# Grab topn best [title,paragraph] pairs sorted by bm25 score
topn = len(paragraph_data)//4+1 if topn is None else min(topn,len(paragraph_data))
best_paragraphs = [[titles[p[2]],corpus[p[0]]] for p in paragraph_data[:topn]]
return best_paragraphs
def generate_answer(pipeline,paragraphs, question):
"""
Generate an answer using a question-answer pipeline
Parameters:
-----------
pipeline : transformers.QuestionAnsweringPipeline
The question answering pipeline object
paragraphs : list(list(str,str))
The k-th entry is a list [title,paragraph] consisting
of a context paragraph and the title of the page from which the
paragraph was sourced
question : str
A question that is to be answered based on context given
in the entries of paragraphs
Returns:
--------
response : str
A response indicating the answer that was discovered,
or indicating that no answer could be found.
"""
# For each paragraph, format input to QA pipeline...
for paragraph in paragraphs:
input = {
'context':paragraph[1],
'question':question,
}
# ...and pass to QA pipeline
output = pipeline(**input)
# Append answers and scores. Report score of
# zero for paragraphs without answer, so they are
# deprioritized when the max is taken below
if output['answer']!='':
paragraph += [output['answer'],output['score']]
else:
paragraph += ['',0]
# Get paragraph with max confidence score and collect data
best_paragraph = max(paragraphs,key = lambda x:x[3])
best_answer = best_paragraph[2]
best_context_page = best_paragraph[0]
best_context = best_paragraph[1]
# Update response in session state
if best_answer == "":
response = "I cannot find the answer to your question."
else:
response = f"""
My answer is: {best_answer}
...and here's where I found it:
Page title: {best_context_page}
Paragraph containing answer:
{best_context}
"""
return response