File size: 5,286 Bytes
75d38ea
01e29bf
 
75d38ea
 
 
 
 
 
 
 
 
 
 
01e29bf
 
 
 
 
 
 
 
75d38ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbe205b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d38ea
cbe205b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
import sqlite3
import requests
import streamlit as st
from langgraph.errors import GraphRecursionError
from langchain_groq import ChatGroq
from agent import SQLAgentRAG
from tools import retriever
from constant import GROQ_API_KEY, CONFIG

# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def download_sqlite_db(url, local_filename='travel.sqlite'):
    # Download the file
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    return local_filename

# Initialize the language model
llm = ChatGroq(
    model="llama3-8b-8192",
    api_key=GROQ_API_KEY,
    temperature=0.1,
    verbose=True
)

# Initialize SQL Agent
agent = SQLAgentRAG(llm=llm, tools=retriever)

def query_rag_agent(query: str):
    """
    Handle a query through the RAG Agent, producing an SQL response if applicable.

    Parameters:
        - query (str): The input query to process.

    Returns:
        - Tuple[str, List[str]]: The response content and SQL query if applicable.

    Raises:
        - GraphRecursionError: If there's a recursion limit reached within the agent's graph.
    """
    try:
        output = agent.graph.invoke({"messages": query}, CONFIG)
        response = output["messages"][-1].content
        sql_query = output.get("sql_query", ["No SQL query generated"])[-1]

        logger.info(f"Query processed successfully: {query}")
        return response, sql_query

    except GraphRecursionError:
        logger.error("Graph recursion limit reached; query processing failed.")
        return "Graph recursion limit reached. No SQL result generated.", ""

def main():
    with st.sidebar:
        st.header("About Project")
        st.markdown(
            """
            RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data, 
            particularly in SQL databases. RAG-Agent SQL uses two main components:
            - Retrieval: Retrieving relevant information from the database based on a given question or input.
            - Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results.
    
            to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra)
            """
        )
    
        st.header("Example Question")
        st.markdown(
            """
            - How many different aircraft models are there? And what are the models?
            - What is the aircraft model with the longest range?
            - Which airports are located in the city of Basel?
            - Can you please provide information on what I asked before?
            - What are the fare conditions available on Boeing 777-300?
            - What is the total amount of bookings made in April 2024?
            - What is the scheduled arrival time of flight number QR0051?
            - Which car rental services are available in Basel?
            - Which seat was assigned to the boarding pass with ticket number 0060005435212351?
            - Which trip recommendations are related to history in Basel?
            - How many tickets were sold for Business class on flight 30625?
            - Which hotels are located in Zurich?
            """
        )

    # Main Application Title
    st.title("RAG SQL-Agent")
    
    # Initialize session state for storing chat messages
    if "messages" not in st.session_state:
        st.session_state.messages = []
    
    # Display conversation history from session state
    for message in st.session_state.messages:
        role = message.get("role", "assistant")
        with st.chat_message(role):
            if "output" in message:
                st.markdown(message["output"])
            if "sql_query" in message and message["sql_query"]:
                with st.expander("SQL Query", expanded=True):
                    st.code(message["sql_query"])
    
    # Input form for user prompt
    if prompt := st.chat_input("What do you want to know?"):
        st.chat_message("user").markdown(prompt)
        st.session_state.messages.append({"role": "user", "output": prompt})
    
        # Fetch response from RAG agent function directly
        with st.spinner("Searching for an answer..."):
            output_text, sql_query = query_rag_agent(prompt)
    
        # Display assistant response and SQL query
        st.chat_message("assistant").markdown(output_text)
        if sql_query:
            with st.expander("SQL Query", expanded=True):
                st.code(sql_query)
    
        # Append assistant response to session state
        st.session_state.messages.append(
            {
                "role": "assistant",
                "output": output_text,
                "sql_query": sql_query,
            }
        )

if __name__ == "__main__":
    URL = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite"
    download_sqlite_db(URL)
    main()