ajaynagotha commited on
Commit
5f84801
1 Parent(s): 3e027f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -28
app.py CHANGED
@@ -8,11 +8,9 @@ from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
11
- # Set up logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
  logger = logging.getLogger(__name__)
14
 
15
- # Add a handler to write logs to a file
16
  file_handler = logging.FileHandler('app.log')
17
  file_handler.setLevel(logging.INFO)
18
  file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
@@ -22,7 +20,7 @@ logger.info("Starting the application")
22
 
23
  try:
24
  logger.info("Loading the dataset")
25
- ds = load_dataset("knowrohit07/gita_dataset")
26
  logger.info("Dataset loaded successfully")
27
  except Exception as e:
28
  logger.error(f"Error loading dataset: {str(e)}")
@@ -30,7 +28,7 @@ except Exception as e:
30
 
31
  try:
32
  logger.info("Loading the model and tokenizer")
33
- model_name = "deepset/roberta-base-squad2"
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
36
  logger.info("Model and tokenizer loaded successfully")
@@ -39,56 +37,45 @@ except Exception as e:
39
  sys.exit(1)
40
 
41
  def clean_answer(answer):
42
- # Remove special tokens and leading/trailing whitespace
43
  special_tokens = set(tokenizer.all_special_tokens)
44
  cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
45
  return cleaned_answer.strip()
46
 
47
  def answer_question(question):
48
  logger.info(f"Received question: {question}")
49
-
50
  try:
51
  logger.info("Combining text from dataset")
52
  context = " ".join([item['Text'] for item in ds['train']])
53
  logger.info(f"Combined context length: {len(context)} characters")
54
-
55
  logger.info("Tokenizing input")
56
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=512, truncation=True)
57
  logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
58
-
59
  logger.info("Getting model output")
60
  outputs = model(**inputs)
61
  logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
62
-
63
  logger.info("Processing output to get answer")
64
  answer_start = torch.argmax(outputs.start_logits)
65
  answer_end = torch.argmax(outputs.end_logits) + 1
66
  raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
67
  answer = clean_answer(raw_answer)
68
  logger.info(f"Generated answer: {answer}")
69
-
70
  if not answer:
71
  logger.warning("Generated answer was empty after cleaning")
72
  answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
73
-
74
  logger.info("Answer generated successfully")
75
-
76
  return answer
77
-
78
  except Exception as e:
79
  logger.error(f"Error in answer_question function: {str(e)}")
80
  return "I'm sorry, but an error occurred while processing your question. Please try again later."
81
 
82
- # FastAPI setup
83
  app = FastAPI()
84
 
85
- # Add CORS middleware
86
  app.add_middleware(
87
  CORSMiddleware,
88
- allow_origins=["*"], # Allows all origins
89
  allow_credentials=True,
90
- allow_methods=["*"], # Allows all methods
91
- allow_headers=["*"], # Allows all headers
92
  )
93
 
94
  class Question(BaseModel):
@@ -98,24 +85,17 @@ class Question(BaseModel):
98
  async def predict(question: Question):
99
  try:
100
  last_user_message = next((msg for msg in reversed(question.messages) if msg['role'] == 'user'), None)
101
-
102
  if not last_user_message:
103
  raise HTTPException(status_code=400, detail="No user message found")
104
-
105
  user_question = last_user_message['content']
106
-
107
  answer = answer_question(user_question)
108
-
109
- disclaimer = "\n\nPlease note: This response is generated by an AI model based on the Bhagavad Gita. For authoritative information, please consult the original text or scholarly sources."
110
  full_response = answer + disclaimer
111
-
112
  return {"response": full_response, "isTruncated": False}
113
-
114
  except Exception as e:
115
  logger.error(f"Error in predict function: {str(e)}")
116
  raise HTTPException(status_code=500, detail=str(e))
117
 
