etweedy commited on
Commit
ddadb1a
1 Parent(s): 42d54ec

Upload 9 files

Browse files
app.py CHANGED
@@ -8,28 +8,17 @@ from transformers import (
8
  pipeline,
9
  )
10
  import spacy
11
- # import pandas as pd
12
- from lib.utils import ContextRetriever
13
-
14
-
15
- #### TO DO:######
16
- # build out functions for:
17
- # * formatting input into document retrieval query (spaCy)
18
- # * document retrieval based on query (wikipedia library)
19
- # * document postprocessing into passages
20
- # * ranking passage based on BM25 scores for query (rank_bm25)
21
- # * feeding passages into RoBERTa an reporting answer(s) and passages as evidence
22
- # decide what to do with examples
23
-
24
- ### CAN REMOVE:#####
25
- # * context collection
26
- # *
27
 
28
- ########################
29
- ### Helper functions ###
30
- ########################
31
 
32
- # Build trainer using model and tokenizer from Hugging Face repo
33
  @st.cache_resource(show_spinner=False)
34
  def get_pipeline():
35
  """
@@ -70,57 +59,6 @@ def get_spacy():
70
  )
71
  return nlp
72
 
73
- def generate_query(nlp,text):
74
- """
75
- Process text into a search query,
76
- only retaining nouns, proper nouns,
77
- numerals, verbs, and adjectives
78
- Parameters:
79
- -----------
80
- nlp : spacy.Pipe
81
- spaCy pipeline for processing search query
82
- text : str
83
- The input text to be processed
84
- Returns:
85
- --------
86
- query : str
87
- The condensed search query
88
- """
89
- tokens = nlp(text)
90
- keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
91
- query = ' '.join(token.text for token in tokens \
92
- if token.pos_ in keep)
93
- return query
94
-
95
- def fill_in_example(i):
96
- """
97
- Function for context-question example button click
98
- """
99
- st.session_state['response'] = ''
100
- st.session_state['question'] = ex_q[i]
101
-
102
- def clear_boxes():
103
- """
104
- Function for field clear button click
105
- """
106
- st.session_state['response'] = ''
107
- st.session_state['question'] = ''
108
-
109
- # def get_examples():
110
- # """
111
- # Retrieve pre-made examples from a .csv file
112
- # Parameters: None
113
- # -----------
114
- # Returns:
115
- # --------
116
- # questions, contexts : list, list
117
- # Lists of examples of corresponding question-context pairs
118
-
119
- # """
120
- # examples = pd.read_csv('examples.csv')
121
- # questions = list(examples['question'])
122
- # return questions
123
-
124
  #############
125
  ### Setup ###
126
  #############
@@ -134,40 +72,71 @@ else:
134
  device = "cpu"
135
 
136
  # Initialize session state variables
137
- if 'response' not in st.session_state:
138
- st.session_state['response'] = ''
139
- if 'question' not in st.session_state:
140
- st.session_state['question'] = ''
141
-
142
- # Retrieve trained RoBERTa pipeline for Q&A
143
- # and spaCy pipeline for processing search query
 
 
 
 
144
  with st.spinner('Loading the model...'):
145
  qa_pipeline = get_pipeline()
146
  nlp = get_spacy()
147
 
148
- # # Grab example question-context pairs from csv file
149
- # ex_q, ex_c = get_examples()
 
 
 
 
 
 
 
150
 
151
- ###################
152
- ### App content ###
153
- ###################
 
 
 
154
 
155
- # Intro text
156
- st.header('RoBERTa answer retieval')
157
- st.markdown('''
158
- This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- Please type in a question and click submit. When you do, a few things will happen:
161
- 1. A Wikipedia search will be performed based on your question
162
- 2. Candidate passages will be ranked based on a similarity score as compared to your question
163
- 3. RoBERTa will search the best candidate passages to find the answer to your question
164
 
165
- If the model cannot find the answer to your question, it will tell you so.
166
- ''')
167
- with st.expander('Click to read more about the model...'):
168
  st.markdown('''
 
 
 
 
169
  * [Click here](https://huggingface.co/etweedy/roberta-base-squad-v2) to visit the Hugging Face model card for this fine-tuned model.
170
- * To create this model, the [RoBERTa base model](https://huggingface.co/roberta-base) was fine-tuned on Version 2 of [SQuAD (Stanford Question Answering Dataset)](https://huggingface.co/datasets/squad_v2), a dataset of context-question-answer triples.
171
  * The objective of the model is "extractive question answering", the task of retrieving the answer to the question from a given context text corpus.
172
  * SQuAD Version 2 incorporates the 100,000 samples from Version 1.1, along with 50,000 'unanswerable' questions, i.e. samples in the question cannot be answered using the context given.
173
  * The original base RoBERTa model was introduced in [this paper](https://arxiv.org/abs/1907.11692) and [this repository](https://github.com/facebookresearch/fairseq/tree/main/examples/roberta). Here's a citation for that base model:
@@ -195,102 +164,320 @@ with st.expander('Click to read more about the model...'):
195
  bibsource = {dblp computer science bibliography, https://dblp.org}
196
  }
197
  ```
198
- ''')
199
- # st.markdown('''
200
- # Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question based on the context you provided. If the model cannot find the answer in the context, it will tell you so - the model is also trained to recognize when the context doesn't provide the answer.
201
-
202
- # Your results will appear below the question field when the model is finished running.
 
 
203
 
204
- # Alternatively, you can try an example by clicking one of the buttons below:
205
- # ''')
 
206
 
207
- # Generate containers in order
208
- # example_container = st.container()
209
- input_container = st.container()
210
- button_container = st.container()
211
- response_container = st.container()
212
 
213
- ###########################
214
- ### Populate containers ###
215
- ###########################
 
 
216
 
217
- # Populate example button container
218
- # with example_container:
219
- # ex_cols = st.columns(len(ex_q)+1)
220
- # for i in range(len(ex_q)):
221
- # with ex_cols[i]:
222
- # st.button(
223
- # label = f'Try example {i+1}',
224
- # key = f'ex_button_{i+1}',
225
- # on_click = fill_in_example,
226
- # args=(i,),
227
- # )
228
- # with ex_cols[-1]:
229
- # st.button(
230
- # label = "Clear all fields",
231
- # key = "clear_button",
232
- # on_click = clear_boxes,
233
- # )
234
 
235
- # Populate user input container
236
- with input_container:
237
- with st.form(key='input_form',clear_on_submit=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  # Question input field
239
  question = st.text_input(
240
  label='Question',
241
- value=st.session_state['question'],
242
- key='question_field',
243
- label_visibility='hidden',
244
  placeholder='Enter your question here.',
245
  )
246
  # Form submit button
247
  query_submitted = st.form_submit_button("Submit")
248
- if query_submitted:
249
  # update question, context in session state
250
- st.session_state['question'] = question
251
- with st.spinner('Retrieving documentation...'):
252
- query = generate_query(nlp,question)
253
- retriever = ContextRetriever()
254
- retriever.get_pageids(query)
255
- retriever.get_pages()
256
- retriever.get_paragraphs()
257
- retriever.rank_paragraphs(query)
258
  with st.spinner('Generating response...'):
259
- # Loop through best_paragraph contexts
260
- # looking for answer in each
261
- best_answer = ""
262
- for context in retriever.best_paragraphs:
263
- input = {
264
- 'context':context,
265
- 'question':st.session_state['question'],
266
- }
267
- # Pass to QA pipeline
268
- response = qa_pipeline(**input)
269
- if response['answer']!='':
270
- best_answer = response['answer']
271
- best_context = context
272
- break
273
- # Update response in session state
274
- if best_answer == "":
275
- st.session_state['response'] = "I cannot find the answer to your question."
276
- else:
277
- st.session_state['response'] = f"""
278
- My answer is: {best_answer}
279
-
280
- ...and here's where I found it:
281
-
282
- {best_context}
283
  """
 
 
 
 
 
 
 
284
 
285
- # Button for clearing the form
286
- with button_container:
287
- st.button(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  label = "Clear all fields",
289
- key = "clear_button",
290
- on_click = clear_boxes,
291
  )
292
 
293
- # Display response
294
- with response_container:
295
- st.write(st.session_state['response'])
296
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  pipeline,
9
  )
