fahmiaziz commited on
Commit
75d38ea
1 Parent(s): 470836b

Upload 10 files

Browse files
Files changed (9) hide show
  1. agent.py +264 -0
  2. app.py +123 -0
  3. column.jsonl +55 -0
  4. constant.py +9 -0
  5. models.py +9 -0
  6. requirements.txt +14 -0
  7. state.py +9 -0
  8. table.jsonl +11 -0
  9. tools.py +38 -0
agent.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Literal
3
+
4
+ from langgraph.graph import END, StateGraph
5
+ from langchain_community.document_loaders import JSONLoader
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+
9
+ from langgraph.store.memory import InMemoryStore
10
+ from langchain.chat_models.base import BaseChatModel
11
+ from langgraph.checkpoint.memory import MemorySaver
12
+ from langchain_community.utilities import SQLDatabase
13
+
14
+ from constant import DB_PATH
15
+ from state import AgentState
16
+ import logging
17
+
18
+ # Initialize logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class SQLAgentRAG:
24
+ def __init__(
25
+ self,
26
+ llm: BaseChatModel,
27
+ tools: Any,
28
+ db_uri: str = DB_PATH,
29
+ table_json_path: str = "table.jsonl",
30
+ column_json_path: str = "column.jsonl"
31
+ ):
32
+ self.llm = llm
33
+ self.table_json_path = table_json_path
34
+ self.column_json_path = column_json_path
35
+ self.db = SQLDatabase.from_uri(db_uri)
36
+ self.schema = None
37
+ self.retriever = tools
38
+
39
+ # add nodes
40
+ graph = StateGraph(AgentState)
41
+ graph.add_node("router", self.router_node)
42
+ graph.add_node("general_asistant", self._general_asistant)
43
+ graph.add_node("sql_gen", self._sql_gen)
44
+ graph.add_node("validate_sql", self._validate_sql)
45
+ graph.add_node("solve_error", self._solve_error)
46
+ graph.add_node("response", self._query_gen_node)
47
+
48
+ # add edges
49
+ graph.set_entry_point("router")
50
+ graph.add_edge("sql_gen", "validate_sql")
51
+ graph.add_conditional_edges(
52
+ "router",
53
+ self.router,
54
+ {
55
+ "SQL": "sql_gen",
56
+ "GENERAL": "general_asistant"
57
+ }
58
+ )
59
+ graph.add_conditional_edges(
60
+ "validate_sql",
61
+ self._should_continue
62
+ )
63
+ graph.add_edge("solve_error", "validate_sql")
64
+ graph.add_edge("response", END)
65
+ graph.add_edge("general_asistant", END)
66
+
67
+ # compile
68
+ store = InMemoryStore()
69
+ checkpointer = MemorySaver()
70
+ self.graph = graph.compile(checkpointer=checkpointer, store=store)
71
+
72
+
73
+ def _indexing_table(self, query: str):
74
+ """
75
+ Index and retrieve relevant tables based on the input query.
76
+ """
77
+ logger.info("Indexing Table...")
78
+ docs_table = JSONLoader(
79
+ file_path=self.table_json_path,
80
+ jq_schema='.',
81
+ text_content=False,
82
+ json_lines=True
83
+ ).load()
84
+
85
+ retriever = self.retriever(docs_table, k=5, search_type='mmr', lambda_mult=1)
86
+
87
+ matched_documents_table = retriever.invoke(query)
88
+ matched_tables = [
89
+ json.loads(doc.page_content)["table"] for doc in matched_documents_table
90
+ ]
91
+
92
+ return matched_tables
93
+
94
+ def _indexing_column(self, matched_tables, query: str):
95
+ """
96
+ Index and retrieve relevant columns based on the matched tables and query.
97
+ """
98
+ logger.info("Get matched schema...")
99
+ docs_column = JSONLoader(
100
+ file_path=self.column_json_path,
101
+ jq_schema='.',
102
+ text_content=False,
103
+ json_lines=True
104
+ ).load()
105
+
106
+ retriever = self.retriever(docs_column, k=20, search_type='similarity')
107
+
108
+ matched_columns = retriever.invoke(query)
109
+ matched_columns_filtered = [
110
+ json.loads(doc.page_content) for doc in matched_columns
111
+ if json.loads(doc.page_content)["table_name"] in matched_tables
112
+ ]
113
+
114
+ matched_columns_cleaned = [
115
+ f'table_name={doc["table_name"]}|column_name={doc["column_name"]}|data_type={doc["data_type"]}'
116
+ for doc in matched_columns_filtered
117
+ ]
118
+
119
+ return matched_columns_cleaned
120
+
121
+ def _sql_gen(self, state: AgentState):
122
+ """
123
+ Generates a SQL query based on the input provided by the user.
124
+ This function uses the LLM to construct the query from matched tables and columns.
125
+ """
126
+ logger.info("Generate SQL Query...")
127
+ messages = state["messages"][-1].content
128
+ matched_table = self._indexing_table(messages)
129
+ self.schema = self._indexing_column(matched_table, messages)
130
+
131
+ prompt = PromptTemplate(
132
+ template="""
133
+ You are a SQL master expert specializing in writing complex SQL queries for SQLite. Your task is to construct a SQL query based on the provided information. Follow these strict rules:
134
+
135
+ QUERY: {query}
136
+ -------
137
+ MATCHED_SCHEMA: {matched_schema}
138
+ -------
139
+
140
+ Please construct a SQL query using the MATCHED_SCHEMA and the QUERY provided above.
141
+ IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this.
142
+ IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
143
+ NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
144
+ Generate ONLY the SQL query. Do not provide any explanations, comments, or additional text.
145
+
146
+ """,
147
+ input_variables=["query", "matched_schema"]
148
+ )
149
+
150
+ sql_gen = prompt | self.llm | StrOutputParser()
151
+ result_sql = sql_gen.invoke({"query": messages, "matched_schema": self.schema})
152
+ return {"sql_query": [result_sql]}
153
+
154
+ def _validate_sql(self, state: AgentState):
155
+ """
156
+ Validates the generated SQL query by attempting to execute it.
157
+ Returns "" if no errors, otherwise returns the error message.
158
+ """
159
+ logger.info("Validate SQL...")
160
+ query = state["sql_query"][-1]
161
+ try:
162
+ logger.info(f"Query:\n{query}")
163
+ self.db.run(query)
164
+ return {"error_str": ""}
165
+
166
+ except Exception as e:
167
+ return {"error_str": [f"Unexpected Error: {str(e)}"]}
168
+
169
+ def _solve_error(self, state: AgentState):
170
+ """
171
+ Called with the error code and error description as the argument to get guidance on how to solve the error
172
+ """
173
+ error_string = state.get('error_str', "")
174
+ sql_query = state.get('sql_query', [None])[-1]
175
+
176
+ logger.info(f"Error{error_string}")
177
+ logger.info(f"SQL query {sql_query}")
178
+
179
+ prompt = PromptTemplate(
180
+ template="""
181
+ First, identify the main issues with the given SQL query based on the error message.
182
+ {error_string}
183
+
184
+ Next, examine the schema and current SQL query to locate potential sources of the error.
185
+ {schema}
186
+
187
+ Then, modify the current SQL query to fix the error and avoid similar issues in the future.
188
+ {sql_query}
189
+
190
+ Finally, ensure the revised SQL query conforms to the requirements outlined in the original task and provide the corrected SQL query.
191
+ Generate ONLY the SQL query. Do not provide any explanations, comments, or additional text.
192
+
193
+ """,
194
+ input_variables=["error_string", "schema", "sql_query"]
195
+ )
196
+ resolver = prompt | self.llm | StrOutputParser()
197
+ corrected_query = resolver.invoke({
198
+ "error_string": error_string,
199
+ "schema": self.schema,
200
+ "sql_query": sql_query
201
+ })
202
+ return {"sql_query": [corrected_query]}
203
+
204
+ def _query_gen_node(self, state: AgentState):
205
+ """
206
+ Generates a final response after executing the SQL query and getting the result.
207
+ """
208
+ logger.info("Generate Response...")
209
+ query = state.get('sql_query', [None])[-1]
210
+ messages = state["messages"][-1]
211
+ prompt = PromptTemplate(
212
+ template="""Based on the following SQL result, generate a natural language response:
213
+ Query SQL: {user_query}
214
+ SQL Result: {sql_response}
215
+ """,
216
+ input_variables=["user_query", "sql_response"]
217
+ )
218
+
219
+ sql_response = self.db.run(query)
220
+ gen_llm = prompt | self.llm | StrOutputParser()
221
+ response = gen_llm.invoke({"user_query": messages, "sql_response": sql_response})
222
+ return {"messages": [response]}
223
+
224
+ def _general_asistant(self, state: AgentState):
225
+ """Assistan handle question out of context"""
226
+ logger.info("Assistant...")
227
+ messages = state["messages"]
228
+ response = self.llm.invoke(messages)
229
+
230
+ return {"messages": [response], "sql_query": [""]}
231
+
232
+ def router(self, state: AgentState):
233
+ """Router"""
234
+ logger.info("Router...")
235
+ return state["question_type"]
236
+
237
+ def router_node(self, state: AgentState):
238
+ """Router Node"""
239
+ question = state["messages"][-1].content
240
+ prompt = PromptTemplate(
241
+ template= """
242
+ You are a senior specialist of analytical support. Classify incoming questions into one of two types:
243
+ - SQL: Related to flight information, schedules, hotels, rentals, recommendations, and anything about vacations
244
+ - GENERAL: General questions
245
+ Return only one word: SQL, or GENERAL.
246
+
247
+ {question}
248
+ """,
249
+ input_variables=["question"]
250
+ )
251
+ router = prompt | self.llm | StrOutputParser()
252
+ question_type = router.invoke({"question": question})
253
+ return {"question_type": question_type}
254
+
255
+ def _should_continue(self, state: AgentState) -> Literal["response", "solve_error"]:
256
+ """
257
+ Decides whether to proceed based on SQL validation results.
258
+ If the last message contains "error_str",
259
+ go to error-solving or retry. Otherwise,continue to response.
260
+ """
261
+ if state.get("error_str"):
262
+ # If error exists, go to solve_error
263
+ return "solve_error"
264
+ return "response"
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import streamlit as st
3
+ from dotenv import load_dotenv, find_dotenv
4
+ from langgraph.errors import GraphRecursionError
5
+ from langchain_groq import ChatGroq
6
+ from agent import SQLAgentRAG
7
+ from tools import retriever
8
+ from constant import GROQ_API_KEY, CONFIG
9
+
10
+ # Initialize logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Load environment variables
15
+ load_dotenv(find_dotenv())
16
+
17
+ # Initialize the language model
18
+ llm = ChatGroq(
19
+ model="llama3-8b-8192",
20
+ api_key=GROQ_API_KEY,
21
+ temperature=0.1,
22
+ verbose=True
23
+ )
24
+
25
+ # Initialize SQL Agent
26
+ agent = SQLAgentRAG(llm=llm, tools=retriever)
27
+
28
+ def query_rag_agent(query: str):
29
+ """
30
+ Handle a query through the RAG Agent, producing an SQL response if applicable.
31
+
32
+ Parameters:
33
+ - query (str): The input query to process.
34
+
35
+ Returns:
36
+ - Tuple[str, List[str]]: The response content and SQL query if applicable.
37
+
38
+ Raises:
39
+ - GraphRecursionError: If there's a recursion limit reached within the agent's graph.
40
+ """
41
+ try:
42
+ output = agent.graph.invoke({"messages": query}, CONFIG)
43
+ response = output["messages"][-1].content
44
+ sql_query = output.get("sql_query", ["No SQL query generated"])[-1]
45
+
46
+ logger.info(f"Query processed successfully: {query}")
47
+ return response, sql_query
48
+
49
+ except GraphRecursionError:
50
+ logger.error("Graph recursion limit reached; query processing failed.")
51
+ return "Graph recursion limit reached. No SQL result generated.", ""
52
+
53
+ with st.sidebar:
54
+ st.header("About Project")
55
+ st.markdown(
56
+ """
57
+ RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data,
58
+ particularly in SQL databases. RAG-Agent SQL uses two main components:
59
+ - Retrieval: Retrieving relevant information from the database based on a given question or input.
60
+ - Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results.
61
+
62
+ to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra)
63
+ """
64
+ )
65
+
66
+ st.header("Example Question")
67
+ st.markdown(
68
+ """
69
+ - How many different aircraft models are there? And what are the models?
70
+ - What is the aircraft model with the longest range?
71
+ - Which airports are located in the city of Basel?
72
+ - Can you please provide information on what I asked before?
73
+ - What are the fare conditions available on Boeing 777-300?
74
+ - What is the total amount of bookings made in April 2024?
75
+ - What is the scheduled arrival time of flight number QR0051?
76
+ - Which car rental services are available in Basel?
77
+ - Which seat was assigned to the boarding pass with ticket number 0060005435212351?
78
+ - Which trip recommendations are related to history in Basel?
79
+ - How many tickets were sold for Business class on flight 30625?
80
+ - Which hotels are located in Zurich?
81
+ """
82
+ )
83
+
84
+ # Main Application Title
85
+ st.title("RAG SQL-Agent")
86
+
87
+ # Initialize session state for storing chat messages
88
+ if "messages" not in st.session_state:
89
+ st.session_state.messages = []
90
+
91
+ # Display conversation history from session state
92
+ for message in st.session_state.messages:
93
+ role = message.get("role", "assistant")
94
+ with st.chat_message(role):
95
+ if "output" in message:
96
+ st.markdown(message["output"])
97
+ if "sql_query" in message and message["sql_query"]:
98
+ with st.expander("SQL Query", expanded=True):
99
+ st.code(message["sql_query"])
100
+
101
+ # Input form for user prompt
102
+ if prompt := st.chat_input("What do you want to know?"):
103
+ st.chat_message("user").markdown(prompt)
104
+ st.session_state.messages.append({"role": "user", "output": prompt})
105
+
106
+ # Fetch response from RAG agent function directly
107
+ with st.spinner("Searching for an answer..."):
108
+ output_text, sql_query = query_rag_agent(prompt)
109
+
110
+ # Display assistant response and SQL query
111
+ st.chat_message("assistant").markdown(output_text)
112
+ if sql_query:
113
+ with st.expander("SQL Query", expanded=True):
114
+ st.code(sql_query)
115
+
116
+ # Append assistant response to session state
117
+ st.session_state.messages.append(
118
+ {
119
+ "role": "assistant",
120
+ "output": output_text,
121
+ "sql_query": sql_query,
122
+ }
123
+ )
column.jsonl ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"table_name":"aircrafts_data","column_name":"aircraft_code","description":"Unique identifier for each aircraft","data_type":"TEXT"}
2
+ {"table_name":"aircrafts_data","column_name":"model","description":"Model of the aircraft","data_type":"TEXT"}
3
+ {"table_name":"aircrafts_data","column_name":"range","description":"Maximum flight range of the aircraft in kilometers","data_type":"INTEGER"}
4
+ {"table_name":"airports_data","column_name":"airport_code","description":"Unique identifier for each airport","data_type":"TEXT"}
5
+ {"table_name":"airports_data","column_name":"airport_name","description":"Full name of the airport","data_type":"TEXT"}
6
+ {"table_name":"airports_data","column_name":"city","description":"City where the airport is located","data_type":"TEXT"}
7
+ {"table_name":"airports_data","column_name":"coordinates","description":"Geographic coordinates of the airport","data_type":"TEXT"}
8
+ {"table_name":"airports_data","column_name":"timezone","description":"Time zone of the airport location","data_type":"TEXT"}
9
+ {"table_name":"boarding_passes","column_name":"ticket_no","description":"Ticket number associated with the boarding pass","data_type":"TEXT"}
10
+ {"table_name":"boarding_passes","column_name":"flight_id","description":"Identifier of the flight","data_type":"INTEGER"}
11
+ {"table_name":"boarding_passes","column_name":"boarding_no","description":"Boarding number assigned to the passenger","data_type":"INTEGER"}
12
+ {"table_name":"boarding_passes","column_name":"seat_no","description":"Seat number assigned to the passenger","data_type":"TEXT"}
13
+ {"table_name":"bookings","column_name":"book_ref","description":"Unique booking reference","data_type":"TEXT"}
14
+ {"table_name":"bookings","column_name":"book_date","description":"Date and time when the booking was made","data_type":"TIMESTAMP"}
15
+ {"table_name":"bookings","column_name":"total_amount","description":"Total amount paid for the booking","data_type":"INTEGER"}
16
+ {"table_name":"car_rentals","column_name":"id","description":"Unique identifier for each car rental","data_type":"INTEGER"}
17
+ {"table_name":"car_rentals","column_name":"name","description":"Name of the car rental company","data_type":"TEXT"}
18
+ {"table_name":"car_rentals","column_name":"location","description":"Location of the car rental","data_type":"TEXT"}
19
+ {"table_name":"car_rentals","column_name":"price_tier","description":"Price category of the rental car","data_type":"TEXT"}
20
+ {"table_name":"car_rentals","column_name":"start_date","description":"Start date of the car rental period","data_type":"DATE"}
21
+ {"table_name":"car_rentals","column_name":"end_date","description":"End date of the car rental period","data_type":"DATE"}
22
+ {"table_name":"car_rentals","column_name":"booked","description":"Booking status of the car (0 for available, 1 for booked)","data_type":"INTEGER"}
23
+ {"table_name":"flights","column_name":"flight_id","description":"Unique identifier for each flight","data_type":"INTEGER"}
24
+ {"table_name":"flights","column_name":"flight_no","description":"Flight number","data_type":"TEXT"}
25
+ {"table_name":"flights","column_name":"scheduled_departure","description":"Scheduled departure date and time","data_type":"TIMESTAMP"}
26
+ {"table_name":"flights","column_name":"scheduled_arrival","description":"Scheduled arrival date and time","data_type":"TIMESTAMP"}
27
+ {"table_name":"flights","column_name":"departure_airport","description":"Code of the departure airport","data_type":"TEXT"}
28
+ {"table_name":"flights","column_name":"arrival_airport","description":"Code of the arrival airport","data_type":"TEXT"}
29
+ {"table_name":"flights","column_name":"status","description":"Current status of the flight","data_type":"TEXT"}
30
+ {"table_name":"flights","column_name":"aircraft_code","description":"Code of the aircraft operating the flight","data_type":"TEXT"}
31
+ {"table_name":"flights","column_name":"actual_departure","description":"Actual departure date and time","data_type":"TIMESTAMP"}
32
+ {"table_name":"flights","column_name":"actual_arrival","description":"Actual arrival date and time","data_type":"TIMESTAMP"}
33
+ {"table_name":"hotels","column_name":"id","description":"Unique identifier for each hotel booking","data_type":"INTEGER"}
34
+ {"table_name":"hotels","column_name":"name","description":"Name of the hotel","data_type":"TEXT"}
35
+ {"table_name":"hotels","column_name":"location","description":"Location of the hotel","data_type":"TEXT"}
36
+ {"table_name":"hotels","column_name":"price_tier","description":"Price category of the hotel","data_type":"TEXT"}
37
+ {"table_name":"hotels","column_name":"checkin_date","description":"Check-in date for the hotel booking","data_type":"DATE"}
38
+ {"table_name":"hotels","column_name":"checkout_date","description":"Check-out date for the hotel booking","data_type":"DATE"}
39
+ {"table_name":"hotels","column_name":"booked","description":"Booking status of the hotel room (0 for available, 1 for booked)","data_type":"INTEGER"}
40
+ {"table_name":"seats","column_name":"aircraft_code","description":"Code of the aircraft","data_type":"TEXT"}
41
+ {"table_name":"seats","column_name":"seat_no","description":"Seat number","data_type":"TEXT"}
42
+ {"table_name":"seats","column_name":"fare_conditions","description":"Fare class of the seat (e.g., Economy, Business)","data_type":"TEXT"}
43
+ {"table_name":"ticket_flights","column_name":"ticket_no","description":"Ticket number","data_type":"TEXT"}
44
+ {"table_name":"ticket_flights","column_name":"flight_id","description":"Flight identifier","data_type":"INTEGER"}
45
+ {"table_name":"ticket_flights","column_name":"fare_conditions","description":"Fare conditions for the ticket (e.g., Economy, Business)","data_type":"TEXT"}
46
+ {"table_name":"ticket_flights","column_name":"amount","description":"Cost of the ticket","data_type":"INTEGER"}
47
+ {"table_name":"tickets","column_name":"ticket_no","description":"Unique ticket number","data_type":"TEXT"}
48
+ {"table_name":"tickets","column_name":"book_ref","description":"Booking reference associated with the ticket","data_type":"TEXT"}
49
+ {"table_name":"tickets","column_name":"passenger_id","description":"Unique identifier for the passenger","data_type":"TEXT"}
50
+ {"table_name":"trip_recommendations","column_name":"id","description":"Unique identifier for each trip recommendation","data_type":"INTEGER"}
51
+ {"table_name":"trip_recommendations","column_name":"name","description":"Name of the recommended attraction or activity","data_type":"TEXT"}
52
+ {"table_name":"trip_recommendations","column_name":"location","description":"Location of the recommended attraction or activity","data_type":"TEXT"}
53
+ {"table_name":"trip_recommendations","column_name":"keywords","description":"Keywords associated with the recommendation","data_type":"TEXT"}
54
+ {"table_name":"trip_recommendations","column_name":"details","description":"Detailed description of the recommendation","data_type":"TEXT"}
55
+ {"table_name":"trip_recommendations","column_name":"booked","description":"Booking status of the recommendation (0 for available, 1 for booked)","data_type":"INTEGER"}
constant.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ GROQ_API_KEY = "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"
3
+ DB_PATH = "sqlite:///travel.sqlite"
4
+
5
+ CONFIG = {
6
+ "configurable" : {
7
+ "thread_id": "1234"
8
+ }
9
+ }
models.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+ class QueryInput(BaseModel):
5
+ query: str
6
+
7
+ class QueryOutput(BaseModel):
8
+ response: str
9
+ sql_query: List[str]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Main dependencies
2
+ opentelemetry-api==1.22.0
3
+ pydantic>=2.7
4
+ langchain==0.3.3
5
+ langgraph==0.2.38
6
+ langchain-community==0.3.2
7
+ langchain-groq==0.2.0
8
+ fastembed==0.4.1
9
+ onnx==1.17.0
10
+ onnxruntime==1.18.0
11
+ jq==1.8.0a2
12
+ faiss-cpu==1.8.0
13
+ requests==2.31.0
14
+ streamlit==1.34.0
state.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import TypedDict
2
+ from langgraph.graph.message import AnyMessage, add_messages
3
+ from typing import Annotated, List
4
+
5
+ class AgentState(TypedDict):
6
+ messages: Annotated[List[AnyMessage], add_messages]
7
+ sql_query: str
8
+ error_str: str
9
+ question_type: str
table.jsonl ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"table": "aircrafts_data", "description": "This table stores information about aircraft, including aircraft code, model, and flight range.", "example_questions": ["What is the flight range of the aircraft with code '773'?", "Which aircraft model has the shortest flight range?", "What is the difference in flight range between the aircraft with the longest range and the one with the shortest range?"]}
2
+ {"table": "airports_data", "description": "This table contains airport data, including airport code, name, city, coordinates, and time zone.", "example_questions": ["In which city is the airport with code 'ATL' located?", "How many airports are in the Asia time zone?", "Which airport has the northernmost location based on its coordinates?"]}
3
+ {"table": "boarding_passes", "description": "This table stores information about boarding passes, including ticket number, flight ID, boarding number, and seat number.", "example_questions": ["What is the seat number for the ticket number '0060005435212351'?", "How many passengers have seat numbers in row 2 (e.g., 2A, 2B, 2C, etc.)?", "For the flight with flight_id 30625, is there a correlation between boarding number and seat position?"]}
4
+ {"table": "bookings", "description": "This table contains booking information, including booking reference, booking date, and total amount.", "example_questions": ["What is the total amount for the booking with reference '00000F'?", "How many bookings were made in March 2024?", "What is the average booking amount for bookings made in April 2024?"]}
5
+ {"table": "car_rentals", "description": "This table stores information about car rentals, including rental company, location, price tier, rental dates, and booking status.", "example_questions": ["How many car rentals are available in Basel?", "Which rental company offers a luxury car in Basel, and for how many days?", "Calculate the average rental duration for all non-booked cars, grouped by price tier."]}
6
+ {"table": "flights", "description": "This table contains flight information, including flight numbers, schedules, airports, status, and aircraft details.", "example_questions": [ "What is the flight number for the flight departing from BSL to BKK?", "How many flights are scheduled to depart from Shanghai (SHA)?", "For flights that have both scheduled and actual departure times, what is the average delay in minutes?"]}
7
+ {"table": "hotels","description": "This table stores information about hotel bookings, including hotel name, location, price tier, check-in and check-out dates, and booking status.", "example_questions": ["How many hotels are available for booking in Basel?", "Which hotel in Zurich offers an Upscale price tier, and for how many nights?", "Calculate the average length of stay for all non-booked hotels, grouped by price tier."]}
8
+ {"table": "seats", "description": "This table contains information about aircraft seats, including the aircraft code, seat number, and fare conditions.", "example_questions": [ "How many Business class seats are there on aircraft with code '319'?", "List all unique fare conditions available across all aircraft.", "For each aircraft code, what percentage of seats are in the Business fare condition?"]}
9
+ {"table": "ticket_flights", "description": "This table links tickets to specific flights and includes fare conditions and ticket amounts.", "example_questions": ["What is the average ticket amount for Business class on flight_id 30625?", "How many different fare conditions are there across all ticket_flights?", "For each flight_id, what is the price difference between the highest and lowest ticket amount?"]}
10
+ {"table": "tickets", "description": "This table contains information about tickets, including ticket number, booking reference, and passenger ID.", "example_questions": ["How many tickets are associated with the booking reference '06B046'?", "What is the passenger ID for the ticket number '9880005432000987'?", "How many unique passengers (based on passenger_id) have tickets in the system?"]}
11
+ {"table": "trip_recommendations", "description": "This table stores trip recommendations, including name, location, keywords, details, and booking status.", "example_questions": [ "How many trip recommendations are available for Basel?", "What are the keywords associated with the 'Kunstmuseum Basel' recommendation?", "List all unique locations that have trip recommendations, along with the count of recommendations for each location."]}
tools.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
3
+
4
+ embeddings = FastEmbedEmbeddings()
5
+
6
+ def retriever(docs, k=5, search_type='mmr', lambda_mult=None):
7
+ """
8
+ Creates a document retriever using FAISS (Facebook AI Similarity Search).
9
+
10
+ Parameters:
11
+ -----------
12
+ docs : List[Document]
13
+ A list of documents to be indexed for similarity search.
14
+
15
+ k : int, optional, default=5
16
+ The number of top-k results to return for a query.
17
+
18
+ search_type : str, optional, default='mmr'
19
+ The type of search to perform. Options include 'mmr' and 'similarity'.
20
+
21
+ lambda_mult : float, optional, default=None
22
+ Lambda multiplier for Maximal Marginal Relevance (MMR) search.
23
+
24
+ Returns:
25
+ --------
26
+ retriever : Retriever
27
+ A retriever object for querying relevant documents.
28
+ """
29
+ # Create FAISS index from documents
30
+ vector_store = FAISS.from_documents(docs, embedding=embeddings)
31
+
32
+ # Prepare search kwargs with optional lambda_mult
33
+ search_kwargs = {'k': k}
34
+ if lambda_mult is not None:
35
+ search_kwargs['lambda_mult'] = lambda_mult
36
+
37
+ # Return the retriever
38
+ return vector_store.as_retriever(search_type=search_type, search_kwargs=search_kwargs)