File size: 4,582 Bytes
75d38ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import logging
import streamlit as st
from dotenv import load_dotenv, find_dotenv
from langgraph.errors import GraphRecursionError
from langchain_groq import ChatGroq
from agent import SQLAgentRAG
from tools import retriever
from constant import GROQ_API_KEY, CONFIG

# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv(find_dotenv())

# Initialize the language model
llm = ChatGroq(
    model="llama3-8b-8192",
    api_key=GROQ_API_KEY,
    temperature=0.1,
    verbose=True
)

# Initialize SQL Agent
agent = SQLAgentRAG(llm=llm, tools=retriever)

def query_rag_agent(query: str):
    """
    Handle a query through the RAG Agent, producing an SQL response if applicable.

    Parameters:
        - query (str): The input query to process.

    Returns:
        - Tuple[str, List[str]]: The response content and SQL query if applicable.

    Raises:
        - GraphRecursionError: If there's a recursion limit reached within the agent's graph.
    """
    try:
        output = agent.graph.invoke({"messages": query}, CONFIG)
        response = output["messages"][-1].content
        sql_query = output.get("sql_query", ["No SQL query generated"])[-1]

        logger.info(f"Query processed successfully: {query}")
        return response, sql_query

    except GraphRecursionError:
        logger.error("Graph recursion limit reached; query processing failed.")
        return "Graph recursion limit reached. No SQL result generated.", ""

with st.sidebar:
    st.header("About Project")
    st.markdown(
        """
        RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data, 
        particularly in SQL databases. RAG-Agent SQL uses two main components:
        - Retrieval: Retrieving relevant information from the database based on a given question or input.
        - Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results.

        to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra)
        """
    )

    st.header("Example Question")
    st.markdown(
        """
        - How many different aircraft models are there? And what are the models?
        - What is the aircraft model with the longest range?
        - Which airports are located in the city of Basel?
        - Can you please provide information on what I asked before?
        - What are the fare conditions available on Boeing 777-300?
        - What is the total amount of bookings made in April 2024?
        - What is the scheduled arrival time of flight number QR0051?
        - Which car rental services are available in Basel?
        - Which seat was assigned to the boarding pass with ticket number 0060005435212351?
        - Which trip recommendations are related to history in Basel?
        - How many tickets were sold for Business class on flight 30625?
        - Which hotels are located in Zurich?
        """
    )

# Main Application Title
st.title("RAG SQL-Agent")

# Initialize session state for storing chat messages
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display conversation history from session state
for message in st.session_state.messages:
    role = message.get("role", "assistant")
    with st.chat_message(role):
        if "output" in message:
            st.markdown(message["output"])
        if "sql_query" in message and message["sql_query"]:
            with st.expander("SQL Query", expanded=True):
                st.code(message["sql_query"])

# Input form for user prompt
if prompt := st.chat_input("What do you want to know?"):
    st.chat_message("user").markdown(prompt)
    st.session_state.messages.append({"role": "user", "output": prompt})

    # Fetch response from RAG agent function directly
    with st.spinner("Searching for an answer..."):
        output_text, sql_query = query_rag_agent(prompt)

    # Display assistant response and SQL query
    st.chat_message("assistant").markdown(output_text)
    if sql_query:
        with st.expander("SQL Query", expanded=True):
            st.code(sql_query)

    # Append assistant response to session state
    st.session_state.messages.append(
        {
            "role": "assistant",
            "output": output_text,
            "sql_query": sql_query,
        }
    )