10
  import spacy
11
+ from lib.utils import (
12
+ ContextRetriever,
13
+ get_examples,
14
+ generate_query,
15
+ generate_answer,
16
+ )
 
 
 
 
 
 
 
 
 
 
17
 
18
+ #################################
19
+ ### Model retrieval functions ###
20
+ #################################
21
 
 
22
  @st.cache_resource(show_spinner=False)
23
  def get_pipeline():
24
  """
 
59
  )
60
  return nlp
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  #############
63
  ### Setup ###
64
  #############
 
72
  device = "cpu"
73
 
74
  # Initialize session state variables
75
+ for tab in ['basic','semi','auto']:
76
+ if tab not in st.session_state:
77
+ st.session_state[tab] = {}
78
+ for field in ['question','context','query','response']:
79
+ if field not in st.session_state[tab]:
80
+ st.session_state[tab][field] = ''
81
+ for field in ['page_options','selected_pages']:
82
+ if field not in st.session_state['semi']:
83
+ st.session_state['semi'][field] = []
84
+
85
+ # Retrieve models
86
  with st.spinner('Loading the model...'):
87
  qa_pipeline = get_pipeline()
88
  nlp = get_spacy()
89
 
90
+ # Retrieve example questions and contexts
91
+ examples = get_examples()
92
+ # ex_queries, ex_questions, ex_contexts = get_examples()
93
+ if 'ex_questions' not in st.session_state['semi']:
94
+ st.session_state['semi']['ex_questions'] = len(examples[1][0])*['']
95
+
96
+ ################################
97
+ ### Initialize App Structure ###
98
+ ################################
99
 
100
+ tabs = st.tabs([
101
+ 'RoBERTa Q&A model',
102
+ 'Basic extractive Q&A',
103
+ 'User-guided Wiki Q&A',
104
+ 'Automated Wiki Q&A',
105
+ ])
106
 
107
+ with tabs[0]:
108
+ intro_container = st.container()
109
+ with tabs[1]:
110
+ basic_title_container = st.container()
111
+ basic_example_container = st.container()
112
+ basic_input_container = st.container()
113
+ basic_response_container = st.container()
114
+ with tabs[2]:
115
+ semi_title_container = st.container()
116
+ semi_query_container = st.container()
117
+ semi_page_container = st.container()
118
+ semi_input_container = st.container()
119
+ semi_response_container = st.container()
120
+ with tabs[3]:
121
+ auto_title_container = st.container()
122
+ auto_example_container = st.container()
123
+ auto_input_container = st.container()
124
+ auto_response_container = st.container()
125
 
126
+ ##############################
127
+ ### Populate tab - Welcome ###
128
+ ##############################
 
129
 
130
+ with intro_container:
131
+ # Intro text
132
+ st.header('RoBERTa Q&A model')
133
  st.markdown('''
134
+ This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
135
+ ''')
136
+ with st.expander('Click to read more about the model...'):
137
+ st.markdown('''
138
  * [Click here](https://huggingface.co/etweedy/roberta-base-squad-v2) to visit the Hugging Face model card for this fine-tuned model.
139
+ * To create this model, I fine-tuned the [RoBERTa base model](https://huggingface.co/roberta-base) Version 2 of [SQuAD (Stanford Question Answering Dataset)](https://huggingface.co/datasets/squad_v2), a dataset of context-question-answer triples.
140
  * The objective of the model is "extractive question answering", the task of retrieving the answer to the question from a given context text corpus.
141
  * SQuAD Version 2 incorporates the 100,000 samples from Version 1.1, along with 50,000 'unanswerable' questions, i.e. samples in the question cannot be answered using the context given.
142
  * The original base RoBERTa model was introduced in [this paper](https://arxiv.org/abs/1907.11692) and [this repository](https://github.com/facebookresearch/fairseq/tree/main/examples/roberta). Here's a citation for that base model:
 
164
  bibsource = {dblp computer science bibliography, https://dblp.org}
165
  }
166
  ```
167
+ ''')
168
+ st.markdown('''
169
+ Use the menu on the left side to navigate between different app components:
170
+ 1. A basic Q&A tool which allows the user to ask the model to search a user-provided context paragraph for the answer to a user-provided question.
171
+ 2. A user-guided Wiki Q&A tool which allows the user to search for one or more Wikipedia pages and ask the model to search those pages for the answer to a user-provided question.
172
+ 3. An automated Wiki Q&A tool which asks the model to perform retrieve its own Wikipedia pages in order to answer the user-provided question.
173
+ ''')
174
 
175
+ ################################
176
+ ### Populate tab - basic Q&A ###
177
+ ################################
178
 
179
+ from lib.utils import basic_clear_boxes, basic_ex_click
 
 
 
 
180
 
181
+ with basic_title_container:
182
+ ### Intro text ###
183
+ st.header('Basic extractive Q&A')
184
+ st.markdown('''
185
+ The basic functionality of a RoBERTa model for extractive question-answering is to attempt to extract the answer to a user-provided question from a piece of user-provided context text. The model is also trained to recognize when the context doesn't provide the answer.
186
 
187
+ Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question based on the context you provided, or report that it cannot find the answer in the context. Your results will appear below the question field when the model is finished running.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ Alternatively, you can try an example by clicking one of the buttons below:
190
+ ''')
191
+
192
+ ### Populate example button container ###
193
+ with basic_example_container:
194
+ basic_ex_cols = st.columns(len(examples[0])+1)
195
+ for i in range(len(examples[0])):
196
+ with basic_ex_cols[i]:
197
+ st.button(
198
+ label = f'example {i+1}',
199
+ key = f'basic_ex_button_{i+1}',
200
+ on_click = basic_ex_click,
201
+ args = (examples,i,),
202
+ )
203
+ with basic_ex_cols[-1]:
204
+ st.button(
205
+ label = "Clear all fields",
206
+ key = "basic_clear_button",
207
+ on_click = basic_clear_boxes,
208
+ )
209
+ ### Populate user input container ###
210
+ with basic_input_container:
211
+ with st.form(key='basic_input_form',clear_on_submit=False):
212
+ # Context input field
213
+ context = st.text_area(
214
+ label='Context',
215
+ value=st.session_state['basic']['context'],
216
+ key='basic_context_field',
217
+ label_visibility='collapsed',
218
+ placeholder='Enter your context paragraph here.',
219
+ height=300,
220
+ )
221
  # Question input field
222
  question = st.text_input(
223
  label='Question',
224
+ value=st.session_state['basic']['question'],
225
+ key='basic_question_field',
226
+ label_visibility='collapsed',
227
  placeholder='Enter your question here.',
228
  )
229
  # Form submit button
230
  query_submitted = st.form_submit_button("Submit")
231
+ if query_submitted and question!= '':
232
  # update question, context in session state
233
+ st.session_state['basic']['question'] = question
234
+ st.session_state['basic']['context'] = context
 
 
 
 
 
 
235
  with st.spinner('Generating response...'):
236
+ # Generate dictionary from inputs
237
+ query = {
238
+ 'context':st.session_state['basic']['context'],
239
+ 'question':st.session_state['basic']['question'],
240
+ }
241
+ # Pass to QA pipeline
242
+ response = qa_pipeline(**query)
243
+ answer = response['answer']
244
+ confidence = response['score']
245
+ # Reformat empty answer to message
246
+ if answer == '':
247
+ answer = "I don't have an answer based on the context provided."
248
+ # Update response in session state
249
+ st.session_state['basic']['response'] = f"""
250
+ Answer: {answer}\n
251
+ Confidence: {confidence:.2%}
 
 
 
 
 
 
 
 
252
  """
253
+ ### Populate response container ###
254
+ with basic_response_container:
255
+ st.write(st.session_state['basic']['response'])
256
+
257
+ #################################
258
+ ### Populate tab - guided Q&A ###
259
+ #################################
260
 
261
+ from lib.utils import (
262
+ semi_ex_query_click,
263
+ semi_ex_question_click,
264
+ semi_clear_query,
265
+ semi_clear_question,
266
+ )
267
+
268
+ ### Intro text ###
269
+ with semi_title_container:
270
+ st.header('User-guided Wiki Q&A')
271
+ st.markdown('''
272
+ This component allows you to perform a Wikipedia search for source material to feed as contexts to the RoBERTa question-answering model.
273
+ ''')
274
+ with st.expander("Click here to find out what's happening behind the scenes..."):
275
+ st.markdown('''
276
+ 1. A Wikipedia search is performed using your query, resulting in a list of pages which then populate the drop-down menu.
277
+ 2. The pages you select are retrieved and broken up into paragraphs. Wikipedia queries and page collection use the [wikipedia library](https://pypi.org/project/wikipedia/), a wrapper for the [MediaWiki API](https://www.mediawiki.org/wiki/API).
278
+ 3. The paragraphs are ranked in descending order of relevance to your question, using the [Okapi BM25 score](https://en.wikipedia.org/wiki/Okapi_BM25) as implemented in the [rank_bm25 library](https://github.com/dorianbrown/rank_bm25).
279
+ 4. Among these ranked paragraphs, approximately the top 25% are fed as context to the RoBERTa model, from which it will attempt to extract the answer to your question. The 'hit' having the highest confidence (prediction probability) from the model is reported as the answer.
280
+ ''')
281
+
282
+ ### Populate query container ###
283
+ with semi_query_container:
284
+ st.markdown('First submit a search query, or choose one of the examples.')
285
+ semi_query_cols = st.columns(len(examples[0])+1)
286
+ # Buttons for query examples
287
+ for i in range(len(examples[0])):
288
+ with semi_query_cols[i]:
289
+ st.button(
290
+ label = f'query {i+1}',
291
+ key = f'semi_query_button_{i+1}',
292
+ on_click = semi_ex_query_click,
293
+ args=(examples,i,),
294
+ )
295
+ # Button for clearning query field
296
+ with semi_query_cols[-1]:
297
+ st.button(
298
+ label = "Clear query",
299
+ key = "semi_clear_query",
300
+ on_click = semi_clear_query,
301
+ )
302
+ # Search query input form
303
+ with st.form(key='semi_query_form',clear_on_submit=False):
304
+ query = st.text_input(
305
+ label='Search query',
306
+ value=st.session_state['semi']['query'],
307
+ key='semi_query_field',
308
+ label_visibility='collapsed',
309
+ placeholder='Enter your Wikipedia search query here.',
310
+ )
311
+ query_submitted = st.form_submit_button("Submit")
312
+
313
+ if query_submitted and query != '':
314
+ st.session_state['semi']['query'] = query
315
+ # Retrieve Wikipedia page list from
316
+ # search results and store in session state
317
+ with st.spinner('Retrieving Wiki pages...'):
318
+ retriever = ContextRetriever()
319
+ retriever.get_pageids(query)
320
+ st.session_state['semi']['page_options'] = retriever.pageids
321
+ st.session_state['semi']['selected_pages'] = []
322
+
323
+ ### Populate page selection container ###
324
+ with semi_page_container:
325
+ st.markdown('Next select any number of Wikipedia pages to provide to RoBERTa:')
326
+ # Page title selection form
327
+ with st.form(key='semi_page_form',clear_on_submit=False):
328
+ selected_pages = st.multiselect(
329
+ label = "Choose Wiki pages for Q&A model:",
330
+ options = st.session_state['semi']['page_options'],
331
+ default = st.session_state['semi']['selected_pages'],
332
+ label_visibility = 'collapsed',
333
+ key = "semi_page_selectbox",
334
+ format_func = lambda x:x[1],
335
+ )
336
+ pages_submitted = st.form_submit_button("Submit")
337
+ if pages_submitted:
338
+ st.session_state['semi']['selected_pages'] = selected_pages
339
+
340
+ ### Populate question input container ###
341
+ with semi_input_container:
342
+ st.markdown('Finally submit a question for RoBERTa to answer based on the above pages or choose one of the examples.')
343
+ # Question example buttons
344
+ semi_ques_cols = st.columns(len(examples[0])+1)
345
+ for i in range(len(examples[0])):
346
+ with semi_ques_cols[i]:
347
+ st.button(
348
+ label = f'question {i+1}',
349
+ key = f'semi_ques_button_{i+1}',
350
+ on_click = semi_ex_question_click,
351
+ args=(i,),
352
+ )
353
+ # Question field clear button
354
+ with semi_ques_cols[-1]:
355
+ st.button(
356
+ label = "Clear question",
357
+ key = "semi_clear_question",
358
+ on_click = semi_clear_question,
359
+ )
360
+ # Question submission form
361
+ with st.form(key = "semi_question_form",clear_on_submit=False):
362
+ question = st.text_input(
363
+ label='Question',
364
+ value=st.session_state['semi']['question'],
365
+ key='semi_question_field',
366
+ label_visibility='collapsed',
367
+ placeholder='Enter your question here.',
368
+ )
369
+ question_submitted = st.form_submit_button("Submit")
370
+ if question_submitted and len(question)>0 and len(st.session_state['semi']['selected_pages'])>0:
371
+ st.session_state['semi']['response'] = ''
372
+ st.session_state['semi']['question'] = question
373
+ # Retrieve pages corresponding to user selections,
374
+ # extract paragraphs, and retrieve top 10 paragraphs,
375
+ # ranked by relevance to user question
376
+ with st.spinner("Retrieving documentation..."):
377
+ retriever = ContextRetriever()
378
+ pages = retriever.ids_to_pages(selected_pages)
379
+ paragraphs = retriever.pages_to_paragraphs(pages)
380
+ best_paragraphs = retriever.rank_paragraphs(
381
+ paragraphs, question,
382
+ )
383
+ with st.spinner("Generating response..."):
384
+ # Generate a response and update the session state
385
+ response = generate_answer(
386
+ pipeline = qa_pipeline,
387
+ paragraphs = best_paragraphs,
388
+ question = st.session_state['semi']['question'],
389
+ )
390
+ st.session_state['semi']['response'] = response
391
+
392
+ ### Populate response container ###
393
+ with semi_response_container:
394
+ st.write(st.session_state['semi']['response'])
395
+
396
+ ####################################
397
+ ### Populate tab - automated Q&A ###
398
+ ####################################
399
+
400
+ from lib.utils import auto_ex_click, auto_clear_boxes
401
+
402
+ ### Intro text ###
403
+ with auto_title_container:
404
+ st.header('Automated Wiki Q&A')
405
+ st.markdown('''
406
+ This component attempts to automate the Wiki-assisted extractive question-answering task. A Wikipedia search will be performed based on your question, and a list of relevant paragraphs will be passed to the RoBERTa model so it can attempt to find an answer.
407
+ ''')
408
+ with st.expander("Click here to find out what's happening behind the scenes..."):
409
+ st.markdown('''
410
+ When you submit a question, the following steps are performed:
411
+ 1. Your question is condensed into a search query which just retains nouns, verbs, numerals, and adjectives, where part-of-speech tagging is done using the [en_core_web_sm](https://spacy.io/models/en#en_core_web_sm) pipeline in the [spaCy library](https://spacy.io/).
412
+ 2. A Wikipedia search is performed using this query, resulting in several articles. The articles from the top 3 search results are collected and split into paragraphs. Wikipedia queries and article collection use the [wikipedia library](https://pypi.org/project/wikipedia/), a wrapper for the [MediaWiki API](https://www.mediawiki.org/wiki/API).
413
+ 4. The paragraphs are ranked in descending order of relevance to the query, using the [Okapi BM25 score](https://en.wikipedia.org/wiki/Okapi_BM25) as implemented in the [rank_bm25 library](https://github.com/dorianbrown/rank_bm25).
414
+ 5. The ten most relevant paragraphs are fed as context to the RoBERTa model, from which it will attempt to extract the answer to your question. The 'hit' having the highest confidence (prediction probability) from the model is reported as the answer.
415
+ ''')
416
+
417
+ st.markdown('''
418
+ Please provide a question you'd like the model to try to answer. The model will report back its answer, as well as an excerpt of text from Wikipedia in which it found its answer. Your result will appear below the question field when the model is finished running.
419
+
420
+ Alternatively, you can try an example by clicking one of the buttons below:
421
+ ''')
422
+
423
+ ### Populate example container ###
424
+ with auto_example_container:
425
+ auto_ex_cols = st.columns(len(examples[0])+1)
426
+ # Buttons for selecting example questions
427
+ for i in range(len(examples[0])):
428
+ with auto_ex_cols[i]:
429
+ st.button(
430
+ label = f'example {i+1}',
431
+ key = f'auto_ex_button_{i+1}',
432
+ on_click = auto_ex_click,
433
+ args=(examples,i,),
434
+ )
435
+ # Button for clearing question field and response
436
+ with auto_ex_cols[-1]:
437
+ st.button(
438
  label = "Clear all fields",
439
+ key = "auto_clear_button",
440
+ on_click = auto_clear_boxes,
441
  )
442
 
443
+ ### Populate user input container ###
444
+ with auto_input_container:
445
+ with st.form(key='auto_input_form',clear_on_submit=False):
446
+ # Question input field
447
+ question = st.text_input(
448
+ label='Question',
449
+ value=st.session_state['auto']['question'],
450
+ key='auto_question_field',
451
+ label_visibility='collapsed',
452
+ placeholder='Enter your question here.',
453
+ )
454
+ # Form submit button
455
+ question_submitted = st.form_submit_button("Submit")
456
+ if question_submitted:
457
+ # update question, context in session state
458
+ st.session_state['auto']['question'] = question
459
+ query = generate_query(nlp,question)
460
+ # query == '' will throw error in document retrieval
461
+ if len(query)==0:
462
+ st.session_state['auto']['response'] = 'Please include some nouns, verbs, and/or adjectives in your question.'
463
+ elif len(question)>0:
464
+ with st.spinner('Retrieving documentation...'):
465
+ # Retrieve ids from top 3 results
466
+ retriever = ContextRetriever()
467
+ retriever.get_pageids(query,topn=3)
468
+ # Retrieve pages then paragraphs
469
+ retriever.get_all_pages()
470
+ retriever.get_all_paragraphs()
471
+ # Get top 10 paragraphs, ranked by relevance to query
472
+ best_paragraphs = retriever.rank_paragraphs(retriever.paragraphs, query)
473
+ with st.spinner('Generating response...'):
474
+ # Generate a response and update the session state
475
+ response = generate_answer(
476
+ pipeline = qa_pipeline,
477
+ paragraphs = best_paragraphs,
478
+ question = st.session_state['auto']['question'],
479
+ )
480
+ st.session_state['auto']['response'] = response
481
+ ### Populate response container ###
482
+ with auto_response_container:
483
+ st.write(st.session_state['auto']['response'])
examples.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ question,context,query
2
+ What did Robert Oppenheimer remark after the Trinity test?:When was Robert Oppenheimer born?:What did Robert Oppenheimer do after World War II?,"Oppenheimer attended Harvard University, where he earned a bachelor's degree in chemistry in 1925. He studied physics at the University of Cambridge and University of Göttingen, where he received his PhD in 1927. He held academic positions at the University of California, Berkeley, and the California Institute of Technology, and made significant contributions to theoretical physics, including in quantum mechanics and nuclear physics. During World War II, he was recruited to work on the Manhattan Project, and in 1943 was appointed as director of the Los Alamos Laboratory in New Mexico, tasked with developing the weapons. Oppenheimer's leadership and scientific expertise were instrumental in the success of the project. He was among those who observed the Trinity test on July 16, 1945, in which the first atomic bomb was successfully detonated. He later remarked that the explosion brought to his mind words from the Hindu scripture Bhagavad Gita: ""Now I am become Death, the destroyer of worlds."" In August 1945, the atomic bombs were used on the Japanese cities of Hiroshima and Nagasaki, the only use of nuclear weapons in war.",Oppenheimer
3
+ Why did Twinkies change to vanilla cream?:When were Twinkies invented?:What is the urban legend about Twinkies?,"Twinkies were invented on April 6, 1930, by Canadian-born baker James Alexander Dewar for the Continental Baking Company in Schiller Park, Illinois. Realizing that several machines used for making cream-filled strawberry shortcake sat idle when strawberries were out of season, Dewar conceived a snack cake filled with banana cream, which he dubbed the Twinkie. Ritchy Koph said he came up with the name when he saw a billboard in St. Louis for ""Twinkle Toe Shoes"". During World War II, bananas were rationed, and the company was forced to switch to vanilla cream. This change proved popular, and banana-cream Twinkies were not widely re-introduced. The original flavor was occasionally found in limited time only promotions, but the company used vanilla cream for most Twinkies. In 1988, Fruit and Cream Twinkies were introduced with a strawberry filling swirled into the cream. The product was soon dropped. Vanilla's dominance over banana flavoring was challenged in 2005, following a month-long promotion of the movie King Kong. Hostess saw its Twinkie sales rise 20 percent during the promotion, and in 2007 restored the banana-cream Twinkie to its snack lineup although they are now made with 2% banana purée.",Twinkie
4
+ When was Pinkfong founded?:What is Pinkfong's most famous song?:Who represents the company?,"""Baby Shark"" is a children's song associated with a dance involving hand movements that originated as a campfire song dating back to at least the 20th century. In 2016, ""Baby Shark"" became very popular when Pinkfong, a South Korean entertainment company, released a version of the song with a YouTube music video that went viral across social media, online video, and radio. In January 2022, it became the first YouTube video to reach 10 billion views. In November 2020, Pinkfong's version became the most-viewed YouTube video of all time, with over 12 billion views as of April 2023. ""Baby Shark"" originated as a campfire song or chant. The original song dates back to at least the 20th century, potentially created by camp counselors inspired by the movie Jaws. In the chant, each member of a family of sharks is introduced, with campers using their hands to imitate the sharks' jaws. Different versions of the song have the sharks hunting fish, eating a sailor, or killing people who then go to heaven. Various entities have copyrighted original videos and sound recordings of the song, and some have trademarked merchandise based on their versions. However, according to The New York Times, the underlying song and characters are believed to be in the public domain.",Pinkfong
lib/.DS_Store CHANGED
Binary files a/lib/.DS_Store and b/lib/.DS_Store differ
 