118
- # Gradio interface
119
  iface = gr.Interface(
120
  fn=answer_question,
121
  inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
@@ -124,10 +104,8 @@ iface = gr.Interface(
124
  description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
125
  )
126
 
127
- # Mount Gradio app to FastAPI
128
  app = gr.mount_gradio_app(app, iface, path="/")
129
 
130
- # Run the FastAPI app
131
  if __name__ == "__main__":
132
  import uvicorn
133
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  from pydantic import BaseModel
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
 
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
  logger = logging.getLogger(__name__)
13
 
 
14
  file_handler = logging.FileHandler('app.log')
15
  file_handler.setLevel(logging.INFO)
16
  file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
 
20
 
21
  try:
22
  logger.info("Loading the dataset")
23
+ ds = load_dataset("adarshxs/gita")
24
  logger.info("Dataset loaded successfully")
25
  except Exception as e:
26
  logger.error(f"Error loading dataset: {str(e)}")
 
28
 
29
  try:
30
  logger.info("Loading the model and tokenizer")
31
+ model_name = "microsoft/deberta-v3-large"
32
  tokenizer = AutoTokenizer.from_pretrained(model_name)
33
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
34
  logger.info("Model and tokenizer loaded successfully")
 
37
  sys.exit(1)
38
 
39
  def clean_answer(answer):
 
40
  special_tokens = set(tokenizer.all_special_tokens)
41
  cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
42
  return cleaned_answer.strip()
43
 
44
  def answer_question(question):
45
  logger.info(f"Received question: {question}")
 
46
  try:
47
  logger.info("Combining text from dataset")
48
  context = " ".join([item['Text'] for item in ds['train']])
49
  logger.info(f"Combined context length: {len(context)} characters")
 
50
  logger.info("Tokenizing input")
51
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=512, truncation=True)
52
  logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
 
53
  logger.info("Getting model output")
54
  outputs = model(**inputs)
55
  logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
 
56
  logger.info("Processing output to get answer")
57
  answer_start = torch.argmax(outputs.start_logits)
58
  answer_end = torch.argmax(outputs.end_logits) + 1
59
  raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
60
  answer = clean_answer(raw_answer)
61
  logger.info(f"Generated answer: {answer}")
 
62
  if not answer:
63
  logger.warning("Generated answer was empty after cleaning")
64
  answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
 
65
  logger.info("Answer generated successfully")
 
66
  return answer
 
67
  except Exception as e:
68
  logger.error(f"Error in answer_question function: {str(e)}")
69
  return "I'm sorry, but an error occurred while processing your question. Please try again later."
70
 
 
71
  app = FastAPI()
72
 
 
73
  app.add_middleware(
74
  CORSMiddleware,
75
+ allow_origins=["*"],
76
  allow_credentials=True,
77
+ allow_methods=["*"],
78
+ allow_headers=["*"],
79
  )
80
 
81
  class Question(BaseModel):
 
85
  async def predict(question: Question):
86
  try:
87
  last_user_message = next((msg for msg in reversed(question.messages) if msg['role'] == 'user'), None)
 
88
  if not last_user_message:
89
  raise HTTPException(status_code=400, detail="No user message found")
 
90
  user_question = last_user_message['content']
 
91
  answer = answer_question(user_question)
92
+ disclaimer = "\n\n---Please note: This response is generated by an AI model based on the Bhagavad Gita. For authoritative information, please consult the original text or scholarly sources."
 
93
  full_response = answer + disclaimer
 
94
  return {"response": full_response, "isTruncated": False}
 
95
  except Exception as e:
96
  logger.error(f"Error in predict function: {str(e)}")
97
  raise HTTPException(status_code=500, detail=str(e))
98
 
 
99
  iface = gr.Interface(
100
  fn=answer_question,
101
  inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
 
104
  description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
105
  )
106
 
 
107
  app = gr.mount_gradio_app(app, iface, path="/")
108
 
 
109
  if __name__ == "__main__":
110
  import uvicorn
111
  uvicorn.run(app, host="0.0.0.0", port=7860)