DevBM commited on
Commit
a563a42
1 Parent(s): c1799d4

improving batch processing for better performance

Browse files
Files changed (1) hide show
  1. app.py +103 -50
app.py CHANGED
@@ -27,9 +27,13 @@ from transformers import pipeline
27
  import re
28
  import pymupdf
29
  import uuid
 
 
 
30
  print("***************************************************************")
31
 
32
  st.set_page_config(
 
33
  page_title="Question Generator",
34
  initial_sidebar_state="auto",
35
  menu_items={
@@ -38,6 +42,7 @@ st.set_page_config(
38
  )
39
 
40
  st.set_option('deprecation.showPyplotGlobalUse',False)
 
41
  # Initialize Wikipedia API with a user agent
42
  user_agent = 'QGen/1.0 (channingfisher7@gmail.com)'
43
  wiki_wiki = wikipediaapi.Wikipedia(user_agent= user_agent,language='en')
@@ -87,11 +92,16 @@ def load_qa_models():
87
  spell = SpellChecker()
88
  return similarity_model, spell
89
 
 
 
 
 
 
 
90
  nlp, s2v = load_nlp_models()
91
- model, tokenizer = load_model('DevBM/t5-large-squad')
92
  similarity_model, spell = load_qa_models()
93
  context_model = similarity_model
94
-
95
  # Info Section
96
  def display_info():
97
  st.sidebar.title("Information")
@@ -127,7 +137,7 @@ def display_info():
127
  # Text Preprocessing Function
128
  def preprocess_text(text):
129
  # Remove newlines and extra spaces
130
- text = re.sub(r'\s+', ' ', text)
131
  return text
132
 
133
  def get_pdf_text(pdf_file):
@@ -159,11 +169,11 @@ def save_feedback(question, answer,rating):
159
  # Function to clean text
160
  def clean_text(text):
161
  text = re.sub(r"[^\x00-\x7F]", " ", text)
 
162
  return text
163
 
164
  # Function to create text chunks
165
- def segment_text(text, max_segment_length=500):
166
- """Segment the text into smaller chunks."""
167
  sentences = sent_tokenize(text)
168
  segments = []
169
  current_segment = ""
@@ -177,8 +187,11 @@ def segment_text(text, max_segment_length=500):
177
 
178
  if current_segment:
179
  segments.append(current_segment.strip())
180
- print(f"\n\nSegement Chunks: {segments}\n\n")
181
- return segments
 
 
 
182
 
183
  # Function to extract keywords using combined techniques
184
  def extract_keywords(text, extract_all):
@@ -302,14 +315,82 @@ def entity_linking(keyword):
302
  return page.fullurl
303
  return None
304
 
305
- # Function to generate questions using beam search
306
- def generate_question(context, answer, num_beams):
307
  input_text = f"<context> {context} <answer> {answer}"
 
308
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
309
- outputs = model.generate(input_ids, num_beams=num_beams, early_stopping=True, max_length=150)
310
  question = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
311
  return question
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  # Function to export questions to CSV
314
  def export_to_csv(data):
315
  # df = pd.DataFrame(data, columns=["Context", "Answer", "Question", "Options"])
@@ -375,6 +456,7 @@ def main():
375
  st.title(":blue[Question Generator System]")
376
  session_id = get_session_id()
377
  state = initialize_state(session_id)
 
378
  with st.sidebar:
379
  show_info = st.toggle('Show Info',True)
380
  if show_info:
@@ -382,24 +464,21 @@ def main():
382
  st.subheader("Customization Options")
383
  # Customization options
384
  input_type = st.radio("Select Input Preference", ("Text Input","Upload PDF"))
385
- num_beams = st.slider("Select number of beams for question generation", min_value=1, max_value=10, value=5)
386
- context_window_size = st.slider("Select context window size (number of sentences before and after)", min_value=1, max_value=5, value=1)
387
- num_questions = st.slider("Select number of questions to generate", min_value=1, max_value=1000, value=5)
388
  with st.expander("Choose the Additional Elements to show"):
389
  show_context = st.checkbox("Context",True)
390
  show_answer = st.checkbox("Answer",True)
391
  show_options = st.checkbox("Options",False)
392
  show_entity_link = st.checkbox("Entity Link For Wikipedia",True)
393
  show_qa_scores = st.checkbox("QA Score",False)
 
 
 
394
  col1, col2 = st.columns(2)
395
  with col1:
396
  extract_all_keywords = st.toggle("Extract Max Keywords",value=False)
397
  with col2:
398
  enable_feedback_mode = st.toggle("Enable Feedback Mode",False)
399
- use_t5_small = st.toggle("Use T5-Small",False)
400
- # set_state(session_id, 'generated_questions', state['generated_questions'])
401
- if use_t5_small is True:
402
- model, tokenizer = load_model('AneriThakkar/flan-t5-small-finetuned')
403
  text = None
404
  if input_type == "Text Input":
405
  text = st.text_area("Enter text here:", value="Joe Biden, the current US president is on a weak wicket going in for his reelection later this November against former President Donald Trump.")
@@ -409,45 +488,19 @@ def main():
409
  text = get_pdf_text(file)
410
  if text:
411
  text = clean_text(text)
412
- segments = segment_text(text)
413
  generate_questions_button = st.button("Generate Questions")
414
  q_count = 0
415
- if generate_questions_button:
416
- state['generated_questions'] = []
417
- # st.session_state.generated_questions = []
418
- for text in segments:
419
- keywords = extract_keywords(text, extract_all_keywords)
420
- print(f"\n\nFinal Keywords in Main Function: {keywords}\n\n")
421
- keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
422
- for i, (keyword, context) in enumerate(keyword_sentence_mapping.items()):
423
- if i >= num_questions:
424
- break
425
- if q_count>num_questions:
426
- break
427
- question = generate_question(context, keyword, num_beams=num_beams)
428
- options = generate_options(keyword,context)
429
- overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context,question,keyword)
430
- if overall_score < 0.5:
431
- continue
432
- tpl = {
433
- "question" : question,
434
- "context" : context,
435
- "answer" : keyword,
436
- "options" : options,
437
- "overall_score" : overall_score,
438
- "relevance_score" : relevance_score,
439
- "complexity_score" : complexity_score,
440
- "spelling_correctness" : spelling_correctness,
441
- }
442
- print("\n\n",tpl,"\n\n")
443
- # st.session_state.generated_questions.append(tpl)
444
- state['generated_questions'].append(tpl)
445
- q_count += 1
446
  print("\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n")
447
  data = get_state(session_id)
448
  print(data)
 
 
449
  set_state(session_id, 'generated_questions', state['generated_questions'])
450
- a = get_state(session_id)
451
 
452
  # sort question based on their quality score
453
  state['generated_questions'] = sorted(state['generated_questions'],key = lambda x: x['overall_score'], reverse=True)
 
27
  import re
28
  import pymupdf
29
  import uuid
30
+ import time
31
+ import asyncio
32
+ import aiohttp
33
  print("***************************************************************")
34
 
35
  st.set_page_config(
36
+ page_icon='cyclone',
37
  page_title="Question Generator",
38
  initial_sidebar_state="auto",
39
  menu_items={
 
42
  )
43
 
44
  st.set_option('deprecation.showPyplotGlobalUse',False)
45
+
46
  # Initialize Wikipedia API with a user agent
47
  user_agent = 'QGen/1.0 (channingfisher7@gmail.com)'
48
  wiki_wiki = wikipediaapi.Wikipedia(user_agent= user_agent,language='en')
 
92
  spell = SpellChecker()
93
  return similarity_model, spell
94
 
95
+ with st.sidebar:
96
+ select_model = st.selectbox("Select Model", ("T5-large","T5-small"))
97
+ if select_model == "T5-large":
98
+ modelname = "DevBM/t5-large-squad"
99
+ elif select_model == "T5-small":
100
+ modelname = "AneriThakkar/flan-t5-small-finetuned"
101
  nlp, s2v = load_nlp_models()
 
102
  similarity_model, spell = load_qa_models()
103
  context_model = similarity_model
104
+ model, tokenizer = load_model(modelname)
105
  # Info Section
106
  def display_info():
107
  st.sidebar.title("Information")
 
137
  # Text Preprocessing Function
138
  def preprocess_text(text):
139
  # Remove newlines and extra spaces
140
+ text = re.sub(r'[\n]', ' ', text)
141
  return text
142
 
143
  def get_pdf_text(pdf_file):
 
169
  # Function to clean text
170
  def clean_text(text):
171
  text = re.sub(r"[^\x00-\x7F]", " ", text)
172
+ text = re.sub(f"[\n]"," ", text)
173
  return text
174
 
175
  # Function to create text chunks
176
+ def segment_text(text, max_segment_length=700, batch_size=7):
 
177
  sentences = sent_tokenize(text)
178
  segments = []
179
  current_segment = ""
 
187
 
188
  if current_segment:
189
  segments.append(current_segment.strip())
190
+
191
+ # Create batches
192
+ batches = [segments[i:i + batch_size] for i in range(0, len(segments), batch_size)]
193
+ return batches
194
+
195
 
196
  # Function to extract keywords using combined techniques
197
  def extract_keywords(text, extract_all):
 
315
  return page.fullurl
316
  return None
317
 
318
+ async def generate_question_async(context, answer, num_beams):
 
319
  input_text = f"<context> {context} <answer> {answer}"
320
+ print(f"\n{input_text}\n")
321
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
322
+ outputs = await asyncio.to_thread(model.generate, input_ids, num_beams=num_beams, early_stopping=True, max_length=250)
323
  question = tokenizer.decode(outputs[0], skip_special_tokens=True)
324
+ print(f"\n{question}\n")
325
  return question
326
 
327
+ async def generate_options_async(answer, context, n=3):
328
+ options = [answer]
329
+
330
+ # Add contextually relevant words using a pre-trained model
331
+ context_embedding = await asyncio.to_thread(context_model.encode, context)
332
+ answer_embedding = await asyncio.to_thread(context_model.encode, answer)
333
+ context_words = [token.text for token in nlp(context) if token.is_alpha and token.text.lower() != answer.lower()]
334
+
335
+ # Compute similarity scores and sort context words
336
+ similarity_scores = [util.pytorch_cos_sim(await asyncio.to_thread(context_model.encode, word), answer_embedding).item() for word in context_words]
337
+ sorted_context_words = [word for _, word in sorted(zip(similarity_scores, context_words), reverse=True)]
338
+ options.extend(sorted_context_words[:n])
339
+
340
+ # Try to get similar words based on sense2vec
341
+ similar_words = await asyncio.to_thread(get_similar_words_sense2vec, answer, n)
342
+ options.extend(similar_words)
343
+
344
+ # If we don't have enough options, try synonyms
345
+ if len(options) < n + 1:
346
+ synonyms = await asyncio.to_thread(get_synonyms, answer, n - len(options) + 1)
347
+ options.extend(synonyms)
348
+
349
+ # Ensure we have the correct number of unique options
350
+ options = list(dict.fromkeys(options))[:n+1]
351
+
352
+ # Shuffle the options
353
+ random.shuffle(options)
354
+
355
+ return options
356
+
357
+
358
+ # Function to generate questions using beam search
359
+ async def generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords):
360
+ batches = segment_text(text)
361
+ keywords = extract_keywords(text, extract_all_keywords)
362
+ all_questions = []
363
+
364
+ for batch in batches:
365
+ batch_questions = await process_batch(batch, keywords, context_window_size, num_beams)
366
+ all_questions.extend(batch_questions)
367
+ if len(all_questions) >= num_questions:
368
+ break
369
+
370
+ return all_questions[:num_questions]
371
+
372
+
373
+ async def process_batch(batch, keywords, context_window_size, num_beams):
374
+ questions = []
375
+ for text in batch:
376
+ keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size)
377
+ for keyword, context in keyword_sentence_mapping.items():
378
+ question = await generate_question_async(context, keyword, num_beams)
379
+ options = await generate_options_async(keyword, context)
380
+ overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword)
381
+ if overall_score >= 0.5:
382
+ questions.append({
383
+ "question": question,
384
+ "context": context,
385
+ "answer": keyword,
386
+ "options": options,
387
+ "overall_score": overall_score,
388
+ "relevance_score": relevance_score,
389
+ "complexity_score": complexity_score,
390
+ "spelling_correctness": spelling_correctness,
391
+ })
392
+ return questions
393
+
394
  # Function to export questions to CSV