lib/.ipynb_checkpoints/utils-checkpoint.py CHANGED
@@ -1,42 +1,119 @@
1
- import requests, wikipedia, re, spacy
2
  from rank_bm25 import BM25Okapi
3
- import torch
4
- from datasets import Dataset
5
- from torch.utils.data import DataLoader
6
- from transformers import (
7
- AutoTokenizer,
8
- AutoModelForQuestionAnswering,
9
- pipeline,
10
- )
11
-
12
- class QueryProcessor:
13
- """
14
- Processes text into queries using a spaCy model
15
- """
16
- def __init__(self):
17
- self.keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
18
- self.nlp = spacy.load(
19
- 'en_core_web_sm',
20
- disable = ['ner','parser','textcat']
21
- )
22
-
23
- def generate_query(self,text):
24
- """
25
- Process text into a search query,
26
- only retaining nouns, proper nouns numerals, verbs, and adjectives
27
- Parameters:
28
- -----------
29
- text : str
30
- The input text to be processed
31
- Returns:
32
- --------
33
- query : str
34
- The condensed search query
35
- """
36
- tokens = self.nlp(text)
37
- query = ' '.join(token.text for token in tokens \
38
- if token.pos_ in self.keep)
39
- return query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  class ContextRetriever:
42
  """
@@ -49,17 +126,21 @@ class ContextRetriever:
49
  self.pages = None
50
  self.paragraphs = None
51
 
52
- def get_pageids(self,query):
53
  """
