import logging import streamlit as st from dotenv import load_dotenv, find_dotenv from langgraph.errors import GraphRecursionError from langchain_groq import ChatGroq from agent import SQLAgentRAG from tools import retriever from constant import GROQ_API_KEY, CONFIG # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load environment variables load_dotenv(find_dotenv()) # Initialize the language model llm = ChatGroq( model="llama3-8b-8192", api_key=GROQ_API_KEY, temperature=0.1, verbose=True ) # Initialize SQL Agent agent = SQLAgentRAG(llm=llm, tools=retriever) def query_rag_agent(query: str): """ Handle a query through the RAG Agent, producing an SQL response if applicable. Parameters: - query (str): The input query to process. Returns: - Tuple[str, List[str]]: The response content and SQL query if applicable. Raises: - GraphRecursionError: If there's a recursion limit reached within the agent's graph. """ try: output = agent.graph.invoke({"messages": query}, CONFIG) response = output["messages"][-1].content sql_query = output.get("sql_query", ["No SQL query generated"])[-1] logger.info(f"Query processed successfully: {query}") return response, sql_query except GraphRecursionError: logger.error("Graph recursion limit reached; query processing failed.") return "Graph recursion limit reached. No SQL result generated.", "" with st.sidebar: st.header("About Project") st.markdown( """ RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data, particularly in SQL databases. RAG-Agent SQL uses two main components: - Retrieval: Retrieving relevant information from the database based on a given question or input. - Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results. to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra) """ ) st.header("Example Question") st.markdown( """ - How many different aircraft models are there? And what are the models? - What is the aircraft model with the longest range? - Which airports are located in the city of Basel? - Can you please provide information on what I asked before? - What are the fare conditions available on Boeing 777-300? - What is the total amount of bookings made in April 2024? - What is the scheduled arrival time of flight number QR0051? - Which car rental services are available in Basel? - Which seat was assigned to the boarding pass with ticket number 0060005435212351? - Which trip recommendations are related to history in Basel? - How many tickets were sold for Business class on flight 30625? - Which hotels are located in Zurich? """ ) # Main Application Title st.title("RAG SQL-Agent") # Initialize session state for storing chat messages if "messages" not in st.session_state: st.session_state.messages = [] # Display conversation history from session state for message in st.session_state.messages: role = message.get("role", "assistant") with st.chat_message(role): if "output" in message: st.markdown(message["output"]) if "sql_query" in message and message["sql_query"]: with st.expander("SQL Query", expanded=True): st.code(message["sql_query"]) # Input form for user prompt if prompt := st.chat_input("What do you want to know?"): st.chat_message("user").markdown(prompt) st.session_state.messages.append({"role": "user", "output": prompt}) # Fetch response from RAG agent function directly with st.spinner("Searching for an answer..."): output_text, sql_query = query_rag_agent(prompt) # Display assistant response and SQL query st.chat_message("assistant").markdown(output_text) if sql_query: with st.expander("SQL Query", expanded=True): st.code(sql_query) # Append assistant response to session state st.session_state.messages.append( { "role": "assistant", "output": output_text, "sql_query": sql_query, } )