Spaces:
Running
Running
File size: 5,286 Bytes
75d38ea 01e29bf 75d38ea 01e29bf 75d38ea cbe205b 75d38ea cbe205b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import logging
import sqlite3
import requests
import streamlit as st
from langgraph.errors import GraphRecursionError
from langchain_groq import ChatGroq
from agent import SQLAgentRAG
from tools import retriever
from constant import GROQ_API_KEY, CONFIG
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def download_sqlite_db(url, local_filename='travel.sqlite'):
# Download the file
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return local_filename
# Initialize the language model
llm = ChatGroq(
model="llama3-8b-8192",
api_key=GROQ_API_KEY,
temperature=0.1,
verbose=True
)
# Initialize SQL Agent
agent = SQLAgentRAG(llm=llm, tools=retriever)
def query_rag_agent(query: str):
"""
Handle a query through the RAG Agent, producing an SQL response if applicable.
Parameters:
- query (str): The input query to process.
Returns:
- Tuple[str, List[str]]: The response content and SQL query if applicable.
Raises:
- GraphRecursionError: If there's a recursion limit reached within the agent's graph.
"""
try:
output = agent.graph.invoke({"messages": query}, CONFIG)
response = output["messages"][-1].content
sql_query = output.get("sql_query", ["No SQL query generated"])[-1]
logger.info(f"Query processed successfully: {query}")
return response, sql_query
except GraphRecursionError:
logger.error("Graph recursion limit reached; query processing failed.")
return "Graph recursion limit reached. No SQL result generated.", ""
def main():
with st.sidebar:
st.header("About Project")
st.markdown(
"""
RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data,
particularly in SQL databases. RAG-Agent SQL uses two main components:
- Retrieval: Retrieving relevant information from the database based on a given question or input.
- Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results.
to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra)
"""
)
st.header("Example Question")
st.markdown(
"""
- How many different aircraft models are there? And what are the models?
- What is the aircraft model with the longest range?
- Which airports are located in the city of Basel?
- Can you please provide information on what I asked before?
- What are the fare conditions available on Boeing 777-300?
- What is the total amount of bookings made in April 2024?
- What is the scheduled arrival time of flight number QR0051?
- Which car rental services are available in Basel?
- Which seat was assigned to the boarding pass with ticket number 0060005435212351?
- Which trip recommendations are related to history in Basel?
- How many tickets were sold for Business class on flight 30625?
- Which hotels are located in Zurich?
"""
)
# Main Application Title
st.title("RAG SQL-Agent")
# Initialize session state for storing chat messages
if "messages" not in st.session_state:
st.session_state.messages = []
# Display conversation history from session state
for message in st.session_state.messages:
role = message.get("role", "assistant")
with st.chat_message(role):
if "output" in message:
st.markdown(message["output"])
if "sql_query" in message and message["sql_query"]:
with st.expander("SQL Query", expanded=True):
st.code(message["sql_query"])
# Input form for user prompt
if prompt := st.chat_input("What do you want to know?"):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "output": prompt})
# Fetch response from RAG agent function directly
with st.spinner("Searching for an answer..."):
output_text, sql_query = query_rag_agent(prompt)
# Display assistant response and SQL query
st.chat_message("assistant").markdown(output_text)
if sql_query:
with st.expander("SQL Query", expanded=True):
st.code(sql_query)
# Append assistant response to session state
st.session_state.messages.append(
{
"role": "assistant",
"output": output_text,
"sql_query": sql_query,
}
)
if __name__ == "__main__":
URL = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
download_sqlite_db(URL)
main()
|