54
  Retrieve page ids corresponding to a search query
55
  Parameters:
56
  -----------
57
  query : str
58
  A query to use for Wikipedia page search
 
 
 
59
  Returns: None, but stores:
60
  --------
61
- self.pageids : list(int)
62
- A list of Wikipedia page ids corresponding to search results
 
63
  """
64
  params = {
65
  'action':'query',
@@ -68,39 +149,62 @@ class ContextRetriever:
68
  'format':'json',
69
  }
70
  results = requests.get(self.url, params=params).json()
71
- pageids = [page['pageid'] for page in results['query']['search']]
 
72
  self.pageids = pageids
73
-
74
- def get_pages(self):
75
  """
76
  Use MediaWiki API to retrieve page content corresponding to
77
- entries of self.pageids
78
- Parameters: None
79
  -----------
 
 
80
  Returns: None, but stores
81
  --------
82
- self.pages : list(str)
83
- Entries are content of pages corresponding to entries of self.pageid
 
84
  """
85
- assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
86
- self.pages = []
87
- for pageid in self.pageids:
88
  try:
89
- self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
90
- except wikipedia.DisambiguationError as e:
 
91
  continue
92
-
93
- def get_paragraphs(self):
 
94
  """
95
- Process self.pages into list of paragraphs from pages
 
