Spaces:
Sleeping
Sleeping
ajaynagotha
commited on
Commit
•
5f84801
1
Parent(s):
3e027f8
Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,9 @@ from fastapi import FastAPI, HTTPException
|
|
8 |
from pydantic import BaseModel
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
|
11 |
-
# Set up logging
|
12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
-
# Add a handler to write logs to a file
|
16 |
file_handler = logging.FileHandler('app.log')
|
17 |
file_handler.setLevel(logging.INFO)
|
18 |
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
@@ -22,7 +20,7 @@ logger.info("Starting the application")
|
|
22 |
|
23 |
try:
|
24 |
logger.info("Loading the dataset")
|
25 |
-
ds = load_dataset("
|
26 |
logger.info("Dataset loaded successfully")
|
27 |
except Exception as e:
|
28 |
logger.error(f"Error loading dataset: {str(e)}")
|
@@ -30,7 +28,7 @@ except Exception as e:
|
|
30 |
|
31 |
try:
|
32 |
logger.info("Loading the model and tokenizer")
|
33 |
-
model_name = "
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
35 |
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
36 |
logger.info("Model and tokenizer loaded successfully")
|
@@ -39,56 +37,45 @@ except Exception as e:
|
|
39 |
sys.exit(1)
|
40 |
|
41 |
def clean_answer(answer):
|
42 |
-
# Remove special tokens and leading/trailing whitespace
|
43 |
special_tokens = set(tokenizer.all_special_tokens)
|
44 |
cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
|
45 |
return cleaned_answer.strip()
|
46 |
|
47 |
def answer_question(question):
|
48 |
logger.info(f"Received question: {question}")
|
49 |
-
|
50 |
try:
|
51 |
logger.info("Combining text from dataset")
|
52 |
context = " ".join([item['Text'] for item in ds['train']])
|
53 |
logger.info(f"Combined context length: {len(context)} characters")
|
54 |
-
|
55 |
logger.info("Tokenizing input")
|
56 |
inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=512, truncation=True)
|
57 |
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
58 |
-
|
59 |
logger.info("Getting model output")
|
60 |
outputs = model(**inputs)
|
61 |
logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
|
62 |
-
|
63 |
logger.info("Processing output to get answer")
|
64 |
answer_start = torch.argmax(outputs.start_logits)
|
65 |
answer_end = torch.argmax(outputs.end_logits) + 1
|
66 |
raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
|
67 |
answer = clean_answer(raw_answer)
|
68 |
logger.info(f"Generated answer: {answer}")
|
69 |
-
|
70 |
if not answer:
|
71 |
logger.warning("Generated answer was empty after cleaning")
|
72 |
answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
|
73 |
-
|
74 |
logger.info("Answer generated successfully")
|
75 |
-
|
76 |
return answer
|
77 |
-
|
78 |
except Exception as e:
|
79 |
logger.error(f"Error in answer_question function: {str(e)}")
|
80 |
return "I'm sorry, but an error occurred while processing your question. Please try again later."
|
81 |
|
82 |
-
# FastAPI setup
|
83 |
app = FastAPI()
|
84 |
|
85 |
-
# Add CORS middleware
|
86 |
app.add_middleware(
|
87 |
CORSMiddleware,
|
88 |
-
allow_origins=["*"],
|
89 |
allow_credentials=True,
|
90 |
-
allow_methods=["*"],
|
91 |
-
allow_headers=["*"],
|
92 |
)
|
93 |
|
94 |
class Question(BaseModel):
|
@@ -98,24 +85,17 @@ class Question(BaseModel):
|
|
98 |
async def predict(question: Question):
|
99 |
try:
|
100 |
last_user_message = next((msg for msg in reversed(question.messages) if msg['role'] == 'user'), None)
|
101 |
-
|
102 |
if not last_user_message:
|
103 |
raise HTTPException(status_code=400, detail="No user message found")
|
104 |
-
|
105 |
user_question = last_user_message['content']
|
106 |
-
|
107 |
answer = answer_question(user_question)
|
108 |
-
|
109 |
-
disclaimer = "\n\nPlease note: This response is generated by an AI model based on the Bhagavad Gita. For authoritative information, please consult the original text or scholarly sources."
|
110 |
full_response = answer + disclaimer
|
111 |
-
|
112 |
return {"response": full_response, "isTruncated": False}
|
113 |
-
|
114 |
except Exception as e:
|
115 |
logger.error(f"Error in predict function: {str(e)}")
|
116 |
raise HTTPException(status_code=500, detail=str(e))
|
117 |
|
118 |
-
# Gradio interface
|
119 |
iface = gr.Interface(
|
120 |
fn=answer_question,
|
121 |
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
|
@@ -124,10 +104,8 @@ iface = gr.Interface(
|
|
124 |
description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
|
125 |
)
|
126 |
|
127 |
-
# Mount Gradio app to FastAPI
|
128 |
app = gr.mount_gradio_app(app, iface, path="/")
|
129 |
|
130 |
-
# Run the FastAPI app
|
131 |
if __name__ == "__main__":
|
132 |
import uvicorn
|
133 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
8 |
from pydantic import BaseModel
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
|
|
|
11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
14 |
file_handler = logging.FileHandler('app.log')
|
15 |
file_handler.setLevel(logging.INFO)
|
16 |
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
|
|
20 |
|
21 |
try:
|
22 |
logger.info("Loading the dataset")
|
23 |
+
ds = load_dataset("adarshxs/gita")
|
24 |
logger.info("Dataset loaded successfully")
|
25 |
except Exception as e:
|
26 |
logger.error(f"Error loading dataset: {str(e)}")
|
|
|
28 |
|
29 |
try:
|
30 |
logger.info("Loading the model and tokenizer")
|
31 |
+
model_name = "microsoft/deberta-v3-large"
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
34 |
logger.info("Model and tokenizer loaded successfully")
|
|
|
37 |
sys.exit(1)
|
38 |
|
39 |
def clean_answer(answer):
|
|
|
40 |
special_tokens = set(tokenizer.all_special_tokens)
|
41 |
cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
|
42 |
return cleaned_answer.strip()
|
43 |
|
44 |
def answer_question(question):
|
45 |
logger.info(f"Received question: {question}")
|
|
|
46 |
try:
|
47 |
logger.info("Combining text from dataset")
|
48 |
context = " ".join([item['Text'] for item in ds['train']])
|
49 |
logger.info(f"Combined context length: {len(context)} characters")
|
|
|
50 |
logger.info("Tokenizing input")
|
51 |
inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=512, truncation=True)
|
52 |
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
|
|
53 |
logger.info("Getting model output")
|
54 |
outputs = model(**inputs)
|
55 |
logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
|
|
|
56 |
logger.info("Processing output to get answer")
|
57 |
answer_start = torch.argmax(outputs.start_logits)
|
58 |
answer_end = torch.argmax(outputs.end_logits) + 1
|
59 |
raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
|
60 |
answer = clean_answer(raw_answer)
|
61 |
logger.info(f"Generated answer: {answer}")
|
|
|
62 |
if not answer:
|
63 |
logger.warning("Generated answer was empty after cleaning")
|
64 |
answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
|
|
|
65 |
logger.info("Answer generated successfully")
|
|
|
66 |
return answer
|
|
|
67 |
except Exception as e:
|
68 |
logger.error(f"Error in answer_question function: {str(e)}")
|
69 |
return "I'm sorry, but an error occurred while processing your question. Please try again later."
|
70 |
|
|
|
71 |
app = FastAPI()
|
72 |
|
|
|
73 |
app.add_middleware(
|
74 |
CORSMiddleware,
|
75 |
+
allow_origins=["*"],
|
76 |
allow_credentials=True,
|
77 |
+
allow_methods=["*"],
|
78 |
+
allow_headers=["*"],
|
79 |
)
|
80 |
|
81 |
class Question(BaseModel):
|
|
|
85 |
async def predict(question: Question):
|
86 |
try:
|
87 |
last_user_message = next((msg for msg in reversed(question.messages) if msg['role'] == 'user'), None)
|
|
|
88 |
if not last_user_message:
|
89 |
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
90 |
user_question = last_user_message['content']
|
|
|
91 |
answer = answer_question(user_question)
|
92 |
+
disclaimer = "\n\n---Please note: This response is generated by an AI model based on the Bhagavad Gita. For authoritative information, please consult the original text or scholarly sources."
|
|
|
93 |
full_response = answer + disclaimer
|
|
|
94 |
return {"response": full_response, "isTruncated": False}
|
|
|
95 |
except Exception as e:
|
96 |
logger.error(f"Error in predict function: {str(e)}")
|
97 |
raise HTTPException(status_code=500, detail=str(e))
|
98 |
|
|
|
99 |
iface = gr.Interface(
|
100 |
fn=answer_question,
|
101 |
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
|
|
|
104 |
description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
|
105 |
)
|
106 |
|
|
|
107 |
app = gr.mount_gradio_app(app, iface, path="/")
|
108 |
|
|
|
109 |
if __name__ == "__main__":
|
110 |
import uvicorn
|
111 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|