Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- app.py +362 -175
- examples.csv +4 -0
- lib/.DS_Store +0 -0
- lib/.ipynb_checkpoints/utils-checkpoint.py +287 -77
- lib/__pycache__/utils.cpython-310.pyc +0 -0
- lib/utils.py +286 -47
app.py
CHANGED
@@ -8,28 +8,17 @@ from transformers import (
|
|
8 |
pipeline,
|
9 |
)
|
10 |
import spacy
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
# * formatting input into document retrieval query (spaCy)
|
18 |
-
# * document retrieval based on query (wikipedia library)
|
19 |
-
# * document postprocessing into passages
|
20 |
-
# * ranking passage based on BM25 scores for query (rank_bm25)
|
21 |
-
# * feeding passages into RoBERTa an reporting answer(s) and passages as evidence
|
22 |
-
# decide what to do with examples
|
23 |
-
|
24 |
-
### CAN REMOVE:#####
|
25 |
-
# * context collection
|
26 |
-
# *
|
27 |
|
28 |
-
|
29 |
-
###
|
30 |
-
|
31 |
|
32 |
-
# Build trainer using model and tokenizer from Hugging Face repo
|
33 |
@st.cache_resource(show_spinner=False)
|
34 |
def get_pipeline():
|
35 |
"""
|
@@ -70,57 +59,6 @@ def get_spacy():
|
|
70 |
)
|
71 |
return nlp
|
72 |
|
73 |
-
def generate_query(nlp,text):
|
74 |
-
"""
|
75 |
-
Process text into a search query,
|
76 |
-
only retaining nouns, proper nouns,
|
77 |
-
numerals, verbs, and adjectives
|
78 |
-
Parameters:
|
79 |
-
-----------
|
80 |
-
nlp : spacy.Pipe
|
81 |
-
spaCy pipeline for processing search query
|
82 |
-
text : str
|
83 |
-
The input text to be processed
|
84 |
-
Returns:
|
85 |
-
--------
|
86 |
-
query : str
|
87 |
-
The condensed search query
|
88 |
-
"""
|
89 |
-
tokens = nlp(text)
|
90 |
-
keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
|
91 |
-
query = ' '.join(token.text for token in tokens \
|
92 |
-
if token.pos_ in keep)
|
93 |
-
return query
|
94 |
-
|
95 |
-
def fill_in_example(i):
|
96 |
-
"""
|
97 |
-
Function for context-question example button click
|
98 |
-
"""
|
99 |
-
st.session_state['response'] = ''
|
100 |
-
st.session_state['question'] = ex_q[i]
|
101 |
-
|
102 |
-
def clear_boxes():
|
103 |
-
"""
|
104 |
-
Function for field clear button click
|
105 |
-
"""
|
106 |
-
st.session_state['response'] = ''
|
107 |
-
st.session_state['question'] = ''
|
108 |
-
|
109 |
-
# def get_examples():
|
110 |
-
# """
|
111 |
-
# Retrieve pre-made examples from a .csv file
|
112 |
-
# Parameters: None
|
113 |
-
# -----------
|
114 |
-
# Returns:
|
115 |
-
# --------
|
116 |
-
# questions, contexts : list, list
|
117 |
-
# Lists of examples of corresponding question-context pairs
|
118 |
-
|
119 |
-
# """
|
120 |
-
# examples = pd.read_csv('examples.csv')
|
121 |
-
# questions = list(examples['question'])
|
122 |
-
# return questions
|
123 |
-
|
124 |
#############
|
125 |
### Setup ###
|
126 |
#############
|
@@ -134,40 +72,71 @@ else:
|
|
134 |
device = "cpu"
|
135 |
|
136 |
# Initialize session state variables
|
137 |
-
|
138 |
-
st.session_state
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
with st.spinner('Loading the model...'):
|
145 |
qa_pipeline = get_pipeline()
|
146 |
nlp = get_spacy()
|
147 |
|
148 |
-
#
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
st.
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
3. RoBERTa will search the best candidate passages to find the answer to your question
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
st.markdown('''
|
|
|
|
|
|
|
|
|
169 |
* [Click here](https://huggingface.co/etweedy/roberta-base-squad-v2) to visit the Hugging Face model card for this fine-tuned model.
|
170 |
-
* To create this model, the [RoBERTa base model](https://huggingface.co/roberta-base)
|
171 |
* The objective of the model is "extractive question answering", the task of retrieving the answer to the question from a given context text corpus.
|
172 |
* SQuAD Version 2 incorporates the 100,000 samples from Version 1.1, along with 50,000 'unanswerable' questions, i.e. samples in the question cannot be answered using the context given.
|
173 |
* The original base RoBERTa model was introduced in [this paper](https://arxiv.org/abs/1907.11692) and [this repository](https://github.com/facebookresearch/fairseq/tree/main/examples/roberta). Here's a citation for that base model:
|
@@ -195,102 +164,320 @@ with st.expander('Click to read more about the model...'):
|
|
195 |
bibsource = {dblp computer science bibliography, https://dblp.org}
|
196 |
}
|
197 |
```
|
198 |
-
''')
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
|
|
206 |
|
207 |
-
|
208 |
-
# example_container = st.container()
|
209 |
-
input_container = st.container()
|
210 |
-
button_container = st.container()
|
211 |
-
response_container = st.container()
|
212 |
|
213 |
-
|
214 |
-
###
|
215 |
-
|
|
|
|
|
216 |
|
217 |
-
|
218 |
-
# with example_container:
|
219 |
-
# ex_cols = st.columns(len(ex_q)+1)
|
220 |
-
# for i in range(len(ex_q)):
|
221 |
-
# with ex_cols[i]:
|
222 |
-
# st.button(
|
223 |
-
# label = f'Try example {i+1}',
|
224 |
-
# key = f'ex_button_{i+1}',
|
225 |
-
# on_click = fill_in_example,
|
226 |
-
# args=(i,),
|
227 |
-
# )
|
228 |
-
# with ex_cols[-1]:
|
229 |
-
# st.button(
|
230 |
-
# label = "Clear all fields",
|
231 |
-
# key = "clear_button",
|
232 |
-
# on_click = clear_boxes,
|
233 |
-
# )
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
# Question input field
|
239 |
question = st.text_input(
|
240 |
label='Question',
|
241 |
-
value=st.session_state['question'],
|
242 |
-
key='
|
243 |
-
label_visibility='
|
244 |
placeholder='Enter your question here.',
|
245 |
)
|
246 |
# Form submit button
|
247 |
query_submitted = st.form_submit_button("Submit")
|
248 |
-
if query_submitted:
|
249 |
# update question, context in session state
|
250 |
-
st.session_state['question'] = question
|
251 |
-
|
252 |
-
query = generate_query(nlp,question)
|
253 |
-
retriever = ContextRetriever()
|
254 |
-
retriever.get_pageids(query)
|
255 |
-
retriever.get_pages()
|
256 |
-
retriever.get_paragraphs()
|
257 |
-
retriever.rank_paragraphs(query)
|
258 |
with st.spinner('Generating response...'):
|
259 |
-
#
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
st.session_state['response'] = "I cannot find the answer to your question."
|
276 |
-
else:
|
277 |
-
st.session_state['response'] = f"""
|
278 |
-
My answer is: {best_answer}
|
279 |
-
|
280 |
-
...and here's where I found it:
|
281 |
-
|
282 |
-
{best_context}
|
283 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
label = "Clear all fields",
|
289 |
-
key = "
|
290 |
-
on_click =
|
291 |
)
|
292 |
|
293 |
-
|
294 |
-
with
|
295 |
-
st.
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
pipeline,
|
9 |
)
|
10 |
import spacy
|
11 |
+
from lib.utils import (
|
12 |
+
ContextRetriever,
|
13 |
+
get_examples,
|
14 |
+
generate_query,
|
15 |
+
generate_answer,
|
16 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
#################################
|
19 |
+
### Model retrieval functions ###
|
20 |
+
#################################
|
21 |
|
|
|
22 |
@st.cache_resource(show_spinner=False)
|
23 |
def get_pipeline():
|
24 |
"""
|
|
|
59 |
)
|
60 |
return nlp
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
#############
|
63 |
### Setup ###
|
64 |
#############
|
|
|
72 |
device = "cpu"
|
73 |
|
74 |
# Initialize session state variables
|
75 |
+
for tab in ['basic','semi','auto']:
|
76 |
+
if tab not in st.session_state:
|
77 |
+
st.session_state[tab] = {}
|
78 |
+
for field in ['question','context','query','response']:
|
79 |
+
if field not in st.session_state[tab]:
|
80 |
+
st.session_state[tab][field] = ''
|
81 |
+
for field in ['page_options','selected_pages']:
|
82 |
+
if field not in st.session_state['semi']:
|
83 |
+
st.session_state['semi'][field] = []
|
84 |
+
|
85 |
+
# Retrieve models
|
86 |
with st.spinner('Loading the model...'):
|
87 |
qa_pipeline = get_pipeline()
|
88 |
nlp = get_spacy()
|
89 |
|
90 |
+
# Retrieve example questions and contexts
|
91 |
+
examples = get_examples()
|
92 |
+
# ex_queries, ex_questions, ex_contexts = get_examples()
|
93 |
+
if 'ex_questions' not in st.session_state['semi']:
|
94 |
+
st.session_state['semi']['ex_questions'] = len(examples[1][0])*['']
|
95 |
+
|
96 |
+
################################
|
97 |
+
### Initialize App Structure ###
|
98 |
+
################################
|
99 |
|
100 |
+
tabs = st.tabs([
|
101 |
+
'RoBERTa Q&A model',
|
102 |
+
'Basic extractive Q&A',
|
103 |
+
'User-guided Wiki Q&A',
|
104 |
+
'Automated Wiki Q&A',
|
105 |
+
])
|
106 |
|
107 |
+
with tabs[0]:
|
108 |
+
intro_container = st.container()
|
109 |
+
with tabs[1]:
|
110 |
+
basic_title_container = st.container()
|
111 |
+
basic_example_container = st.container()
|
112 |
+
basic_input_container = st.container()
|
113 |
+
basic_response_container = st.container()
|
114 |
+
with tabs[2]:
|
115 |
+
semi_title_container = st.container()
|
116 |
+
semi_query_container = st.container()
|
117 |
+
semi_page_container = st.container()
|
118 |
+
semi_input_container = st.container()
|
119 |
+
semi_response_container = st.container()
|
120 |
+
with tabs[3]:
|
121 |
+
auto_title_container = st.container()
|
122 |
+
auto_example_container = st.container()
|
123 |
+
auto_input_container = st.container()
|
124 |
+
auto_response_container = st.container()
|
125 |
|
126 |
+
##############################
|
127 |
+
### Populate tab - Welcome ###
|
128 |
+
##############################
|
|
|
129 |
|
130 |
+
with intro_container:
|
131 |
+
# Intro text
|
132 |
+
st.header('RoBERTa Q&A model')
|
133 |
st.markdown('''
|
134 |
+
This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
|
135 |
+
''')
|
136 |
+
with st.expander('Click to read more about the model...'):
|
137 |
+
st.markdown('''
|
138 |
* [Click here](https://huggingface.co/etweedy/roberta-base-squad-v2) to visit the Hugging Face model card for this fine-tuned model.
|
139 |
+
* To create this model, I fine-tuned the [RoBERTa base model](https://huggingface.co/roberta-base) Version 2 of [SQuAD (Stanford Question Answering Dataset)](https://huggingface.co/datasets/squad_v2), a dataset of context-question-answer triples.
|
140 |
* The objective of the model is "extractive question answering", the task of retrieving the answer to the question from a given context text corpus.
|
141 |
* SQuAD Version 2 incorporates the 100,000 samples from Version 1.1, along with 50,000 'unanswerable' questions, i.e. samples in the question cannot be answered using the context given.
|
142 |
* The original base RoBERTa model was introduced in [this paper](https://arxiv.org/abs/1907.11692) and [this repository](https://github.com/facebookresearch/fairseq/tree/main/examples/roberta). Here's a citation for that base model:
|
|
|
164 |
bibsource = {dblp computer science bibliography, https://dblp.org}
|
165 |
}
|
166 |
```
|
167 |
+
''')
|
168 |
+
st.markdown('''
|
169 |
+
Use the menu on the left side to navigate between different app components:
|
170 |
+
1. A basic Q&A tool which allows the user to ask the model to search a user-provided context paragraph for the answer to a user-provided question.
|
171 |
+
2. A user-guided Wiki Q&A tool which allows the user to search for one or more Wikipedia pages and ask the model to search those pages for the answer to a user-provided question.
|
172 |
+
3. An automated Wiki Q&A tool which asks the model to perform retrieve its own Wikipedia pages in order to answer the user-provided question.
|
173 |
+
''')
|
174 |
|
175 |
+
################################
|
176 |
+
### Populate tab - basic Q&A ###
|
177 |
+
################################
|
178 |
|
179 |
+
from lib.utils import basic_clear_boxes, basic_ex_click
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
with basic_title_container:
|
182 |
+
### Intro text ###
|
183 |
+
st.header('Basic extractive Q&A')
|
184 |
+
st.markdown('''
|
185 |
+
The basic functionality of a RoBERTa model for extractive question-answering is to attempt to extract the answer to a user-provided question from a piece of user-provided context text. The model is also trained to recognize when the context doesn't provide the answer.
|
186 |
|
187 |
+
Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question based on the context you provided, or report that it cannot find the answer in the context. Your results will appear below the question field when the model is finished running.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
Alternatively, you can try an example by clicking one of the buttons below:
|
190 |
+
''')
|
191 |
+
|
192 |
+
### Populate example button container ###
|
193 |
+
with basic_example_container:
|
194 |
+
basic_ex_cols = st.columns(len(examples[0])+1)
|
195 |
+
for i in range(len(examples[0])):
|
196 |
+
with basic_ex_cols[i]:
|
197 |
+
st.button(
|
198 |
+
label = f'example {i+1}',
|
199 |
+
key = f'basic_ex_button_{i+1}',
|
200 |
+
on_click = basic_ex_click,
|
201 |
+
args = (examples,i,),
|
202 |
+
)
|
203 |
+
with basic_ex_cols[-1]:
|
204 |
+
st.button(
|
205 |
+
label = "Clear all fields",
|
206 |
+
key = "basic_clear_button",
|
207 |
+
on_click = basic_clear_boxes,
|
208 |
+
)
|
209 |
+
### Populate user input container ###
|
210 |
+
with basic_input_container:
|
211 |
+
with st.form(key='basic_input_form',clear_on_submit=False):
|
212 |
+
# Context input field
|
213 |
+
context = st.text_area(
|
214 |
+
label='Context',
|
215 |
+
value=st.session_state['basic']['context'],
|
216 |
+
key='basic_context_field',
|
217 |
+
label_visibility='collapsed',
|
218 |
+
placeholder='Enter your context paragraph here.',
|
219 |
+
height=300,
|
220 |
+
)
|
221 |
# Question input field
|
222 |
question = st.text_input(
|
223 |
label='Question',
|
224 |
+
value=st.session_state['basic']['question'],
|
225 |
+
key='basic_question_field',
|
226 |
+
label_visibility='collapsed',
|
227 |
placeholder='Enter your question here.',
|
228 |
)
|
229 |
# Form submit button
|
230 |
query_submitted = st.form_submit_button("Submit")
|
231 |
+
if query_submitted and question!= '':
|
232 |
# update question, context in session state
|
233 |
+
st.session_state['basic']['question'] = question
|
234 |
+
st.session_state['basic']['context'] = context
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
with st.spinner('Generating response...'):
|
236 |
+
# Generate dictionary from inputs
|
237 |
+
query = {
|
238 |
+
'context':st.session_state['basic']['context'],
|
239 |
+
'question':st.session_state['basic']['question'],
|
240 |
+
}
|
241 |
+
# Pass to QA pipeline
|
242 |
+
response = qa_pipeline(**query)
|
243 |
+
answer = response['answer']
|
244 |
+
confidence = response['score']
|
245 |
+
# Reformat empty answer to message
|
246 |
+
if answer == '':
|
247 |
+
answer = "I don't have an answer based on the context provided."
|
248 |
+
# Update response in session state
|
249 |
+
st.session_state['basic']['response'] = f"""
|
250 |
+
Answer: {answer}\n
|
251 |
+
Confidence: {confidence:.2%}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
"""
|
253 |
+
### Populate response container ###
|
254 |
+
with basic_response_container:
|
255 |
+
st.write(st.session_state['basic']['response'])
|
256 |
+
|
257 |
+
#################################
|
258 |
+
### Populate tab - guided Q&A ###
|
259 |
+
#################################
|
260 |
|
261 |
+
from lib.utils import (
|
262 |
+
semi_ex_query_click,
|
263 |
+
semi_ex_question_click,
|
264 |
+
semi_clear_query,
|
265 |
+
semi_clear_question,
|
266 |
+
)
|
267 |
+
|
268 |
+
### Intro text ###
|
269 |
+
with semi_title_container:
|
270 |
+
st.header('User-guided Wiki Q&A')
|
271 |
+
st.markdown('''
|
272 |
+
This component allows you to perform a Wikipedia search for source material to feed as contexts to the RoBERTa question-answering model.
|
273 |
+
''')
|
274 |
+
with st.expander("Click here to find out what's happening behind the scenes..."):
|
275 |
+
st.markdown('''
|
276 |
+
1. A Wikipedia search is performed using your query, resulting in a list of pages which then populate the drop-down menu.
|
277 |
+
2. The pages you select are retrieved and broken up into paragraphs. Wikipedia queries and page collection use the [wikipedia library](https://pypi.org/project/wikipedia/), a wrapper for the [MediaWiki API](https://www.mediawiki.org/wiki/API).
|
278 |
+
3. The paragraphs are ranked in descending order of relevance to your question, using the [Okapi BM25 score](https://en.wikipedia.org/wiki/Okapi_BM25) as implemented in the [rank_bm25 library](https://github.com/dorianbrown/rank_bm25).
|
279 |
+
4. Among these ranked paragraphs, approximately the top 25% are fed as context to the RoBERTa model, from which it will attempt to extract the answer to your question. The 'hit' having the highest confidence (prediction probability) from the model is reported as the answer.
|
280 |
+
''')
|
281 |
+
|
282 |
+
### Populate query container ###
|
283 |
+
with semi_query_container:
|
284 |
+
st.markdown('First submit a search query, or choose one of the examples.')
|
285 |
+
semi_query_cols = st.columns(len(examples[0])+1)
|
286 |
+
# Buttons for query examples
|
287 |
+
for i in range(len(examples[0])):
|
288 |
+
with semi_query_cols[i]:
|
289 |
+
st.button(
|
290 |
+
label = f'query {i+1}',
|
291 |
+
key = f'semi_query_button_{i+1}',
|
292 |
+
on_click = semi_ex_query_click,
|
293 |
+
args=(examples,i,),
|
294 |
+
)
|
295 |
+
# Button for clearning query field
|
296 |
+
with semi_query_cols[-1]:
|
297 |
+
st.button(
|
298 |
+
label = "Clear query",
|
299 |
+
key = "semi_clear_query",
|
300 |
+
on_click = semi_clear_query,
|
301 |
+
)
|
302 |
+
# Search query input form
|
303 |
+
with st.form(key='semi_query_form',clear_on_submit=False):
|
304 |
+
query = st.text_input(
|
305 |
+
label='Search query',
|
306 |
+
value=st.session_state['semi']['query'],
|
307 |
+
key='semi_query_field',
|
308 |
+
label_visibility='collapsed',
|
309 |
+
placeholder='Enter your Wikipedia search query here.',
|
310 |
+
)
|
311 |
+
query_submitted = st.form_submit_button("Submit")
|
312 |
+
|
313 |
+
if query_submitted and query != '':
|
314 |
+
st.session_state['semi']['query'] = query
|
315 |
+
# Retrieve Wikipedia page list from
|
316 |
+
# search results and store in session state
|
317 |
+
with st.spinner('Retrieving Wiki pages...'):
|
318 |
+
retriever = ContextRetriever()
|
319 |
+
retriever.get_pageids(query)
|
320 |
+
st.session_state['semi']['page_options'] = retriever.pageids
|
321 |
+
st.session_state['semi']['selected_pages'] = []
|
322 |
+
|
323 |
+
### Populate page selection container ###
|
324 |
+
with semi_page_container:
|
325 |
+
st.markdown('Next select any number of Wikipedia pages to provide to RoBERTa:')
|
326 |
+
# Page title selection form
|
327 |
+
with st.form(key='semi_page_form',clear_on_submit=False):
|
328 |
+
selected_pages = st.multiselect(
|
329 |
+
label = "Choose Wiki pages for Q&A model:",
|
330 |
+
options = st.session_state['semi']['page_options'],
|
331 |
+
default = st.session_state['semi']['selected_pages'],
|
332 |
+
label_visibility = 'collapsed',
|
333 |
+
key = "semi_page_selectbox",
|
334 |
+
format_func = lambda x:x[1],
|
335 |
+
)
|
336 |
+
pages_submitted = st.form_submit_button("Submit")
|
337 |
+
if pages_submitted:
|
338 |
+
st.session_state['semi']['selected_pages'] = selected_pages
|
339 |
+
|
340 |
+
### Populate question input container ###
|
341 |
+
with semi_input_container:
|
342 |
+
st.markdown('Finally submit a question for RoBERTa to answer based on the above pages or choose one of the examples.')
|
343 |
+
# Question example buttons
|
344 |
+
semi_ques_cols = st.columns(len(examples[0])+1)
|
345 |
+
for i in range(len(examples[0])):
|
346 |
+
with semi_ques_cols[i]:
|
347 |
+
st.button(
|
348 |
+
label = f'question {i+1}',
|
349 |
+
key = f'semi_ques_button_{i+1}',
|
350 |
+
on_click = semi_ex_question_click,
|
351 |
+
args=(i,),
|
352 |
+
)
|
353 |
+
# Question field clear button
|
354 |
+
with semi_ques_cols[-1]:
|
355 |
+
st.button(
|
356 |
+
label = "Clear question",
|
357 |
+
key = "semi_clear_question",
|
358 |
+
on_click = semi_clear_question,
|
359 |
+
)
|
360 |
+
# Question submission form
|
361 |
+
with st.form(key = "semi_question_form",clear_on_submit=False):
|
362 |
+
question = st.text_input(
|
363 |
+
label='Question',
|
364 |
+
value=st.session_state['semi']['question'],
|
365 |
+
key='semi_question_field',
|
366 |
+
label_visibility='collapsed',
|
367 |
+
placeholder='Enter your question here.',
|
368 |
+
)
|
369 |
+
question_submitted = st.form_submit_button("Submit")
|
370 |
+
if question_submitted and len(question)>0 and len(st.session_state['semi']['selected_pages'])>0:
|
371 |
+
st.session_state['semi']['response'] = ''
|
372 |
+
st.session_state['semi']['question'] = question
|
373 |
+
# Retrieve pages corresponding to user selections,
|
374 |
+
# extract paragraphs, and retrieve top 10 paragraphs,
|
375 |
+
# ranked by relevance to user question
|
376 |
+
with st.spinner("Retrieving documentation..."):
|
377 |
+
retriever = ContextRetriever()
|
378 |
+
pages = retriever.ids_to_pages(selected_pages)
|
379 |
+
paragraphs = retriever.pages_to_paragraphs(pages)
|
380 |
+
best_paragraphs = retriever.rank_paragraphs(
|
381 |
+
paragraphs, question,
|
382 |
+
)
|
383 |
+
with st.spinner("Generating response..."):
|
384 |
+
# Generate a response and update the session state
|
385 |
+
response = generate_answer(
|
386 |
+
pipeline = qa_pipeline,
|
387 |
+
paragraphs = best_paragraphs,
|
388 |
+
question = st.session_state['semi']['question'],
|
389 |
+
)
|
390 |
+
st.session_state['semi']['response'] = response
|
391 |
+
|
392 |
+
### Populate response container ###
|
393 |
+
with semi_response_container:
|
394 |
+
st.write(st.session_state['semi']['response'])
|
395 |
+
|
396 |
+
####################################
|
397 |
+
### Populate tab - automated Q&A ###
|
398 |
+
####################################
|
399 |
+
|
400 |
+
from lib.utils import auto_ex_click, auto_clear_boxes
|
401 |
+
|
402 |
+
### Intro text ###
|
403 |
+
with auto_title_container:
|
404 |
+
st.header('Automated Wiki Q&A')
|
405 |
+
st.markdown('''
|
406 |
+
This component attempts to automate the Wiki-assisted extractive question-answering task. A Wikipedia search will be performed based on your question, and a list of relevant paragraphs will be passed to the RoBERTa model so it can attempt to find an answer.
|
407 |
+
''')
|
408 |
+
with st.expander("Click here to find out what's happening behind the scenes..."):
|
409 |
+
st.markdown('''
|
410 |
+
When you submit a question, the following steps are performed:
|
411 |
+
1. Your question is condensed into a search query which just retains nouns, verbs, numerals, and adjectives, where part-of-speech tagging is done using the [en_core_web_sm](https://spacy.io/models/en#en_core_web_sm) pipeline in the [spaCy library](https://spacy.io/).
|
412 |
+
2. A Wikipedia search is performed using this query, resulting in several articles. The articles from the top 3 search results are collected and split into paragraphs. Wikipedia queries and article collection use the [wikipedia library](https://pypi.org/project/wikipedia/), a wrapper for the [MediaWiki API](https://www.mediawiki.org/wiki/API).
|
413 |
+
4. The paragraphs are ranked in descending order of relevance to the query, using the [Okapi BM25 score](https://en.wikipedia.org/wiki/Okapi_BM25) as implemented in the [rank_bm25 library](https://github.com/dorianbrown/rank_bm25).
|
414 |
+
5. The ten most relevant paragraphs are fed as context to the RoBERTa model, from which it will attempt to extract the answer to your question. The 'hit' having the highest confidence (prediction probability) from the model is reported as the answer.
|
415 |
+
''')
|
416 |
+
|
417 |
+
st.markdown('''
|
418 |
+
Please provide a question you'd like the model to try to answer. The model will report back its answer, as well as an excerpt of text from Wikipedia in which it found its answer. Your result will appear below the question field when the model is finished running.
|
419 |
+
|
420 |
+
Alternatively, you can try an example by clicking one of the buttons below:
|
421 |
+
''')
|
422 |
+
|
423 |
+
### Populate example container ###
|
424 |
+
with auto_example_container:
|
425 |
+
auto_ex_cols = st.columns(len(examples[0])+1)
|
426 |
+
# Buttons for selecting example questions
|
427 |
+
for i in range(len(examples[0])):
|
428 |
+
with auto_ex_cols[i]:
|
429 |
+
st.button(
|
430 |
+
label = f'example {i+1}',
|
431 |
+
key = f'auto_ex_button_{i+1}',
|
432 |
+
on_click = auto_ex_click,
|
433 |
+
args=(examples,i,),
|
434 |
+
)
|
435 |
+
# Button for clearing question field and response
|
436 |
+
with auto_ex_cols[-1]:
|
437 |
+
st.button(
|
438 |
label = "Clear all fields",
|
439 |
+
key = "auto_clear_button",
|
440 |
+
on_click = auto_clear_boxes,
|
441 |
)
|
442 |
|
443 |
+
### Populate user input container ###
|
444 |
+
with auto_input_container:
|
445 |
+
with st.form(key='auto_input_form',clear_on_submit=False):
|
446 |
+
# Question input field
|
447 |
+
question = st.text_input(
|
448 |
+
label='Question',
|
449 |
+
value=st.session_state['auto']['question'],
|
450 |
+
key='auto_question_field',
|
451 |
+
label_visibility='collapsed',
|
452 |
+
placeholder='Enter your question here.',
|
453 |
+
)
|
454 |
+
# Form submit button
|
455 |
+
question_submitted = st.form_submit_button("Submit")
|
456 |
+
if question_submitted:
|
457 |
+
# update question, context in session state
|
458 |
+
st.session_state['auto']['question'] = question
|
459 |
+
query = generate_query(nlp,question)
|
460 |
+
# query == '' will throw error in document retrieval
|
461 |
+
if len(query)==0:
|
462 |
+
st.session_state['auto']['response'] = 'Please include some nouns, verbs, and/or adjectives in your question.'
|
463 |
+
elif len(question)>0:
|
464 |
+
with st.spinner('Retrieving documentation...'):
|
465 |
+
# Retrieve ids from top 3 results
|
466 |
+
retriever = ContextRetriever()
|
467 |
+
retriever.get_pageids(query,topn=3)
|
468 |
+
# Retrieve pages then paragraphs
|
469 |
+
retriever.get_all_pages()
|
470 |
+
retriever.get_all_paragraphs()
|
471 |
+
# Get top 10 paragraphs, ranked by relevance to query
|
472 |
+
best_paragraphs = retriever.rank_paragraphs(retriever.paragraphs, query)
|
473 |
+
with st.spinner('Generating response...'):
|
474 |
+
# Generate a response and update the session state
|
475 |
+
response = generate_answer(
|
476 |
+
pipeline = qa_pipeline,
|
477 |
+
paragraphs = best_paragraphs,
|
478 |
+
question = st.session_state['auto']['question'],
|
479 |
+
)
|
480 |
+
st.session_state['auto']['response'] = response
|
481 |
+
### Populate response container ###
|
482 |
+
with auto_response_container:
|
483 |
+
st.write(st.session_state['auto']['response'])
|
examples.csv
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question,context,query
|
2 |
+
What did Robert Oppenheimer remark after the Trinity test?:When was Robert Oppenheimer born?:What did Robert Oppenheimer do after World War II?,"Oppenheimer attended Harvard University, where he earned a bachelor's degree in chemistry in 1925. He studied physics at the University of Cambridge and University of Göttingen, where he received his PhD in 1927. He held academic positions at the University of California, Berkeley, and the California Institute of Technology, and made significant contributions to theoretical physics, including in quantum mechanics and nuclear physics. During World War II, he was recruited to work on the Manhattan Project, and in 1943 was appointed as director of the Los Alamos Laboratory in New Mexico, tasked with developing the weapons. Oppenheimer's leadership and scientific expertise were instrumental in the success of the project. He was among those who observed the Trinity test on July 16, 1945, in which the first atomic bomb was successfully detonated. He later remarked that the explosion brought to his mind words from the Hindu scripture Bhagavad Gita: ""Now I am become Death, the destroyer of worlds."" In August 1945, the atomic bombs were used on the Japanese cities of Hiroshima and Nagasaki, the only use of nuclear weapons in war.",Oppenheimer
|
3 |
+
Why did Twinkies change to vanilla cream?:When were Twinkies invented?:What is the urban legend about Twinkies?,"Twinkies were invented on April 6, 1930, by Canadian-born baker James Alexander Dewar for the Continental Baking Company in Schiller Park, Illinois. Realizing that several machines used for making cream-filled strawberry shortcake sat idle when strawberries were out of season, Dewar conceived a snack cake filled with banana cream, which he dubbed the Twinkie. Ritchy Koph said he came up with the name when he saw a billboard in St. Louis for ""Twinkle Toe Shoes"". During World War II, bananas were rationed, and the company was forced to switch to vanilla cream. This change proved popular, and banana-cream Twinkies were not widely re-introduced. The original flavor was occasionally found in limited time only promotions, but the company used vanilla cream for most Twinkies. In 1988, Fruit and Cream Twinkies were introduced with a strawberry filling swirled into the cream. The product was soon dropped. Vanilla's dominance over banana flavoring was challenged in 2005, following a month-long promotion of the movie King Kong. Hostess saw its Twinkie sales rise 20 percent during the promotion, and in 2007 restored the banana-cream Twinkie to its snack lineup although they are now made with 2% banana purée.",Twinkie
|
4 |
+
When was Pinkfong founded?:What is Pinkfong's most famous song?:Who represents the company?,"""Baby Shark"" is a children's song associated with a dance involving hand movements that originated as a campfire song dating back to at least the 20th century. In 2016, ""Baby Shark"" became very popular when Pinkfong, a South Korean entertainment company, released a version of the song with a YouTube music video that went viral across social media, online video, and radio. In January 2022, it became the first YouTube video to reach 10 billion views. In November 2020, Pinkfong's version became the most-viewed YouTube video of all time, with over 12 billion views as of April 2023. ""Baby Shark"" originated as a campfire song or chant. The original song dates back to at least the 20th century, potentially created by camp counselors inspired by the movie Jaws. In the chant, each member of a family of sharks is introduced, with campers using their hands to imitate the sharks' jaws. Different versions of the song have the sharks hunting fish, eating a sailor, or killing people who then go to heaven. Various entities have copyrighted original videos and sound recordings of the song, and some have trademarked merchandise based on their versions. However, according to The New York Times, the underlying song and characters are believed to be in the public domain.",Pinkfong
|
lib/.DS_Store
CHANGED
Binary files a/lib/.DS_Store and b/lib/.DS_Store differ
|
|
lib/.ipynb_checkpoints/utils-checkpoint.py
CHANGED
@@ -1,42 +1,119 @@
|
|
1 |
-
import requests, wikipedia, re
|
2 |
from rank_bm25 import BM25Okapi
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
class ContextRetriever:
|
42 |
"""
|
@@ -49,17 +126,21 @@ class ContextRetriever:
|
|
49 |
self.pages = None
|
50 |
self.paragraphs = None
|
51 |
|
52 |
-
def get_pageids(self,query):
|
53 |
"""
|
54 |
Retrieve page ids corresponding to a search query
|
55 |
Parameters:
|
56 |
-----------
|
57 |
query : str
|
58 |
A query to use for Wikipedia page search
|
|
|
|
|
|
|
59 |
Returns: None, but stores:
|
60 |
--------
|
61 |
-
self.pageids : list(int)
|
62 |
-
A list of Wikipedia
|
|
|
63 |
"""
|
64 |
params = {
|
65 |
'action':'query',
|
@@ -68,39 +149,62 @@ class ContextRetriever:
|
|
68 |
'format':'json',
|
69 |
}
|
70 |
results = requests.get(self.url, params=params).json()
|
71 |
-
pageids = [page['pageid'] for page in results['query']['search']]
|
|
|
72 |
self.pageids = pageids
|
73 |
-
|
74 |
-
def
|
75 |
"""
|
76 |
Use MediaWiki API to retrieve page content corresponding to
|
77 |
-
|
78 |
-
Parameters:
|
79 |
-----------
|
|
|
|
|
80 |
Returns: None, but stores
|
81 |
--------
|
82 |
-
|
83 |
-
|
|
|
84 |
"""
|
85 |
-
|
86 |
-
|
87 |
-
for pageid in self.pageids:
|
88 |
try:
|
89 |
-
|
90 |
-
|
|
|
91 |
continue
|
92 |
-
|
93 |
-
|
|
|
94 |
"""
|
95 |
-
|
|
|
96 |
Parameters: None
|
97 |
-----------
|
98 |
Returns: None, but stores
|
99 |
--------
|
100 |
-
self.
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
"""
|
103 |
-
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
|
104 |
# Content from WikiMedia has these headings. We only grab content appearing
|
105 |
# before the first instance of any of these
|
106 |
pattern = '|'.join([
|
@@ -117,38 +221,144 @@ class ContextRetriever:
|
|
117 |
'=== Citations ===',
|
118 |
])
|
119 |
pattern = re.compile(pattern)
|
120 |
-
paragraphs =
|
121 |
-
for page in
|
122 |
# Truncate page to the first index of the start of a matching heading,
|
123 |
# or the end of the page if no matches exist
|
124 |
-
|
125 |
-
|
|
|
126 |
# Split into paragraphs, omitting lines with headings (start with '='),
|
127 |
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
|
128 |
-
paragraphs
|
129 |
-
p for p in
|
130 |
and not p.startswith('=') \
|
131 |
-
and not p.startswith('\t\t')
|
|
|
132 |
]
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
-
def rank_paragraphs(self,query,topn=10):
|
136 |
"""
|
137 |
-
Ranks the elements of
|
138 |
-
by relevance to query using
|
|
|
139 |
Parameters:
|
140 |
-----------
|
|
|
|
|
|
|
|
|
141 |
query : str
|
142 |
The query to use in ranking paragraphs by relevance
|
143 |
-
topn : int
|
144 |
The number of most relevant paragraphs to return
|
145 |
-
|
|
|
|
|
146 |
--------
|
147 |
-
|
148 |
-
The
|
|
|
|
|
149 |
"""
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
tokenized_query = query.split(" ")
|
153 |
-
self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests, wikipedia, re
|
2 |
from rank_bm25 import BM25Okapi
|
3 |
+
import streamlit as st
|
4 |
+
import pandas as pd
|
5 |
+
import spacy
|
6 |
+
|
7 |
+
####################################
|
8 |
+
## Streamlit app helper functions ##
|
9 |
+
####################################
|
10 |
+
|
11 |
+
def get_examples():
|
12 |
+
"""
|
13 |
+
Function for loading example questions
|
14 |
+
and contexts from examples.csv
|
15 |
+
Parameters: None
|
16 |
+
-----------
|
17 |
+
Returns:
|
18 |
+
--------
|
19 |
+
ex_queries, ex_questions, ex_contexts : list(str), list(list(str)), list(str)
|
20 |
+
Example search query, question, and context strings
|
21 |
+
(each entry of ex_questions is a list of three question strings)
|
22 |
+
"""
|
23 |
+
examples = pd.read_csv('examples.csv')
|
24 |
+
ex_questions = [q.split(':') for q in list(examples['question'])]
|
25 |
+
ex_contexts = list(examples['context'])
|
26 |
+
ex_queries = list(examples['query'])
|
27 |
+
return ex_queries, ex_questions, ex_contexts
|
28 |
+
|
29 |
+
def basic_clear_boxes():
|
30 |
+
"""
|
31 |
+
Clears the question, context, response
|
32 |
+
"""
|
33 |
+
for field in ['question','context','response']:
|
34 |
+
st.session_state['basic'][field] = ''
|
35 |
+
|
36 |
+
def basic_ex_click(examples, i):
|
37 |
+
"""
|
38 |
+
Fills in the chosen example
|
39 |
+
"""
|
40 |
+
st.session_state['basic']['question'] = examples[1][i][0]
|
41 |
+
st.session_state['basic']['context'] = examples[2][i]
|
42 |
+
|
43 |
+
def semi_clear_query():
|
44 |
+
"""
|
45 |
+
Clears the search query field
|
46 |
+
and page options list
|
47 |
+
"""
|
48 |
+
st.session_state['semi']['query'] = ''
|
49 |
+
for field in ['selected_pages','page_options']:
|
50 |
+
st.session_state['semi'][field] = []
|
51 |
+
|
52 |
+
def semi_clear_question():
|
53 |
+
"""
|
54 |
+
Clears the question and response field
|
55 |
+
and selected pages list
|
56 |
+
"""
|
57 |
+
for field in ['question','response']:
|
58 |
+
st.session_state['semi'][field] = ''
|
59 |
+
|
60 |
+
def semi_ex_query_click(examples,i):
|
61 |
+
"""
|
62 |
+
Fills in the query example and
|
63 |
+
populates question examples when query
|
64 |
+
example button is clicked
|
65 |
+
"""
|
66 |
+
st.session_state['semi']['query'] = examples[0][i]
|
67 |
+
st.session_state['semi']['ex_questions'] = examples[1][i]
|
68 |
+
|
69 |
+
def semi_ex_question_click(i):
|
70 |
+
"""
|
71 |
+
Fills in the question example
|
72 |
+
"""
|
73 |
+
st.session_state['semi']['question'] = st.session_state['semi']['ex_questions'][i]
|
74 |
+
|
75 |
+
def auto_clear_boxes():
|
76 |
+
"""
|
77 |
+
Clears the response and question fields
|
78 |
+
"""
|
79 |
+
for field in ['question','response']:
|
80 |
+
st.session_state['auto'][field]=''
|
81 |
+
|
82 |
+
def auto_ex_click(examples,i):
|
83 |
+
"""
|
84 |
+
Fills in the chosen example question
|
85 |
+
"""
|
86 |
+
st.session_state['auto']['question'] = examples[1][i][0]
|
87 |
+
|
88 |
+
###########################
|
89 |
+
## Query helper function ##
|
90 |
+
###########################
|
91 |
+
|
92 |
+
def generate_query(nlp,text):
|
93 |
+
"""
|
94 |
+
Process text into a search query,
|
95 |
+
only retaining nouns, proper nouns,
|
96 |
+
numerals, verbs, and adjectives
|
97 |
+
Parameters:
|
98 |
+
-----------
|
99 |
+
nlp : spacy.Pipe
|
100 |
+
spaCy pipeline for processing search query
|
101 |
+
text : str
|
102 |
+
The input text to be processed
|
103 |
+
Returns:
|
104 |
+
--------
|
105 |
+
query : str
|
106 |
+
The condensed search query
|
107 |
+
"""
|
108 |
+
tokens = nlp(text)
|
109 |
+
keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
|
110 |
+
query = ' '.join(token.text for token in tokens \
|
111 |
+
if token.pos_ in keep)
|
112 |
+
return query
|
113 |
+
|
114 |
+
##############################
|
115 |
+
## Document retriever class ##
|
116 |
+
##############################
|
117 |
|
118 |
class ContextRetriever:
|
119 |
"""
|
|
|
126 |
self.pages = None
|
127 |
self.paragraphs = None
|
128 |
|
129 |
+
def get_pageids(self,query,topn = None):
|
130 |
"""
|
131 |
Retrieve page ids corresponding to a search query
|
132 |
Parameters:
|
133 |
-----------
|
134 |
query : str
|
135 |
A query to use for Wikipedia page search
|
136 |
+
topn : int or None
|
137 |
+
If topn is provided, will only return pageids
|
138 |
+
for topn search results
|
139 |
Returns: None, but stores:
|
140 |
--------
|
141 |
+
self.pageids : list(tuple(int,str))
|
142 |
+
A list of Wikipedia (pageid,title) tuples resulting
|
143 |
+
from the search
|
144 |
"""
|
145 |
params = {
|
146 |
'action':'query',
|
|
|
149 |
'format':'json',
|
150 |
}
|
151 |
results = requests.get(self.url, params=params).json()
|
152 |
+
pageids = [(page['pageid'],page['title']) for page in results['query']['search']]
|
153 |
+
pageids = pageids[:topn]
|
154 |
self.pageids = pageids
|
155 |
+
|
156 |
+
def ids_to_pages(self,ids):
|
157 |
"""
|
158 |
Use MediaWiki API to retrieve page content corresponding to
|
159 |
+
a list of pageids
|
160 |
+
Parameters:
|
161 |
-----------
|
162 |
+
ids : list(tuple(int,str))
|
163 |
+
A list of Wikipedia (pageid,title) tuples
|
164 |
Returns: None, but stores
|
165 |
--------
|
166 |
+
pages : list(tuple(str,str))
|
167 |
+
The k-th enry is a tuple consisting of the title and page content
|
168 |
+
of the page corresponding to the k-th entry of ids
|
169 |
"""
|
170 |
+
pages = []
|
171 |
+
for pageid in ids:
|
|
|
172 |
try:
|
173 |
+
page = wikipedia.page(pageid=pageid[0],auto_suggest=False)
|
174 |
+
pages.append((page.title, page.content))
|
175 |
+
except wikipedia.DisambiguationError:
|
176 |
continue
|
177 |
+
return pages
|
178 |
+
|
179 |
+
def get_all_pages(self):
|
180 |
"""
|
181 |
+
Use MediaWiki API to retrieve page content corresponding to
|
182 |
+
the list of pageids in self.pageids
|
183 |
Parameters: None
|
184 |
-----------
|
185 |
Returns: None, but stores
|
186 |
--------
|
187 |
+
self.pages : list(tuple(str,str))
|
188 |
+
The k-th enry is a tuple consisting of the title and page content
|
189 |
+
of the page corresponding to the k-th entry of self.pageids
|
190 |
+
"""
|
191 |
+
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
|
192 |
+
self.pages = self.ids_to_pages(self.pageids)
|
193 |
+
|
194 |
+
def pages_to_paragraphs(self,pages):
|
195 |
+
"""
|
196 |
+
Process a list of pages into a list of paragraphs from those pages
|
197 |
+
Parameters:
|
198 |
+
-----------
|
199 |
+
pages : list(str)
|
200 |
+
A list of Wikipedia page content dumps, as strings
|
201 |
+
Returns:
|
202 |
+
--------
|
203 |
+
paragraphs : dict
|
204 |
+
keys are titles of pages from pages (as strings)
|
205 |
+
paragraphs[page] is a list of paragraphs (as strings)
|
206 |
+
extracted from page
|
207 |
"""
|
|
|
208 |
# Content from WikiMedia has these headings. We only grab content appearing
|
209 |
# before the first instance of any of these
|
210 |
pattern = '|'.join([
|
|
|
221 |
'=== Citations ===',
|
222 |
])
|
223 |
pattern = re.compile(pattern)
|
224 |
+
paragraphs = {}
|
225 |
+
for page in pages:
|
226 |
# Truncate page to the first index of the start of a matching heading,
|
227 |
# or the end of the page if no matches exist
|
228 |
+
title, content = page
|
229 |
+
idx = min([match.start() for match in pattern.finditer(content)]+[len(content)])
|
230 |
+
content = content[:idx]
|
231 |
# Split into paragraphs, omitting lines with headings (start with '='),
|
232 |
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
|
233 |
+
paragraphs[title] = [
|
234 |
+
p for p in content.split('\n') if p \
|
235 |
and not p.startswith('=') \
|
236 |
+
and not p.startswith('\t\t') \
|
237 |
+
and not p.startswith(' ')
|
238 |
]
|
239 |
+
return paragraphs
|
240 |
+
|
241 |
+
def get_all_paragraphs(self):
|
242 |
+
"""
|
243 |
+
Process self.pages into list of paragraphs from pages
|
244 |
+
Parameters: None
|
245 |
+
-----------
|
246 |
+
Returns: None, but stores
|
247 |
+
--------
|
248 |
+
self.paragraphs : dict
|
249 |
+
keys are titles of pages from self.pages (as strings)
|
250 |
+
self.paragraphs[page] is a list of paragraphs (as strings)
|
251 |
+
extracted from page
|
252 |
+
"""
|
253 |
+
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
|
254 |
+
# Content from WikiMedia has these headings. We only grab content appearing
|
255 |
+
# before the first instance of any of these
|
256 |
+
self.paragraphs = self.pages_to_paragraphs(self.pages)
|
257 |
|
258 |
+
def rank_paragraphs(self,paragraphs,query,topn=10):
|
259 |
"""
|
260 |
+
Ranks the elements of paragraphs in descending order
|
261 |
+
by relevance to query using BM25 Okapi, and returns top
|
262 |
+
topn results
|
263 |
Parameters:
|
264 |
-----------
|
265 |
+
paragraphs : dict
|
266 |
+
keys are titles of pages (as strings)
|
267 |
+
paragraphs[page] is a list of paragraphs (as strings)
|
268 |
+
extracted from page
|
269 |
query : str
|
270 |
The query to use in ranking paragraphs by relevance
|
271 |
+
topn : int or None
|
272 |
The number of most relevant paragraphs to return
|
273 |
+
If None, will return roughly the top 1/4 of the
|
274 |
+
paragraphs
|
275 |
+
Returns:
|
276 |
--------
|
277 |
+
best_paragraphs : list(list(str,str))
|
278 |
+
The k-th entry is a list [title,paragraph] for the k-th
|
279 |
+
most relevant paragraph, where title is the title of the
|
280 |
+
Wikipedia article from which that paragraph was sourced
|
281 |
"""
|
282 |
+
corpus, titles, page_nums = [],[],[]
|
283 |
+
# Compile paragraphs into corpus
|
284 |
+
for i,page in enumerate(paragraphs):
|
285 |
+
titles.append(page)
|
286 |
+
paras = paragraphs[page]
|
287 |
+
corpus += paras
|
288 |
+
page_nums += len(paras)*[i]
|
289 |
+
|
290 |
+
# Tokenize corpus and query and initialize bm25 object
|
291 |
+
tokenized_corpus = [p.split(" ") for p in corpus]
|
292 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
293 |
tokenized_query = query.split(" ")
|
|
|
294 |
|
295 |
+
# Compute scores and compile tuples (paragraph number, score, page number)
|
296 |
+
# before sorting tuples by score
|
297 |
+
bm_scores = bm25.get_scores(tokenized_query)
|
298 |
+
paragraph_data = [[i,score,page_nums[i]] for i,score in enumerate(bm_scores)]
|
299 |
+
paragraph_data.sort(reverse=True,key=lambda p:p[1])
|
300 |
+
|
301 |
+
# Grab topn best [title,paragraph] pairs sorted by bm25 score
|
302 |
+
topn = len(paragraph_data)//4+1 if topn is None else min(topn,len(paragraph_data))
|
303 |
+
|
304 |
+
best_paragraphs = [[titles[p[2]],corpus[p[0]]] for p in paragraph_data[:topn]]
|
305 |
+
return best_paragraphs
|
306 |
+
|
307 |
+
def generate_answer(pipeline,paragraphs, question):
|
308 |
+
"""
|
309 |
+
Generate an answer using a question-answer pipeline
|
310 |
+
Parameters:
|
311 |
+
-----------
|
312 |
+
pipeline : transformers.QuestionAnsweringPipeline
|
313 |
+
The question answering pipeline object
|
314 |
+
paragraphs : list(list(str,str))
|
315 |
+
The k-th entry is a list [title,paragraph] consisting
|
316 |
+
of a context paragraph and the title of the page from which the
|
317 |
+
paragraph was sourced
|
318 |
+
question : str
|
319 |
+
A question that is to be answered based on context given
|
320 |
+
in the entries of paragraphs
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
--------
|
324 |
+
response : str
|
325 |
+
A response indicating the answer that was discovered,
|
326 |
+
or indicating that no answer could be found.
|
327 |
+
"""
|
328 |
+
# For each paragraph, format input to QA pipeline...
|
329 |
+
for paragraph in paragraphs:
|
330 |
+
input = {
|
331 |
+
'context':paragraph[1],
|
332 |
+
'question':question,
|
333 |
+
}
|
334 |
+
# ...and pass to QA pipeline
|
335 |
+
output = pipeline(**input)
|
336 |
+
# Append answers and scores. Report score of
|
337 |
+
# zero for paragraphs without answer, so they are
|
338 |
+
# deprioritized when the max is taken below
|
339 |
+
if output['answer']!='':
|
340 |
+
paragraph += [output['answer'],output['score']]
|
341 |
+
else:
|
342 |
+
paragraph += ['',0]
|
343 |
+
# Get paragraph with max confidence score and collect data
|
344 |
+
best_paragraph = max(paragraphs,key = lambda x:x[3])
|
345 |
+
best_answer = best_paragraph[2]
|
346 |
+
best_context_page = best_paragraph[0]
|
347 |
+
best_context = best_paragraph[1]
|
348 |
+
|
349 |
+
# Update response in session state
|
350 |
+
if best_answer == "":
|
351 |
+
response = "I cannot find the answer to your question."
|
352 |
+
else:
|
353 |
+
response = f"""
|
354 |
+
My answer is: {best_answer}
|
355 |
+
|
356 |
+
...and here's where I found it:
|
357 |
+
|
358 |
+
Page title: {best_context_page}
|
359 |
+
|
360 |
+
Paragraph containing answer:
|
361 |
+
|
362 |
+
{best_context}
|
363 |
+
"""
|
364 |
+
return response
|
lib/__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/lib/__pycache__/utils.cpython-310.pyc and b/lib/__pycache__/utils.cpython-310.pyc differ
|
|
lib/utils.py
CHANGED
@@ -1,13 +1,119 @@
|
|
1 |
import requests, wikipedia, re
|
2 |
from rank_bm25 import BM25Okapi
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class ContextRetriever:
|
13 |
"""
|
@@ -20,17 +126,21 @@ class ContextRetriever:
|
|
20 |
self.pages = None
|
21 |
self.paragraphs = None
|
22 |
|
23 |
-
def get_pageids(self,query):
|
24 |
"""
|
25 |
Retrieve page ids corresponding to a search query
|
26 |
Parameters:
|
27 |
-----------
|
28 |
query : str
|
29 |
A query to use for Wikipedia page search
|
|
|
|
|
|
|
30 |
Returns: None, but stores:
|
31 |
--------
|
32 |
-
self.pageids : list(int)
|
33 |
-
A list of Wikipedia
|
|
|
34 |
"""
|
35 |
params = {
|
36 |
'action':'query',
|
@@ -39,39 +149,62 @@ class ContextRetriever:
|
|
39 |
'format':'json',
|
40 |
}
|
41 |
results = requests.get(self.url, params=params).json()
|
42 |
-
pageids = [page['pageid'] for page in results['query']['search']]
|
|
|
43 |
self.pageids = pageids
|
44 |
-
|
45 |
-
def
|
46 |
"""
|
47 |
Use MediaWiki API to retrieve page content corresponding to
|
48 |
-
|
49 |
-
Parameters:
|
50 |
-----------
|
|
|
|
|
51 |
Returns: None, but stores
|
52 |
--------
|
53 |
-
|
54 |
-
|
|
|
55 |
"""
|
56 |
-
|
57 |
-
|
58 |
-
for pageid in self.pageids:
|
59 |
try:
|
60 |
-
|
61 |
-
|
|
|
62 |
continue
|
63 |
-
|
64 |
-
|
|
|
65 |
"""
|
66 |
-
|
|
|
67 |
Parameters: None
|
68 |
-----------
|
69 |
Returns: None, but stores
|
70 |
--------
|
71 |
-
self.
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
"""
|
74 |
-
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
|
75 |
# Content from WikiMedia has these headings. We only grab content appearing
|
76 |
# before the first instance of any of these
|
77 |
pattern = '|'.join([
|
@@ -88,38 +221,144 @@ class ContextRetriever:
|
|
88 |
'=== Citations ===',
|
89 |
])
|
90 |
pattern = re.compile(pattern)
|
91 |
-
paragraphs =
|
92 |
-
for page in
|
93 |
# Truncate page to the first index of the start of a matching heading,
|
94 |
# or the end of the page if no matches exist
|
95 |
-
|
96 |
-
|
|
|
97 |
# Split into paragraphs, omitting lines with headings (start with '='),
|
98 |
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
|
99 |
-
paragraphs
|
100 |
-
p for p in
|
101 |
and not p.startswith('=') \
|
102 |
-
and not p.startswith('\t\t')
|
|
|
103 |
]
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
def rank_paragraphs(self,query,topn=10):
|
107 |
"""
|
108 |
-
Ranks the elements of
|
109 |
-
by relevance to query using
|
|
|
110 |
Parameters:
|
111 |
-----------
|
|
|
|
|
|
|
|
|
112 |
query : str
|
113 |
The query to use in ranking paragraphs by relevance
|
114 |
-
topn : int
|
115 |
The number of most relevant paragraphs to return
|
116 |
-
|
|
|
|
|
117 |
--------
|
118 |
-
|
119 |
-
The
|
|
|
|
|
120 |
"""
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
tokenized_query = query.split(" ")
|
124 |
-
self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests, wikipedia, re
|
2 |
from rank_bm25 import BM25Okapi
|
3 |
+
import streamlit as st
|
4 |
+
import pandas as pd
|
5 |
+
import spacy
|
6 |
+
|
7 |
+
####################################
|
8 |
+
## Streamlit app helper functions ##
|
9 |
+
####################################
|
10 |
+
|
11 |
+
def get_examples():
|
12 |
+
"""
|
13 |
+
Function for loading example questions
|
14 |
+
and contexts from examples.csv
|
15 |
+
Parameters: None
|
16 |
+
-----------
|
17 |
+
Returns:
|
18 |
+
--------
|
19 |
+
ex_queries, ex_questions, ex_contexts : list(str), list(list(str)), list(str)
|
20 |
+
Example search query, question, and context strings
|
21 |
+
(each entry of ex_questions is a list of three question strings)
|
22 |
+
"""
|
23 |
+
examples = pd.read_csv('examples.csv')
|
24 |
+
ex_questions = [q.split(':') for q in list(examples['question'])]
|
25 |
+
ex_contexts = list(examples['context'])
|
26 |
+
ex_queries = list(examples['query'])
|
27 |
+
return ex_queries, ex_questions, ex_contexts
|
28 |
+
|
29 |
+
def basic_clear_boxes():
|
30 |
+
"""
|
31 |
+
Clears the question, context, response
|
32 |
+
"""
|
33 |
+
for field in ['question','context','response']:
|
34 |
+
st.session_state['basic'][field] = ''
|
35 |
+
|
36 |
+
def basic_ex_click(examples, i):
|
37 |
+
"""
|
38 |
+
Fills in the chosen example
|
39 |
+
"""
|
40 |
+
st.session_state['basic']['question'] = examples[1][i][0]
|
41 |
+
st.session_state['basic']['context'] = examples[2][i]
|
42 |
+
|
43 |
+
def semi_clear_query():
|
44 |
+
"""
|
45 |
+
Clears the search query field
|
46 |
+
and page options list
|
47 |
+
"""
|
48 |
+
st.session_state['semi']['query'] = ''
|
49 |
+
for field in ['selected_pages','page_options']:
|
50 |
+
st.session_state['semi'][field] = []
|
51 |
+
|
52 |
+
def semi_clear_question():
|
53 |
+
"""
|
54 |
+
Clears the question and response field
|
55 |
+
and selected pages list
|
56 |
+
"""
|
57 |
+
for field in ['question','response']:
|
58 |
+
st.session_state['semi'][field] = ''
|
59 |
+
|
60 |
+
def semi_ex_query_click(examples,i):
|
61 |
+
"""
|
62 |
+
Fills in the query example and
|
63 |
+
populates question examples when query
|
64 |
+
example button is clicked
|
65 |
+
"""
|
66 |
+
st.session_state['semi']['query'] = examples[0][i]
|
67 |
+
st.session_state['semi']['ex_questions'] = examples[1][i]
|
68 |
+
|
69 |
+
def semi_ex_question_click(i):
|
70 |
+
"""
|
71 |
+
Fills in the question example
|
72 |
+
"""
|
73 |
+
st.session_state['semi']['question'] = st.session_state['semi']['ex_questions'][i]
|
74 |
+
|
75 |
+
def auto_clear_boxes():
|
76 |
+
"""
|
77 |
+
Clears the response and question fields
|
78 |
+
"""
|
79 |
+
for field in ['question','response']:
|
80 |
+
st.session_state['auto'][field]=''
|
81 |
+
|
82 |
+
def auto_ex_click(examples,i):
|
83 |
+
"""
|
84 |
+
Fills in the chosen example question
|
85 |
+
"""
|
86 |
+
st.session_state['auto']['question'] = examples[1][i][0]
|
87 |
+
|
88 |
+
###########################
|
89 |
+
## Query helper function ##
|
90 |
+
###########################
|
91 |
+
|
92 |
+
def generate_query(nlp,text):
|
93 |
+
"""
|
94 |
+
Process text into a search query,
|
95 |
+
only retaining nouns, proper nouns,
|
96 |
+
numerals, verbs, and adjectives
|
97 |
+
Parameters:
|
98 |
+
-----------
|
99 |
+
nlp : spacy.Pipe
|
100 |
+
spaCy pipeline for processing search query
|
101 |
+
text : str
|
102 |
+
The input text to be processed
|
103 |
+
Returns:
|
104 |
+
--------
|
105 |
+
query : str
|
106 |
+
The condensed search query
|
107 |
+
"""
|
108 |
+
tokens = nlp(text)
|
109 |
+
keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
|
110 |
+
query = ' '.join(token.text for token in tokens \
|
111 |
+
if token.pos_ in keep)
|
112 |
+
return query
|
113 |
+
|
114 |
+
##############################
|
115 |
+
## Document retriever class ##
|
116 |
+
##############################
|
117 |
|
118 |
class ContextRetriever:
|
119 |
"""
|
|
|
126 |
self.pages = None
|
127 |
self.paragraphs = None
|
128 |
|
129 |
+
def get_pageids(self,query,topn = None):
|
130 |
"""
|
131 |
Retrieve page ids corresponding to a search query
|
132 |
Parameters:
|
133 |
-----------
|
134 |
query : str
|
135 |
A query to use for Wikipedia page search
|
136 |
+
topn : int or None
|
137 |
+
If topn is provided, will only return pageids
|
138 |
+
for topn search results
|
139 |
Returns: None, but stores:
|
140 |
--------
|
141 |
+
self.pageids : list(tuple(int,str))
|
142 |
+
A list of Wikipedia (pageid,title) tuples resulting
|
143 |
+
from the search
|
144 |
"""
|
145 |
params = {
|
146 |
'action':'query',
|
|
|
149 |
'format':'json',
|
150 |
}
|
151 |
results = requests.get(self.url, params=params).json()
|
152 |
+
pageids = [(page['pageid'],page['title']) for page in results['query']['search']]
|
153 |
+
pageids = pageids[:topn]
|
154 |
self.pageids = pageids
|
155 |
+
|
156 |
+
def ids_to_pages(self,ids):
|
157 |
"""
|
158 |
Use MediaWiki API to retrieve page content corresponding to
|
159 |
+
a list of pageids
|
160 |
+
Parameters:
|
161 |
-----------
|
162 |
+
ids : list(tuple(int,str))
|
163 |
+
A list of Wikipedia (pageid,title) tuples
|
164 |
Returns: None, but stores
|
165 |
--------
|
166 |
+
pages : list(tuple(str,str))
|
167 |
+
The k-th enry is a tuple consisting of the title and page content
|
168 |
+
of the page corresponding to the k-th entry of ids
|
169 |
"""
|
170 |
+
pages = []
|
171 |
+
for pageid in ids:
|
|
|
172 |
try:
|
173 |
+
page = wikipedia.page(pageid=pageid[0],auto_suggest=False)
|
174 |
+
pages.append((page.title, page.content))
|
175 |
+
except wikipedia.DisambiguationError:
|
176 |
continue
|
177 |
+
return pages
|
178 |
+
|
179 |
+
def get_all_pages(self):
|
180 |
"""
|
181 |
+
Use MediaWiki API to retrieve page content corresponding to
|
182 |
+
the list of pageids in self.pageids
|
183 |
Parameters: None
|
184 |
-----------
|
185 |
Returns: None, but stores
|
186 |
--------
|
187 |
+
self.pages : list(tuple(str,str))
|
188 |
+
The k-th enry is a tuple consisting of the title and page content
|
189 |
+
of the page corresponding to the k-th entry of self.pageids
|
190 |
+
"""
|
191 |
+
assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
|
192 |
+
self.pages = self.ids_to_pages(self.pageids)
|
193 |
+
|
194 |
+
def pages_to_paragraphs(self,pages):
|
195 |
+
"""
|
196 |
+
Process a list of pages into a list of paragraphs from those pages
|
197 |
+
Parameters:
|
198 |
+
-----------
|
199 |
+
pages : list(str)
|
200 |
+
A list of Wikipedia page content dumps, as strings
|
201 |
+
Returns:
|
202 |
+
--------
|
203 |
+
paragraphs : dict
|
204 |
+
keys are titles of pages from pages (as strings)
|
205 |
+
paragraphs[page] is a list of paragraphs (as strings)
|
206 |
+
extracted from page
|
207 |
"""
|
|
|
208 |
# Content from WikiMedia has these headings. We only grab content appearing
|
209 |
# before the first instance of any of these
|
210 |
pattern = '|'.join([
|
|
|
221 |
'=== Citations ===',
|
222 |
])
|
223 |
pattern = re.compile(pattern)
|
224 |
+
paragraphs = {}
|
225 |
+
for page in pages:
|
226 |
# Truncate page to the first index of the start of a matching heading,
|
227 |
# or the end of the page if no matches exist
|
228 |
+
title, content = page
|
229 |
+
idx = min([match.start() for match in pattern.finditer(content)]+[len(content)])
|
230 |
+
content = content[:idx]
|
231 |
# Split into paragraphs, omitting lines with headings (start with '='),
|
232 |
# empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
|
233 |
+
paragraphs[title] = [
|
234 |
+
p for p in content.split('\n') if p \
|
235 |
and not p.startswith('=') \
|
236 |
+
and not p.startswith('\t\t') \
|
237 |
+
and not p.startswith(' ')
|
238 |
]
|
239 |
+
return paragraphs
|
240 |
+
|
241 |
+
def get_all_paragraphs(self):
|
242 |
+
"""
|
243 |
+
Process self.pages into list of paragraphs from pages
|
244 |
+
Parameters: None
|
245 |
+
-----------
|
246 |
+
Returns: None, but stores
|
247 |
+
--------
|
248 |
+
self.paragraphs : dict
|
249 |
+
keys are titles of pages from self.pages (as strings)
|
250 |
+
self.paragraphs[page] is a list of paragraphs (as strings)
|
251 |
+
extracted from page
|
252 |
+
"""
|
253 |
+
assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
|
254 |
+
# Content from WikiMedia has these headings. We only grab content appearing
|
255 |
+
# before the first instance of any of these
|
256 |
+
self.paragraphs = self.pages_to_paragraphs(self.pages)
|
257 |
|
258 |
+
def rank_paragraphs(self,paragraphs,query,topn=10):
|
259 |
"""
|
260 |
+
Ranks the elements of paragraphs in descending order
|
261 |
+
by relevance to query using BM25 Okapi, and returns top
|
262 |
+
topn results
|
263 |
Parameters:
|
264 |
-----------
|
265 |
+
paragraphs : dict
|
266 |
+
keys are titles of pages (as strings)
|
267 |
+
paragraphs[page] is a list of paragraphs (as strings)
|
268 |
+
extracted from page
|
269 |
query : str
|
270 |
The query to use in ranking paragraphs by relevance
|
271 |
+
topn : int or None
|
272 |
The number of most relevant paragraphs to return
|
273 |
+
If None, will return roughly the top 1/4 of the
|
274 |
+
paragraphs
|
275 |
+
Returns:
|
276 |
--------
|
277 |
+
best_paragraphs : list(list(str,str))
|
278 |
+
The k-th entry is a list [title,paragraph] for the k-th
|
279 |
+
most relevant paragraph, where title is the title of the
|
280 |
+
Wikipedia article from which that paragraph was sourced
|
281 |
"""
|
282 |
+
corpus, titles, page_nums = [],[],[]
|
283 |
+
# Compile paragraphs into corpus
|
284 |
+
for i,page in enumerate(paragraphs):
|
285 |
+
titles.append(page)
|
286 |
+
paras = paragraphs[page]
|
287 |
+
corpus += paras
|
288 |
+
page_nums += len(paras)*[i]
|
289 |
+
|
290 |
+
# Tokenize corpus and query and initialize bm25 object
|
291 |
+
tokenized_corpus = [p.split(" ") for p in corpus]
|
292 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
293 |
tokenized_query = query.split(" ")
|
|
|
294 |
|
295 |
+
# Compute scores and compile tuples (paragraph number, score, page number)
|
296 |
+
# before sorting tuples by score
|
297 |
+
bm_scores = bm25.get_scores(tokenized_query)
|
298 |
+
paragraph_data = [[i,score,page_nums[i]] for i,score in enumerate(bm_scores)]
|
299 |
+
paragraph_data.sort(reverse=True,key=lambda p:p[1])
|
300 |
+
|
301 |
+
# Grab topn best [title,paragraph] pairs sorted by bm25 score
|
302 |
+
topn = len(paragraph_data)//4+1 if topn is None else min(topn,len(paragraph_data))
|
303 |
+
|
304 |
+
best_paragraphs = [[titles[p[2]],corpus[p[0]]] for p in paragraph_data[:topn]]
|
305 |
+
return best_paragraphs
|
306 |
+
|
307 |
+
def generate_answer(pipeline,paragraphs, question):
|
308 |
+
"""
|
309 |
+
Generate an answer using a question-answer pipeline
|
310 |
+
Parameters:
|
311 |
+
-----------
|
312 |
+
pipeline : transformers.QuestionAnsweringPipeline
|
313 |
+
The question answering pipeline object
|
314 |
+
paragraphs : list(list(str,str))
|
315 |
+
The k-th entry is a list [title,paragraph] consisting
|
316 |
+
of a context paragraph and the title of the page from which the
|
317 |
+
paragraph was sourced
|
318 |
+
question : str
|
319 |
+
A question that is to be answered based on context given
|
320 |
+
in the entries of paragraphs
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
--------
|
324 |
+
response : str
|
325 |
+
A response indicating the answer that was discovered,
|
326 |
+
or indicating that no answer could be found.
|
327 |
+
"""
|
328 |
+
# For each paragraph, format input to QA pipeline...
|
329 |
+
for paragraph in paragraphs:
|
330 |
+
input = {
|
331 |
+
'context':paragraph[1],
|
332 |
+
'question':question,
|
333 |
+
}
|
334 |
+
# ...and pass to QA pipeline
|
335 |
+
output = pipeline(**input)
|
336 |
+
# Append answers and scores. Report score of
|
337 |
+
# zero for paragraphs without answer, so they are
|
338 |
+
# deprioritized when the max is taken below
|
339 |
+
if output['answer']!='':
|
340 |
+
paragraph += [output['answer'],output['score']]
|
341 |
+
else:
|
342 |
+
paragraph += ['',0]
|
343 |
+
# Get paragraph with max confidence score and collect data
|
344 |
+
best_paragraph = max(paragraphs,key = lambda x:x[3])
|
345 |
+
best_answer = best_paragraph[2]
|
346 |
+
best_context_page = best_paragraph[0]
|
347 |
+
best_context = best_paragraph[1]
|
348 |
+
|
349 |
+
# Update response in session state
|
350 |
+
if best_answer == "":
|
351 |
+
response = "I cannot find the answer to your question."
|
352 |
+
else:
|
353 |
+
response = f"""
|
354 |
+
My answer is: {best_answer}
|
355 |
+
|
356 |
+
...and here's where I found it:
|
357 |
+
|
358 |
+
Page title: {best_context_page}
|
359 |
+
|
360 |
+
Paragraph containing answer:
|
361 |
+
|
362 |
+
{best_context}
|
363 |
+
"""
|
364 |
+
return response
|