96
  Parameters: None
97
  -----------
98
  Returns: None, but stores
99
  --------
100
- self.paragraphs : list(str)
101
- List of paragraphs from all pages in self.pages, in order of self.pages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  """
103
- assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
104
  # Content from WikiMedia has these headings. We only grab content appearing
105
  # before the first instance of any of these
106
  pattern = '|'.join([
@@ -117,38 +221,144 @@ class ContextRetriever:
117
  '=== Citations ===',
118
  ])
119
  pattern = re.compile(pattern)
120
- paragraphs = []
121
- for page in self.pages:
122
  # Truncate page to the first index of the start of a matching heading,
123
  # or the end of the page if no matches exist
124
- idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
125
- page = page[:idx]
 
126
  # Split into paragraphs, omitting lines with headings (start with '='),
127
  # empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
128
- paragraphs += [
129
- p for p in page.split('\n') if p \
130
  and not p.startswith('=') \
131
- and not p.startswith('\t\t')
 
132
  ]
133
- self.paragraphs = paragraphs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- def rank_paragraphs(self,query,topn=10):
136
  """
137
- Ranks the elements of self.paragraphs in descending order
138
- by relevance to query using BM25F, and returns top topn results
 
139
  Parameters:
140
  -----------
 
 
 
 
141
  query : str
142
  The query to use in ranking paragraphs by relevance
143
- topn : int
144
  The number of most relevant paragraphs to return
145
- Returns: None, but stores
 
 
146
  --------
147
- self.best_paragraphs : list(str)
148
- The topn most relevant paragraphs to the query
 
 
149
  """
150
- tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
151
- bm25 = BM25Okapi(tokenized_paragraphs)
 
 
 
 
 
 
 
 
 
152
  tokenized_query = query.split(" ")
153
- self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, wikipedia, re
2
  from rank_bm25 import BM25Okapi
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import spacy
6
+
7
+ ####################################
8
+ ## Streamlit app helper functions ##
9
+ ####################################
10
+
11
+ def get_examples():
12
+ """
13
+ Function for loading example questions
14
+ and contexts from examples.csv
15
+ Parameters: None
16
+ -----------
17
+ Returns:
18
+ --------
19
+ ex_queries, ex_questions, ex_contexts : list(str), list(list(str)), list(str)
20
+ Example search query, question, and context strings
21
+ (each entry of ex_questions is a list of three question strings)
22
+ """
23
+ examples = pd.read_csv('examples.csv')
24
+ ex_questions = [q.split(':') for q in list(examples['question'])]
25
+ ex_contexts = list(examples['context'])
26
+ ex_queries = list(examples['query'])
27
+ return ex_queries, ex_questions, ex_contexts
28
+
29
+ def basic_clear_boxes():
30
+ """
31
+ Clears the question, context, response
32
+ """
33
+ for field in ['question','context','response']:
34
+ st.session_state['basic'][field] = ''
35
+
36
+ def basic_ex_click(examples, i):
37
+ """
38
+ Fills in the chosen example
39
+ """
40
+ st.session_state['basic']['question'] = examples[1][i][0]
41
+ st.session_state['basic']['context'] = examples[2][i]
42
+
43
+ def semi_clear_query():
44
+ """
45
+ Clears the search query field
46
+ and page options list
47
+ """
48
+ st.session_state['semi']['query'] = ''
49
+ for field in ['selected_pages','page_options']:
50
+ st.session_state['semi'][field] = []
51
+
52
+ def semi_clear_question():
53
+ """
54
+ Clears the question and response field
55
+ and selected pages list
56
+ """
57
+ for field in ['question','response']:
58
+ st.session_state['semi'][field] = ''
59
+
60
+ def semi_ex_query_click(examples,i):
61
+ """
62
+ Fills in the query example and
63
+ populates question examples when query
64
+ example button is clicked
65
+ """
66
+ st.session_state['semi']['query'] = examples[0][i]
67
+ st.session_state['semi']['ex_questions'] = examples[1][i]
68
+
69
+ def semi_ex_question_click(i):
70
+ """
71
+ Fills in the question example
72
+ """
73
+ st.session_state['semi']['question'] = st.session_state['semi']['ex_questions'][i]
74
+
75
+ def auto_clear_boxes():
76
+ """
77
+ Clears the response and question fields
78
+ """
79
+ for field in ['question','response']:
80
+ st.session_state['auto'][field]=''
81
+
82
+ def auto_ex_click(examples,i):
83
+ """
84
+ Fills in the chosen example question
85
+ """
86
+ st.session_state['auto']['question'] = examples[1][i][0]
87
+
88
+ ###########################
89
+ ## Query helper function ##
90
+ ###########################
91
+
92
+ def generate_query(nlp,text):
93
+ """
94
+ Process text into a search query,
95
+ only retaining nouns, proper nouns,
96
+ numerals, verbs, and adjectives
97
+ Parameters:
98
+ -----------
99
+ nlp : spacy.Pipe
100
+ spaCy pipeline for processing search query
101
+ text : str
102
+ The input text to be processed
103
+ Returns:
104
+ --------
105
+ query : str
106
+ The condensed search query
107
+ """
108
+ tokens = nlp(text)
109
+ keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
110
+ query = ' '.join(token.text for token in tokens \
111
+ if token.pos_ in keep)
112
+ return query
113
+
114
+ ##############################
115
+ ## Document retriever class ##
116
+ ##############################
117
 
118
  class ContextRetriever:
119
  """
 
126
  self.pages = None
127
  self.paragraphs = None
128
 
129
+ def get_pageids(self,query,topn = None):
130
  """
131
  Retrieve page ids corresponding to a search query
132
  Parameters:
133
  -----------
134
  query : str
135
  A query to use for Wikipedia page search
136
+ topn : int or None
137
+ If topn is provided, will only return pageids
138
+ for topn search results
139
  Returns: None, but stores:
140
  --------
141
+ self.pageids : list(tuple(int,str))
142
+ A list of Wikipedia (pageid,title) tuples resulting
143
+ from the search
144
  """
145
  params = {
146
  'action':'query',
 
149
  'format':'json',
150
  }
151
  results = requests.get(self.url, params=params).json()
152
+ pageids = [(page['pageid'],page['title']) for page in results['query']['search']]
153
+ pageids = pageids[:topn]
154
  self.pageids = pageids
