Spaces:
Running
Running
import asyncio | |
import streamlit as st | |
from text_processing import segment_text | |
from keyword_extraction import extract_keywords | |
from utils import QuestionGenerationError | |
from mapping_keywords import map_keywords_to_sentences | |
from option_generation import gen_options, generate_options_async | |
from fill_in_the_blanks_generation import generate_fill_in_the_blank_questions | |
from load_models import load_nlp_models, load_qa_models, load_model | |
nlp, s2v = load_nlp_models() | |
similarity_model, spell = load_qa_models() | |
def assess_question_quality(context, question, answer): | |
# Assess relevance using cosine similarity | |
context_doc = nlp(context) | |
question_doc = nlp(question) | |
relevance_score = context_doc.similarity(question_doc) | |
# Assess complexity using token length (as a simple metric) | |
complexity_score = min(len(question_doc) / 20, 1) # Normalize to 0-1 | |
# Assess Spelling correctness | |
misspelled = spell.unknown(question.split()) | |
spelling_correctness = 1 - (len(misspelled) / len(question.split())) # Normalize to 0-1 | |
# Calculate overall score (you can adjust weights as needed) | |
overall_score = ( | |
0.4 * relevance_score + | |
0.4 * complexity_score + | |
0.2 * spelling_correctness | |
) | |
return overall_score, relevance_score, complexity_score, spelling_correctness | |
async def process_batch(batch, keywords, context_window_size, num_beams, num_questions, modelname): | |
questions = [] | |
print("inside process batch function") | |
flag = False | |
for text in batch: | |
if flag: | |
break | |
keyword_sentence_mapping = map_keywords_to_sentences(text, keywords, context_window_size) | |
print(keyword_sentence_mapping) | |
for keyword, context in keyword_sentence_mapping.items(): | |
print("Length of questions list from process batch function: ",len(questions)) | |
if len(questions)>=num_questions: | |
flag = True | |
break | |
question = await generate_question_async(context, keyword, num_beams,modelname) | |
options = await generate_options_async(keyword, context) | |
# options = gen_options(keyword, context, question) | |
blank_question = await generate_fill_in_the_blank_questions(context,keyword) | |
overall_score, relevance_score, complexity_score, spelling_correctness = assess_question_quality(context, question, keyword) | |
if overall_score >= 0.5: | |
questions.append({ | |
"question": question, | |
"context": context, | |
"answer": keyword, | |
"options": options, | |
"overall_score": overall_score, | |
"relevance_score": relevance_score, | |
"complexity_score": complexity_score, | |
"spelling_correctness": spelling_correctness, | |
"blank_question": blank_question, | |
}) | |
return questions | |
async def generate_question_async(context, answer, num_beams,modelname): | |
model, tokenizer = load_model(modelname) | |
try: | |
input_text = f"<context> {context} <answer> {answer}" | |
print(f"\n{input_text}\n") | |
input_ids = tokenizer.encode(input_text, return_tensors='pt') | |
outputs = await asyncio.to_thread(model.generate, input_ids, num_beams=num_beams, early_stopping=True, max_length=250) | |
question = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(f"\n{question}\n") | |
# print(type(question)) | |
return question | |
except Exception as e: | |
raise QuestionGenerationError(f"Error in question generation: {str(e)}") | |
# Function to generate questions using beam search | |
async def generate_questions_async(text, num_questions, context_window_size, num_beams, extract_all_keywords,modelname): | |
try: | |
batches = segment_text(text.lower()) | |
keywords = extract_keywords(text, extract_all_keywords) | |
all_questions = [] | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
print("Final keywords:",keywords) | |
print("Number of questions that needs to be generated: ",num_questions) | |
print("totoal no of batches:", batches) | |
for i, batch in enumerate(batches): | |
print("batch no: ", len(batches)) | |
status_text.text(f"Processing batch {i+1} of {len(batches)}...") | |
batch_questions = await process_batch(batch, keywords, context_window_size, num_beams,num_questions,modelname) | |
all_questions.extend(batch_questions) | |
progress_bar.progress((i + 1) / len(batches)) | |
print("Length of the all questions list: ",len(all_questions)) | |
if len(all_questions) >= num_questions: | |
break | |
progress_bar.empty() | |
status_text.empty() | |
return all_questions[:num_questions] | |
except QuestionGenerationError as e: | |
st.error(f"An error occurred during question generation: {str(e)}") | |
return [] | |
except Exception as e: | |
st.error(f"An unexpected error occurred: {str(e)}") | |
return [] | |