395
  def export_to_csv(data):
396
  # df = pd.DataFrame(data, columns=["Context", "Answer", "Question", "Options"])
 
456
  st.title(":blue[Question Generator System]")
457
  session_id = get_session_id()
458
  state = initialize_state(session_id)
459
+
460
  with st.sidebar:
461
  show_info = st.toggle('Show Info',True)
462
  if show_info:
 
464
  st.subheader("Customization Options")
465
  # Customization options
466
  input_type = st.radio("Select Input Preference", ("Text Input","Upload PDF"))
 
 
 
467
  with st.expander("Choose the Additional Elements to show"):
468
  show_context = st.checkbox("Context",True)
469
  show_answer = st.checkbox("Answer",True)
470
  show_options = st.checkbox("Options",False)
471
  show_entity_link = st.checkbox("Entity Link For Wikipedia",True)
472
  show_qa_scores = st.checkbox("QA Score",False)
473
+ num_beams = st.slider("Select number of beams for question generation", min_value=2, max_value=10, value=2)
474
+ context_window_size = st.slider("Select context window size (number of sentences before and after)", min_value=1, max_value=5, value=1)
475
+ num_questions = st.slider("Select number of questions to generate", min_value=1, max_value=1000, value=5)
476
  col1, col2 = st.columns(2)
477
  with col1:
478
  extract_all_keywords = st.toggle("Extract Max Keywords",value=False)
479
  with col2:
480
  enable_feedback_mode = st.toggle("Enable Feedback Mode",False)
481
+
 
 
 
482
  text = None
483
  if input_type == "Text Input":
484
  text = st.text_area("Enter text here:", value="Joe Biden, the current US president is on a weak wicket going in for his reelection later this November against former President Donald Trump.")
 
488
  text = get_pdf_text(file)
489
  if text:
490
  text = clean_text(text)
 
491
  generate_questions_button = st.button("Generate Questions")
492
  q_count = 0
493
+ # if generate_questions_button:
494
+ if generate_questions_button and text:
495
+ start_time = time.time()
496
+ with st.spinner("Generating questions..."):
497
+ state['generated_questions'] = asyncio.run(generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  print("\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n")
499
  data = get_state(session_id)
500
  print(data)
501
+ end_time = time.time()
502
+ print(f"Time Taken to generate: {end_time-start_time}")
503
  set_state(session_id, 'generated_questions', state['generated_questions'])
 
504
 
505
  # sort question based on their quality score
506
  state['generated_questions'] = sorted(state['generated_questions'],key = lambda x: x['overall_score'], reverse=True)