155
+
156
+ def ids_to_pages(self,ids):
157
  """
158
  Use MediaWiki API to retrieve page content corresponding to
159
+ a list of pageids
160
+ Parameters:
161
  -----------
162
+ ids : list(tuple(int,str))
163
+ A list of Wikipedia (pageid,title) tuples
164
  Returns: None, but stores
165
  --------
166
+ pages : list(tuple(str,str))
167
+ The k-th enry is a tuple consisting of the title and page content
168
+ of the page corresponding to the k-th entry of ids
169
  """
170
+ pages = []
171
+ for pageid in ids:
 
172
  try:
173
+ page = wikipedia.page(pageid=pageid[0],auto_suggest=False)
174
+ pages.append((page.title, page.content))
175
+ except wikipedia.DisambiguationError:
176
  continue
177
+ return pages
178
+
179
+ def get_all_pages(self):
180
  """
181
+ Use MediaWiki API to retrieve page content corresponding to
182
+ the list of pageids in self.pageids
183
  Parameters: None
184
  -----------
185
  Returns: None, but stores
186
  --------
187
+ self.pages : list(tuple(str,str))
188
+ The k-th enry is a tuple consisting of the title and page content
189
+ of the page corresponding to the k-th entry of self.pageids
190
+ """
191
+ assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
192
+ self.pages = self.ids_to_pages(self.pageids)
193
+
194
+ def pages_to_paragraphs(self,pages):
195
+ """
196
+ Process a list of pages into a list of paragraphs from those pages
197
+ Parameters:
198
+ -----------
199
+ pages : list(str)
200
+ A list of Wikipedia page content dumps, as strings
201
+ Returns:
202
+ --------
203
+ paragraphs : dict
204
+ keys are titles of pages from pages (as strings)
205
+ paragraphs[page] is a list of paragraphs (as strings)
206
+ extracted from page
207
  """
 
208
  # Content from WikiMedia has these headings. We only grab content appearing
209
  # before the first instance of any of these
210
  pattern = '|'.join([
 
221
  '=== Citations ===',
222
  ])
223
  pattern = re.compile(pattern)
224
+ paragraphs = {}
225
+ for page in pages:
226
  # Truncate page to the first index of the start of a matching heading,
227
  # or the end of the page if no matches exist
228
+ title, content = page
229
+ idx = min([match.start() for match in pattern.finditer(content)]+[len(content)])
230
+ content = content[:idx]
231
  # Split into paragraphs, omitting lines with headings (start with '='),
232
  # empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
233
+ paragraphs[title] = [
234
+ p for p in content.split('\n') if p \
235
  and not p.startswith('=') \
236
+ and not p.startswith('\t\t') \
237
+ and not p.startswith(' ')
238
  ]
239
+ return paragraphs
240
+
241
+ def get_all_paragraphs(self):
242
+ """
243
+ Process self.pages into list of paragraphs from pages
244
+ Parameters: None
245
+ -----------
246
+ Returns: None, but stores
247
+ --------
248
+ self.paragraphs : dict
249
+ keys are titles of pages from self.pages (as strings)
250
+ self.paragraphs[page] is a list of paragraphs (as strings)
251
+ extracted from page
252
+ """
253
+ assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
254
+ # Content from WikiMedia has these headings. We only grab content appearing
255
+ # before the first instance of any of these
256
+ self.paragraphs = self.pages_to_paragraphs(self.pages)
257
 
258
+ def rank_paragraphs(self,paragraphs,query,topn=10):
259
  """
260
+ Ranks the elements of paragraphs in descending order
261
+ by relevance to query using BM25 Okapi, and returns top
262
+ topn results
263
  Parameters:
264
  -----------
265
+ paragraphs : dict
266
+ keys are titles of pages (as strings)
267
+ paragraphs[page] is a list of paragraphs (as strings)
268
+ extracted from page
269
  query : str
270
  The query to use in ranking paragraphs by relevance
271
+ topn : int or None
272
  The number of most relevant paragraphs to return
273
+ If None, will return roughly the top 1/4 of the
274
+ paragraphs
275
+ Returns:
276
  --------
277
+ best_paragraphs : list(list(str,str))
278
+ The k-th entry is a list [title,paragraph] for the k-th
279
+ most relevant paragraph, where title is the title of the
280
+ Wikipedia article from which that paragraph was sourced
281
  """
282
+ corpus, titles, page_nums = [],[],[]
283
+ # Compile paragraphs into corpus
284
+ for i,page in enumerate(paragraphs):
285
+ titles.append(page)
286
+ paras = paragraphs[page]
287
+ corpus += paras
288
+ page_nums += len(paras)*[i]
289
+
290
+ # Tokenize corpus and query and initialize bm25 object
291
+ tokenized_corpus = [p.split(" ") for p in corpus]
292
+ bm25 = BM25Okapi(tokenized_corpus)
293
  tokenized_query = query.split(" ")
 
294
 
295
+ # Compute scores and compile tuples (paragraph number, score, page number)
296
+ # before sorting tuples by score
297
+ bm_scores = bm25.get_scores(tokenized_query)
298
+ paragraph_data = [[i,score,page_nums[i]] for i,score in enumerate(bm_scores)]
299
+ paragraph_data.sort(reverse=True,key=lambda p:p[1])
300
+
301
+ # Grab topn best [title,paragraph] pairs sorted by bm25 score
302
+ topn = len(paragraph_data)//4+1 if topn is None else min(topn,len(paragraph_data))
303
+
304
+ best_paragraphs = [[titles[p[2]],corpus[p[0]]] for p in paragraph_data[:topn]]
305
+ return best_paragraphs
306
+
307
+ def generate_answer(pipeline,paragraphs, question):
308
+ """
309
+ Generate an answer using a question-answer pipeline
310
+ Parameters:
311
+ -----------
312
+ pipeline : transformers.QuestionAnsweringPipeline
313
+ The question answering pipeline object
314
+ paragraphs : list(list(str,str))
315
+ The k-th entry is a list [title,paragraph] consisting
316
+ of a context paragraph and the title of the page from which the
317
+ paragraph was sourced
318
+ question : str
319
+ A question that is to be answered based on context given
320
+ in the entries of paragraphs
321
+
322
+ Returns:
323
+ --------
324
+ response : str
325
+ A response indicating the answer that was discovered,
326
+ or indicating that no answer could be found.
327
+ """
328
+ # For each paragraph, format input to QA pipeline...
329
+ for paragraph in paragraphs:
330
+ input = {
331
+ 'context':paragraph[1],
332
+ 'question':question,
333
+ }
334
+ # ...and pass to QA pipeline
335
+ output = pipeline(**input)
336
+ # Append answers and scores. Report score of
337
+ # zero for paragraphs without answer, so they are
338
+ # deprioritized when the max is taken below
339
+ if output['answer']!='':
340
+ paragraph += [output['answer'],output['score']]
341
+ else:
342
+ paragraph += ['',0]
343
+ # Get paragraph with max confidence score and collect data
344
+ best_paragraph = max(paragraphs,key = lambda x:x[3])
345
+ best_answer = best_paragraph[2]
346
+ best_context_page = best_paragraph[0]
347
+ best_context = best_paragraph[1]
348
+
349
+ # Update response in session state
350
+ if best_answer == "":
351
+ response = "I cannot find the answer to your question."
352
+ else:
353
+ response = f"""
354
+ My answer is: {best_answer}
355
+
356
+ ...and here's where I found it:
357
+
358
+ Page title: {best_context_page}
359
+
360
+ Paragraph containing answer:
361
+
362
+ {best_context}
363
+ """
364
+ return response
lib/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/lib/__pycache__/utils.cpython-310.pyc and b/lib/__pycache__/utils.cpython-310.pyc differ
 
lib/utils.py CHANGED
@@ -1,13 +1,119 @@
1
  import requests, wikipedia, re
2
  from rank_bm25 import BM25Okapi
