File size: 4,430 Bytes
3c4e014
7a69d96
5be946a
7a69d96
a09cbc2
ace2b2b
 
 
fdd82d8
702d4ed
a09cbc2
 
ee87c6a
702d4ed
 
385bfcf
702d4ed
3e027f8
702d4ed
 
74d109e
5be946a
 
74d109e
738d0f3
b7473ec
 
 
 
 
ace2b2b
a09cbc2
 
74d109e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09cbc2
74d109e
a09cbc2
 
 
 
702d4ed
ace2b2b
 
702d4ed
ace2b2b
 
5f84801
ace2b2b
5f84801
 
ace2b2b
 
 
 
 
 
 
 
 
 
 
 
 
5f84801
ace2b2b
 
 
 
 
 
702d4ed
7a69d96
3c4e014
ace2b2b
3c4e014
 
7a69d96
fdd82d8
 
702d4ed
3e027f8
 
74d109e
e468338
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load dataset
logger.info("Loading the dataset")
ds = load_dataset("knowrohit07/gita_dataset")
logger.info("Dataset loaded successfully")

# Load model and tokenizer
logger.info("Loading the model and tokenizer")
model_name = "deepset/roberta-large-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
logger.info("Model and tokenizer loaded successfully")

def clean_answer(answer):
    special_tokens = set(tokenizer.all_special_tokens)
    cleaned_answer = ' '.join(token for token in answer.split() if token not in special_tokens)
    return cleaned_answer.strip()

def answer_question(question):
    logger.info(f"Received question: {question}")
    try:
        logger.info("Combining text from dataset")
        context = " ".join([item.get('Text', '') for item in ds['train']])
        logger.info(f"Combined context length: {len(context)} characters")
        logger.info("Tokenizing input")
        inputs = tokenizer.encode_plus(question, context, return_tensors="pt", max_length=514, truncation=True)
        logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
        logger.info("Getting model output")
        outputs = model(**inputs)
        logger.info(f"Output logits shapes: start={outputs.start_logits.shape}, end={outputs.end_logits.shape}")
        logger.info("Processing output to get answer")
        answer_start = torch.argmax(outputs.start_logits)
        answer_end = torch.argmax(outputs.end_logits) + 1
        raw_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
        answer = clean_answer(raw_answer)
        logger.info(f"Generated answer: {answer}")
        if not answer:
            logger.warning("Generated answer was empty after cleaning")
            answer = "I'm sorry, but I couldn't find a specific answer to that question based on the Bhagavad Gita. Could you please rephrase your question or ask about one of the core concepts like dharma, karma, bhakti, or the different types of yoga discussed in the Gita?"
        logger.info("Answer generated successfully")
        return answer
    except Exception as e:
        logger.error(f"Error in answer_question function: {str(e)}")
        return "I'm sorry, but an error occurred while processing your question. Please try again later."

# FastAPI setup
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class Question(BaseModel):
    messages: list

@app.post("/predict")
async def predict(question: Question):
    try:
        last_user_message = next((msg for msg in reversed(question.messages) if msg['role'] == 'user'), None)
        if not last_user_message:
            raise HTTPException(status_code=400, detail="No user message found")
        user_question = last_user_message['content']
        answer = answer_question(user_question)
        disclaimer = "\n\n---Please note: This response is generated by an AI model based on the Bhagavad Gita. For authoritative information, please consult the original text or scholarly sources."
        full_response = answer + disclaimer
        return {"response": full_response, "isTruncated": False}
    except Exception as e:
        logger.error(f"Error in predict function: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Gradio interface
iface = gr.Interface(
    fn=answer_question,
    inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
    outputs="text",
    title="Bhagavad Gita Q&A",
    description="Ask a question about the Bhagavad Gita, and get an answer based on the dataset."
)

# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, iface, path="/")

# For local development and testing
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)