ajaynagotha's picture
Update app.py
738d0f3 verified
raw
history blame
5.27 kB
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load dataset
logger.info("Loading the dataset")
ds = load_dataset("knowrohit07/gita_dataset")
logger.info("Dataset loaded successfully")
# Load model and tokenizer
logger.info("Loading the model and tokenizer")
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Model and tokenizer loaded successfully. Using device: {device}")
# Preprocess the dataset
logger.info("Preprocessing the dataset")
context = " ".join([item.get('Text', '') for item in ds['train']])
logger.info(f"Combined context length: {len(context)} characters")
def clean_answer(answer):
special_tokens = set(tokenizer.all_special_tokens)
cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
return cleaned_answer.strip()
def answer_question(question):
logger.info(f"Received question: {question}")
try:
# Implement sliding window approach
max_length = 1024
stride = 512
answers = []
for i in range(0, len(context), stride):
chunk = context[i:i+max_length]
inputs = tokenizer.encode_plus(
question,
chunk,
return_tensors="pt",
max_length=max_length,
truncation=True,
padding='max_length'
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
ans = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
)
score = torch.max(outputs.start_logits) + torch.max(outputs.end_logits)
answers.append((ans, score.item()))
# Break if we have a good answer
if score > 10: # Adjust this threshold as needed
break
# Select best answer
best_answer = max(answers, key=lambda x: x[1])[0]
# Post-processing
best_answer = clean_answer(best_answer)
best_answer = best_answer.capitalize()
logger.info(f"Generated answer: {best_answer}")
if not best_answer or len(best_answer) < 5:
logger.warning("Generated answer was empty or too short after cleaning")
best_answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
logger.info("Answer generated successfully")
return best_answer
except Exception as e:
logger.error(f"Error in answer_question function: {str(e)}")
return "I'm sorry, but an error occurred while processing your question. Please try again later."
# FastAPI setup
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Question(BaseModel):
messages: list
@app.post("/predict")
async def predict(question: Question):
try:
last_user_message = next((msg for msg in reversed(question.messages) if msg['role'] == 'user'), None)
if not last_user_message:
raise HTTPException(status_code=400, detail="No user message found")
user_question = last_user_message['content']
answer = answer_question(user_question)
disclaimer = "\n\n---Please note: This response is generated by an AI model based on the Bhagavad Gita. For authoritative information, please consult the original text or scholarly sources."
full_response = answer + disclaimer
return {"response": full_response, "isTruncated": False}
except Exception as e:
logger.error(f"Error in predict function: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
iface = gr.Interface(
fn=answer_question,
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
outputs="text",
title="Bhagavad Gita Q&A",
description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")
# For Hugging Face Spaces
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)