Spaces:
Sleeping
Sleeping
ajaynagotha
commited on
Commit
•
74d109e
1
Parent(s):
738d0f3
Update app.py
Browse files
app.py
CHANGED
@@ -18,17 +18,10 @@ logger.info("Dataset loaded successfully")
|
|
18 |
|
19 |
# Load model and tokenizer
|
20 |
logger.info("Loading the model and tokenizer")
|
21 |
-
model_name = "
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
24 |
-
|
25 |
-
model.to(device)
|
26 |
-
logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
|
27 |
-
|
28 |
-
# Preprocess the dataset
|
29 |
-
logger.info("Preprocessing the dataset")
|
30 |
-
context = " ".join([item.get('Text', '') for item in ds['train']])
|
31 |
-
logger.info(f"Combined context length: {len(context)} characters")
|
32 |
|
33 |
def clean_answer(answer):
|
34 |
special_tokens = set(tokenizer.all_special_tokens)
|
@@ -38,56 +31,26 @@ def clean_answer(answer):
|
|
38 |
def answer_question(question):
|
39 |
logger.info(f"Received question: {question}")
|
40 |
try:
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
with torch.no_grad():
|
61 |
-
outputs = model(**inputs)
|
62 |
-
|
63 |
-
answer_start = torch.argmax(outputs.start_logits)
|
64 |
-
answer_end = torch.argmax(outputs.end_logits) + 1
|
65 |
-
|
66 |
-
ans = tokenizer.convert_tokens_to_string(
|
67 |
-
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
|
68 |
-
)
|
69 |
-
|
70 |
-
score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
|
71 |
-
answers.append((ans, score.item()))
|
72 |
-
|
73 |
-
# Break if we have a good answer
|
74 |
-
if score > 10: # Adjust this threshold as needed
|
75 |
-
break
|
76 |
-
|
77 |
-
# Select best answer
|
78 |
-
best_answer = max(answers, key=lambda x: x[1])[0]
|
79 |
-
|
80 |
-
# Post-processing
|
81 |
-
best_answer = clean_answer(best_answer)
|
82 |
-
best_answer = best_answer.capitalize()
|
83 |
-
|
84 |
-
logger.info(f"Generated answer: {best_answer}")
|
85 |
-
if not best_answer or len(best_answer) < 5:
|
86 |
-
logger.warning("Generated answer was empty or too short after cleaning")
|
87 |
-
best_answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
|
88 |
-
|
89 |
logger.info("Answer generated successfully")
|
90 |
-
return
|
91 |
except Exception as e:
|
92 |
logger.error(f"Error in answer_question function: {str(e)}")
|
93 |
return "I'm sorry, but an error occurred while processing your question. Please try again later."
|
@@ -134,7 +97,7 @@ iface = gr.Interface(
|
|
134 |
# Mount Gradio app to FastAPI
|
135 |
app = gr.mount_gradio_app(app, iface, path="/")
|
136 |
|
137 |
-
# For
|
138 |
if __name__ == "__main__":
|
139 |
import uvicorn
|
140 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
18 |
|
19 |
# Load model and tokenizer
|
20 |
logger.info("Loading the model and tokenizer")
|
21 |
+
model_name = "deepset/roberta-large-squad2"
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
24 |
+
logger.info("Model and tokenizer loaded successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def clean_answer(answer):
|
27 |
special_tokens = set(tokenizer.all_special_tokens)
|
|
|
31 |
def answer_question(question):
|
32 |
logger.info(f"Received question: {question}")
|
33 |
try:
|
34 |
+
logger.info("Combining text from dataset")
|
35 |
+
context = " ".join([item.get('Text', '') for item in ds['train']])
|
36 |
+
logger.info(f"Combined context length: {len(context)} characters")
|
37 |
+
logger.info("Tokenizing input")
|
38 |
+
inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=514, truncation=True)
|
39 |
+
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
40 |
+
logger.info("Getting model output")
|
41 |
+
outputs = model(**inputs)
|
42 |
+
logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
|
43 |
+
logger.info("Processing output to get answer")
|
44 |
+
answer_start = torch.argmax(outputs.start_logits)
|
45 |
+
answer_end = torch.argmax(outputs.end_logits) + 1
|
46 |
+
raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
|
47 |
+
answer = clean_answer(raw_answer)
|
48 |
+
logger.info(f"Generated answer: {answer}")
|
49 |
+
if not answer:
|
50 |
+
logger.warning("Generated answer was empty after cleaning")
|
51 |
+
answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
logger.info("Answer generated successfully")
|
53 |
+
return answer
|
54 |
except Exception as e:
|
55 |
logger.error(f"Error in answer_question function: {str(e)}")
|
56 |
return "I'm sorry, but an error occurred while processing your question. Please try again later."
|
|
|
97 |
# Mount Gradio app to FastAPI
|
98 |
app = gr.mount_gradio_app(app, iface, path="/")
|
99 |
|
100 |
+
# For local development and testing
|
101 |
if __name__ == "__main__":
|
102 |
import uvicorn
|
103 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|