ajaynagotha commited on
Commit
702d4ed
1 Parent(s): 0f56688

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -25
app.py CHANGED
@@ -3,38 +3,25 @@ from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
4
  import torch
5
  import logging
6
- import sys
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
 
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
  logger = logging.getLogger(__name__)
13
 
14
- file_handler = logging.FileHandler('app.log')
15
- file_handler.setLevel(logging.INFO)
16
- file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
17
- logger.addHandler(file_handler)
18
 
19
- logger.info("Starting the application")
20
-
21
- try:
22
- logger.info("Loading the dataset")
23
- ds = load_dataset("adarshxs/gita")
24
- logger.info("Dataset loaded successfully")
25
- except Exception as e:
26
- logger.error(f"Error loading dataset: {str(e)}")
27
- sys.exit(1)
28
-
29
- try:
30
- logger.info("Loading the model and tokenizer")
31
- model_name = "deepset/roberta-large-squad2"
32
- tokenizer = AutoTokenizer.from_pretrained(model_name)
33
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
34
- logger.info("Model and tokenizer loaded successfully")
35
- except Exception as e:
36
- logger.error(f"Error loading model or tokenizer: {str(e)}")
37
- sys.exit(1)
38
 
39
  def clean_answer(answer):
40
  special_tokens = set(tokenizer.all_special_tokens)
@@ -45,7 +32,7 @@ def answer_question(question):
45
  logger.info(f"Received question: {question}")
46
  try:
47
  logger.info("Combining text from dataset")
48
- context = " ".join([item['Text'] for item in ds['train']])
49
  logger.info(f"Combined context length: {len(context)} characters")
50
  logger.info("Tokenizing input")
51
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=512, truncation=True)
@@ -68,8 +55,10 @@ def answer_question(question):
68
  logger.error(f"Error in answer_question function: {str(e)}")
69
  return "I'm sorry, but an error occurred while processing your question. Please try again later."
70
 
 
71
  app = FastAPI()
72
 
 
73
  app.add_middleware(
74
  CORSMiddleware,
75
  allow_origins=["*"],
@@ -96,6 +85,7 @@ async def predict(question: Question):
96
  logger.error(f"Error in predict function: {str(e)}")
97
  raise HTTPException(status_code=500, detail=str(e))
98
 
 
99
  iface = gr.Interface(
100
  fn=answer_question,
101
  inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
@@ -104,8 +94,10 @@ iface = gr.Interface(
104
  description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
105
  )
106
 
 
107
  app = gr.mount_gradio_app(app, iface, path="/")
108
 
 
109
  if __name__ == "__main__":
110
  import uvicorn
111
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
4
  import torch
5
  import logging
 
6
  from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from fastapi.middleware.cors import CORSMiddleware
9
 
10
+ # Set up logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Load dataset
15
+ logger.info("Loading the dataset")
16
+ ds = load_dataset("adarshxs/gita")
17
+ logger.info("Dataset loaded successfully")
18
 
19
+ # Load model and tokenizer
20
+ logger.info("Loading the model and tokenizer")
21
+ model_name = "deepset/roberta-large-squad2"
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
24
+ logger.info("Model and tokenizer loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def clean_answer(answer):
27
  special_tokens = set(tokenizer.all_special_tokens)
 
32
  logger.info(f"Received question: {question}")
33
  try:
34
  logger.info("Combining text from dataset")
35
+ context = " ".join([item.get('Text', '') for item in ds['train']])
36
  logger.info(f"Combined context length: {len(context)} characters")
37
  logger.info("Tokenizing input")
38
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=512, truncation=True)
 
55
  logger.error(f"Error in answer_question function: {str(e)}")
56
  return "I'm sorry, but an error occurred while processing your question. Please try again later."
57
 
58
+ # FastAPI setup
59
  app = FastAPI()
60
 
61
+ # Add CORS middleware
62
  app.add_middleware(
63
  CORSMiddleware,
64
  allow_origins=["*"],
 
85
  logger.error(f"Error in predict function: {str(e)}")
86
  raise HTTPException(status_code=500, detail=str(e))
87
 
88
+ # Gradio interface
89
  iface = gr.Interface(
90
  fn=answer_question,
91
  inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
 
94
  description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
95
  )
96
 
97
+ # Mount Gradio app to FastAPI
98
  app = gr.mount_gradio_app(app, iface, path="/")
99
 
100
+ # For local development and testing
101
  if __name__ == "__main__":
102
  import uvicorn
103
  uvicorn.run(app, host="0.0.0.0", port=7860)