Spaces:
Running
Running
import logging | |
import sqlite3 | |
import requests | |
import streamlit as st | |
from langgraph.errors import GraphRecursionError | |
from langchain_groq import ChatGroq | |
from agent import SQLAgentRAG | |
from tools import retriever | |
from constant import GROQ_API_KEY, CONFIG | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def download_sqlite_db(url, local_filename='travel.sqlite'): | |
# Download the file | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
with open(local_filename, 'wb') as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return local_filename | |
# Initialize the language model | |
llm = ChatGroq( | |
model="llama3-8b-8192", | |
api_key=GROQ_API_KEY, | |
temperature=0.1, | |
verbose=True | |
) | |
# Initialize SQL Agent | |
agent = SQLAgentRAG(llm=llm, tools=retriever) | |
def query_rag_agent(query: str): | |
""" | |
Handle a query through the RAG Agent, producing an SQL response if applicable. | |
Parameters: | |
- query (str): The input query to process. | |
Returns: | |
- Tuple[str, List[str]]: The response content and SQL query if applicable. | |
Raises: | |
- GraphRecursionError: If there's a recursion limit reached within the agent's graph. | |
""" | |
try: | |
output = agent.graph.invoke({"messages": query}, CONFIG) | |
response = output["messages"][-1].content | |
sql_query = output.get("sql_query", ["No SQL query generated"])[-1] | |
logger.info(f"Query processed successfully: {query}") | |
return response, sql_query | |
except GraphRecursionError: | |
logger.error("Graph recursion limit reached; query processing failed.") | |
return "Graph recursion limit reached. No SQL result generated.", "" | |
def main(): | |
with st.sidebar: | |
st.header("About Project") | |
st.markdown( | |
""" | |
RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data, | |
particularly in SQL databases. RAG-Agent SQL uses two main components: | |
- Retrieval: Retrieving relevant information from the database based on a given question or input. | |
- Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results. | |
to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra) | |
""" | |
) | |
st.header("Example Question") | |
st.markdown( | |
""" | |
- How many different aircraft models are there? And what are the models? | |
- What is the aircraft model with the longest range? | |
- Which airports are located in the city of Basel? | |
- Can you please provide information on what I asked before? | |
- What are the fare conditions available on Boeing 777-300? | |
- What is the total amount of bookings made in April 2024? | |
- What is the scheduled arrival time of flight number QR0051? | |
- Which car rental services are available in Basel? | |
- Which seat was assigned to the boarding pass with ticket number 0060005435212351? | |
- Which trip recommendations are related to history in Basel? | |
- How many tickets were sold for Business class on flight 30625? | |
- Which hotels are located in Zurich? | |
""" | |
) | |
# Main Application Title | |
st.title("RAG SQL-Agent") | |
# Initialize session state for storing chat messages | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display conversation history from session state | |
for message in st.session_state.messages: | |
role = message.get("role", "assistant") | |
with st.chat_message(role): | |
if "output" in message: | |
st.markdown(message["output"]) | |
if "sql_query" in message and message["sql_query"]: | |
with st.expander("SQL Query", expanded=True): | |
st.code(message["sql_query"]) | |
# Input form for user prompt | |
if prompt := st.chat_input("What do you want to know?"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "output": prompt}) | |
# Fetch response from RAG agent function directly | |
with st.spinner("Searching for an answer..."): | |
output_text, sql_query = query_rag_agent(prompt) | |
# Display assistant response and SQL query | |
st.chat_message("assistant").markdown(output_text) | |
if sql_query: | |
with st.expander("SQL Query", expanded=True): | |
st.code(sql_query) | |
# Append assistant response to session state | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"output": output_text, | |
"sql_query": sql_query, | |
} | |
) | |
if __name__ == "__main__": | |
URL = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite" | |
download_sqlite_db(URL) | |
main() | |