3
- import torch
4
- from datasets import Dataset
5
- from torch.utils.data import DataLoader
6
- from transformers import (
7
- AutoTokenizer,
8
- AutoModelForQuestionAnswering,
9
- pipeline,
10
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class ContextRetriever:
13
  """
@@ -20,17 +126,21 @@ class ContextRetriever:
20
  self.pages = None
21
  self.paragraphs = None
22
 
23
- def get_pageids(self,query):
24
  """
25
  Retrieve page ids corresponding to a search query
26
  Parameters:
27
  -----------
28
  query : str
29
  A query to use for Wikipedia page search
 
 
 
30
  Returns: None, but stores:
31
  --------
32
- self.pageids : list(int)
33
- A list of Wikipedia page ids corresponding to search results
 
34
  """
35
  params = {
36
  'action':'query',
@@ -39,39 +149,62 @@ class ContextRetriever:
39
  'format':'json',
40
  }
41
  results = requests.get(self.url, params=params).json()
42
- pageids = [page['pageid'] for page in results['query']['search']]
 
43
  self.pageids = pageids
44
-
45
- def get_pages(self):
46
  """
47
  Use MediaWiki API to retrieve page content corresponding to
48
- entries of self.pageids
49
- Parameters: None
50
  -----------
 
 
51
  Returns: None, but stores
52
  --------
53
- self.pages : list(str)
54
- Entries are content of pages corresponding to entries of self.pageid
 
55
  """
56
- assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
57
- self.pages = []
58
- for pageid in self.pageids:
59
  try:
60
- self.pages.append(wikipedia.page(pageid=pageid,auto_suggest=False).content)
61
- except wikipedia.DisambiguationError as e:
 
62
  continue
63
-
64
- def get_paragraphs(self):
 
65
  """
66
- Process self.pages into list of paragraphs from pages
 
67
  Parameters: None
68
  -----------
69
  Returns: None, but stores
70
  --------
71
- self.paragraphs : list(str)
72
- List of paragraphs from all pages in self.pages, in order of self.pages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  """
74
- assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
75
  # Content from WikiMedia has these headings. We only grab content appearing
76
  # before the first instance of any of these
77
  pattern = '|'.join([
@@ -88,38 +221,144 @@ class ContextRetriever:
88
  '=== Citations ===',
89
  ])
90
  pattern = re.compile(pattern)
91
- paragraphs = []
92
- for page in self.pages:
93
  # Truncate page to the first index of the start of a matching heading,
94
  # or the end of the page if no matches exist
95
- idx = min([match.start() for match in pattern.finditer(page)]+[len(page)])
96
- page = page[:idx]
 
97
  # Split into paragraphs, omitting lines with headings (start with '='),
98
  # empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
99
- paragraphs += [
100
- p for p in page.split('\n') if p \
101
  and not p.startswith('=') \
102
- and not p.startswith('\t\t')
 
103
  ]
104
- self.paragraphs = paragraphs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- def rank_paragraphs(self,query,topn=10):
107
  """
108
- Ranks the elements of self.paragraphs in descending order
109
- by relevance to query using BM25F, and returns top topn results
 
110
  Parameters:
111
  -----------
 
 
 
 
112
  query : str
113
  The query to use in ranking paragraphs by relevance
114
- topn : int
115
  The number of most relevant paragraphs to return
116
- Returns: None, but stores
 
 
117
  --------
118
- self.best_paragraphs : list(str)
119
- The topn most relevant paragraphs to the query
 
 
120
  """
121
- tokenized_paragraphs = [p.split(" ") for p in self.paragraphs]
122
- bm25 = BM25Okapi(tokenized_paragraphs)
 
 
 
 
 
 
 
 
 
123
  tokenized_query = query.split(" ")
124
- self.best_paragraphs = bm25.get_top_n(tokenized_query,self.paragraphs,n=topn)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests, wikipedia, re
2
  from rank_bm25 import BM25Okapi
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import spacy
6
+
7
+ ####################################
8
+ ## Streamlit app helper functions ##
9
+ ####################################
10
+
11
+ def get_examples():
12
+ """
13
+ Function for loading example questions
14
+ and contexts from examples.csv
15
+ Parameters: None
16
+ -----------
17
+ Returns:
18
+ --------
19
+ ex_queries, ex_questions, ex_contexts : list(str), list(list(str)), list(str)
20
+ Example search query, question, and context strings
21
+ (each entry of ex_questions is a list of three question strings)
22
+ """
23
+ examples = pd.read_csv('examples.csv')
24
+ ex_questions = [q.split(':') for q in list(examples['question'])]
25
+ ex_contexts = list(examples['context'])
26
+ ex_queries = list(examples['query'])
27
+ return ex_queries, ex_questions, ex_contexts
28
+
29
+ def basic_clear_boxes():
30
+ """
31
+ Clears the question, context, response
32
+ """
33
+ for field in ['question','context','response']:
34
+ st.session_state['basic'][field] = ''
35
+
36
+ def basic_ex_click(examples, i):
37
+ """
38
+ Fills in the chosen example
39
+ """
40
+ st.session_state['basic']['question'] = examples[1][i][0]
41
+ st.session_state['basic']['context'] = examples[2][i]
42
+
43
+ def semi_clear_query():
44
+ """
45
+ Clears the search query field
46
+ and page options list
47
+ """
48
+ st.session_state['semi']['query'] = ''
49
+ for field in ['selected_pages','page_options']:
50
+ st.session_state['semi'][field] = []
51
+
52
+ def semi_clear_question():
53
+ """
54
+ Clears the question and response field
55
+ and selected pages list
56
+ """
57
+ for field in ['question','response']:
58
+ st.session_state['semi'][field] = ''
59
+
60
+ def semi_ex_query_click(examples,i):
61
+ """
62
+ Fills in the query example and
63
+ populates question examples when query
64
+ example button is clicked
65
+ """
66
+ st.session_state['semi']['query'] = examples[0][i]
67
+ st.session_state['semi']['ex_questions'] = examples[1][i]
68
+
69
+ def semi_ex_question_click(i):
70
+ """
71
+ Fills in the question example
72
+ """
73
+ st.session_state['semi']['question'] = st.session_state['semi']['ex_questions'][i]
74
+
75
+ def auto_clear_boxes():
76
+ """
77
+ Clears the response and question fields
78
+ """
79
+ for field in ['question','response']:
80
+ st.session_state['auto'][field]=''
81
+
82
+ def auto_ex_click(examples,i):
83
+ """
84
+ Fills in the chosen example question
85
+ """
86
+ st.session_state['auto']['question'] = examples[1][i][0]
87
+
88
+ ###########################
89
+ ## Query helper function ##
90
+ ###########################
91
+
92
+ def generate_query(nlp,text):
93
+ """
94
+ Process text into a search query,
95
+ only retaining nouns, proper nouns,
96
+ numerals, verbs, and adjectives
97
+ Parameters:
98
+ -----------
99
+ nlp : spacy.Pipe
100
+ spaCy pipeline for processing search query
101
+ text : str
102
+ The input text to be processed
103
+ Returns:
104
+ --------
105
+ query : str
106
+ The condensed search query
107
+ """
108
+ tokens = nlp(text)
109
+ keep = {'PROPN', 'NUM', 'VERB', 'NOUN', 'ADJ'}
110
+ query = ' '.join(token.text for token in tokens \
111
+ if token.pos_ in keep)
112
+ return query
113
+
114
+ ##############################
115
+ ## Document retriever class ##
116
+ ##############################
117
 
118
  class ContextRetriever:
119
  """
 
126
  self.pages = None
127
  self.paragraphs = None
128
 
129
+ def get_pageids(self,query,topn = None):
130
  """
131
  Retrieve page ids corresponding to a search query
132
  Parameters:
133
  -----------
134
  query : str
135
  A query to use for Wikipedia page search
136
+ topn : int or None
137
+ If topn is provided, will only return pageids
138
+ for topn search results
139
  Returns: None, but stores:
140
  --------
141
+ self.pageids : list(tuple(int,str))
142
+ A list of Wikipedia (pageid,title) tuples resulting
143
+ from the search
144
  """
145
  params = {
146
  'action':'query',
 
149
  'format':'json',
150
  }
151
  results = requests.get(self.url, params=params).json()
152
+ pageids = [(page['pageid'],page['title']) for page in results['query']['search']]
153
+ pageids = pageids[:topn]
154
  self.pageids = pageids
155
+
156
+ def ids_to_pages(self,ids):
157
  """
158
  Use MediaWiki API to retrieve page content corresponding to
159
+ a list of pageids
160
+ Parameters:
161
  -----------
162
+ ids : list(tuple(int,str))
163
+ A list of Wikipedia (pageid,title) tuples
164
  Returns: None, but stores
165
  --------
166
+ pages : list(tuple(str,str))
167
+ The k-th enry is a tuple consisting of the title and page content
168
+ of the page corresponding to the k-th entry of ids
169
  """
170
+ pages = []
171
+ for pageid in ids:
 
172
  try:
173
+ page = wikipedia.page(pageid=pageid[0],auto_suggest=False)
174
+ pages.append((page.title, page.content))
175
+ except wikipedia.DisambiguationError:
176
  continue
177
+ return pages
178
+
179
+ def get_all_pages(self):
180
  """
181
+ Use MediaWiki API to retrieve page content corresponding to
182
+ the list of pageids in self.pageids
183
  Parameters: None
184
  -----------
185
  Returns: None, but stores
186
  --------
187
+ self.pages : list(tuple(str,str))
188
+ The k-th enry is a tuple consisting of the title and page content
189
+ of the page corresponding to the k-th entry of self.pageids
190
+ """
191
+ assert self.pageids is not None, "No pageids exist. Get pageids first using self.get_pageids"
192
+ self.pages = self.ids_to_pages(self.pageids)
193
+
194
+ def pages_to_paragraphs(self,pages):
195
+ """
196
+ Process a list of pages into a list of paragraphs from those pages
197
+ Parameters:
198
+ -----------
199
+ pages : list(str)
200
+ A list of Wikipedia page content dumps, as strings
201
+ Returns:
202
+ --------
203
+ paragraphs : dict
204
+ keys are titles of pages from pages (as strings)
205
+ paragraphs[page] is a list of paragraphs (as strings)
206
+ extracted from page
207
  """
 
208
  # Content from WikiMedia has these headings. We only grab content appearing
209
  # before the first instance of any of these
210
  pattern = '|'.join([
 
221
  '=== Citations ===',
222
  ])
223
  pattern = re.compile(pattern)
224
+ paragraphs = {}
225
+ for page in pages:
226
  # Truncate page to the first index of the start of a matching heading,
227
  # or the end of the page if no matches exist
228
+ title, content = page
229
+ idx = min([match.start() for match in pattern.finditer(content)]+[len(content)])
230
+ content = content[:idx]
231
  # Split into paragraphs, omitting lines with headings (start with '='),
232
  # empty lines, or lines like '\t\t' or '\t\t\t' which sometimes appear
233
+ paragraphs[title] = [
234
+ p for p in content.split('\n') if p \
235
  and not p.startswith('=') \
236
+ and not p.startswith('\t\t') \
237
+ and not p.startswith(' ')
238
  ]
239
+ return paragraphs
240
+
241
+ def get_all_paragraphs(self):
242
+ """
243
+ Process self.pages into list of paragraphs from pages
244
+ Parameters: None
245
+ -----------
246
+ Returns: None, but stores
247
+ --------
248
+ self.paragraphs : dict
249
+ keys are titles of pages from self.pages (as strings)
250
+ self.paragraphs[page] is a list of paragraphs (as strings)
251
+ extracted from page
252
+ """
253
+ assert self.pages is not None, "No page content exists. Get pages first using self.get_pages"
254
+ # Content from WikiMedia has these headings. We only grab content appearing
255
+ # before the first instance of any of these
256
+ self.paragraphs = self.pages_to_paragraphs(self.pages)
257
 
258
+ def rank_paragraphs(self,paragraphs,query,topn=10):
259
  """
260
+ Ranks the elements of paragraphs in descending order
261
+ by relevance to query using BM25 Okapi, and returns top
262
+ topn results
263
  Parameters:
264
  -----------
265
+ paragraphs : dict
266
+ keys are titles of pages (as strings)
267
+ paragraphs[page] is a list of paragraphs (as strings)
268
+ extracted from page
269
  query : str
270
  The query to use in ranking paragraphs by relevance
271
+ topn : int or None
272
  The number of most relevant paragraphs to return
273
+ If None, will return roughly the top 1/4 of the
274
+ paragraphs
275
+ Returns:
276
  --------
277
+ best_paragraphs : list(list(str,str))
278
+ The k-th entry is a list [title,paragraph] for the k-th
279
+ most relevant paragraph, where title is the title of the
280
+ Wikipedia article from which that paragraph was sourced
281
  """
282
+ corpus, titles, page_nums = [],[],[]
283
+ # Compile paragraphs into corpus
284
+ for i,page in enumerate(paragraphs):
285
+ titles.append(page)
286
+ paras = paragraphs[page]
287
+ corpus += paras
288
+ page_nums += len(paras)*[i]
289
+
290
+ # Tokenize corpus and query and initialize bm25 object
291
+ tokenized_corpus = [p.split(" ") for p in corpus]
292
+ bm25 = BM25Okapi(tokenized_corpus)
293
  tokenized_query = query.split(" ")
 
294
 
295
+ # Compute scores and compile tuples (paragraph number, score, page number)
296
+ # before sorting tuples by score
297
+ bm_scores = bm25.get_scores(tokenized_query)
298
+ paragraph_data = [[i,score,page_nums[i]] for i,score in enumerate(bm_scores)]
299
+ paragraph_data.sort(reverse=True,key=lambda p:p[1])
300
+
301
+ # Grab topn best [title,paragraph] pairs sorted by bm25 score
302
+ topn = len(paragraph_data)//4+1 if topn is None else min(topn,len(paragraph_data))
303
+
304
+ best_paragraphs = [[titles[p[2]],corpus[p[0]]] for p in paragraph_data[:topn]]
305
+ return best_paragraphs
306
+
307
+ def generate_answer(pipeline,paragraphs, question):
308
+ """
309
+ Generate an answer using a question-answer pipeline
310
+ Parameters:
311
+ -----------
312
+ pipeline : transformers.QuestionAnsweringPipeline
313
+ The question answering pipeline object
314
+ paragraphs : list(list(str,str))
315
+ The k-th entry is a list [title,paragraph] consisting
316
+ of a context paragraph and the title of the page from which the
317
+ paragraph was sourced
318
+ question : str
319
+ A question that is to be answered based on context given
320
+ in the entries of paragraphs
321
+
322
+ Returns:
323
+ --------
324
+ response : str
325
+ A response indicating the answer that was discovered,
326
+ or indicating that no answer could be found.
327
+ """
328
+ # For each paragraph, format input to QA pipeline...
329
+ for paragraph in paragraphs:
330
+ input = {
331
+ 'context':paragraph[1],
332
+ 'question':question,
333
+ }
334
+ # ...and pass to QA pipeline
335
+ output = pipeline(**input)
336
+ # Append answers and scores. Report score of
337
+ # zero for paragraphs without answer, so they are
338
+ # deprioritized when the max is taken below
339
+ if output['answer']!='':
340
+ paragraph += [output['answer'],output['score']]
341
+ else:
342
+ paragraph += ['',0]
343
+ # Get paragraph with max confidence score and collect data
344
+ best_paragraph = max(paragraphs,key = lambda x:x[3])
345
+ best_answer = best_paragraph[2]
346
+ best_context_page = best_paragraph[0]
347
+ best_context = best_paragraph[1]
348
+
349
+ # Update response in session state
350
+ if best_answer == "":
351
+ response = "I cannot find the answer to your question."
352
+ else:
353
+ response = f"""
354
+ My answer is: {best_answer}
355
+
356
+ ...and here's where I found it:
357
+
358
+ Page title: {best_context_page}
359
+
360
+ Paragraph containing answer:
361
+
362
+ {best_context}
363
+ """
364
+ return response