aryachakraborty commited on
Commit
31d6ed6
1 Parent(s): 863596d

Upload 4 files

Browse files
Files changed (4) hide show
  1. My_SQL_Connection.py +101 -0
  2. app.py +155 -0
  3. model_functions.py +19 -0
  4. requirements.txt +11 -0
My_SQL_Connection.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### In this file we will store all the codes related to connection to my sql server.
2
+
3
+ import mysql.connector
4
+ import pandas as pd
5
+ ###======================================================================database details-=======================================================
6
+ def database_details(host,user,password):
7
+ connection = mysql.connector.connect(
8
+ host = host,
9
+ user = user,
10
+ password = password,
11
+ buffered = True
12
+ )
13
+ cursor = connection.cursor()
14
+ databases = ("Show databases")
15
+ cursor.execute(databases)
16
+ db = []
17
+ for (databases) in cursor:
18
+ db.append(databases[0])
19
+
20
+ cursor.close()
21
+ connection.close()
22
+ return db, len(db)
23
+
24
+ #### =========================================================================retrieving the tables==========================================================
25
+ def tables_in_this_DB(host,user,password,db_name):
26
+ db_config = {
27
+ 'host':host,
28
+ 'user': user,
29
+ 'password': password,
30
+ 'database': db_name,
31
+ }
32
+ connection = mysql.connector.connect(**db_config)
33
+ cursor = connection.cursor()
34
+ query1 = "SHOW TABLES"
35
+ cursor.execute(query1)
36
+ tables = cursor.fetchall()
37
+
38
+ cursor.close()
39
+ connection.close()
40
+ return tables, len(tables)
41
+
42
+ #### ==================================================Printing the tables=======================================================================
43
+ def printing_tables(host,user,password,db_name):
44
+ db_config = {
45
+ 'host':host,
46
+ 'user': user,
47
+ 'password': password,
48
+ 'database': db_name,
49
+ }
50
+ connection = mysql.connector.connect(**db_config)
51
+ cursor = connection.cursor()
52
+ cursor.execute("SHOW TABLES")
53
+ table_names = [table[0] for table in cursor.fetchall()]
54
+
55
+ tables_data = {}
56
+
57
+ for table_name in table_names:
58
+ query = f"SELECT * FROM {table_name}"
59
+ cursor.execute(query)
60
+ rows = cursor.fetchall()
61
+
62
+ col_names = [desc[0] for desc in cursor.description]
63
+ df = pd.DataFrame(rows, columns=col_names)
64
+
65
+ tables_data[table_name] = df
66
+ cursor.close()
67
+ connection.close()
68
+ return tables_data
69
+
70
+
71
+
72
+ def create_table_command(host,user,password,db_name):
73
+ db_config = {
74
+ 'host': host,
75
+ 'user': user,
76
+ 'password': password,
77
+ 'database': db_name,
78
+ }
79
+
80
+ connection = mysql.connector.connect(**db_config)
81
+ cursor = connection.cursor()
82
+ query = "SHOW TABLES"
83
+ cursor.execute(query)
84
+ table_names = [table[0] for table in cursor.fetchall()]
85
+
86
+ create_table_statements = {}
87
+ for table_name in table_names:
88
+ query = f"SHOW CREATE TABLE {table_name}"
89
+ cursor.execute(query)
90
+ create_table_data = cursor.fetchone()
91
+
92
+ if create_table_data:
93
+ # The CREATE TABLE statement is in the second element of the tuple
94
+ create_table_statement = create_table_data[1]
95
+ create_table_statement = create_table_statement.split("ENGINE=")[0].strip()
96
+ create_table_statements[table_name] = create_table_statement
97
+
98
+ cursor.close()
99
+ connection.close()
100
+
101
+ return create_table_statements
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from My_SQL_Connection import database_details, tables_in_this_DB, printing_tables, create_table_command
4
+ from streamlit_option_menu import option_menu
5
+ from model_functions import LOAD_GEMMA
6
+ import torch
7
+
8
+
9
+ user_name = 'arya'
10
+
11
+
12
+
13
+ if 'localhost' not in st.session_state:
14
+ st.session_state.localhost = ''
15
+ st.session_state.user = ''
16
+ st.session_state.password = ''
17
+ st.session_state.table_commands = """ """
18
+
19
+ with st.sidebar:
20
+ selected = option_menu("Querio Lingua", ["Log In", 'main functionalities','Chat with AI'],
21
+ icons=['house', 'gear', 'robot'], menu_icon="cast", default_index=1)
22
+
23
+ if selected == 'Log In':
24
+ st.title(f'welcome to our web application :green[{user_name}]')
25
+ st.subheader('welcome to our MY SQL Database Explorer ~ ')
26
+
27
+ st.write('First Check how many databases you have in your server ~')
28
+ st.session_state.localhost = st.text_input("what is your host, (localhost if in local) or give the url", 'localhost',help='host')
29
+ st.session_state.user = st.text_input("what is your user name (usually root)", 'root')
30
+ st.session_state.password = st.text_input('Password', type='password')
31
+
32
+ elif selected == 'main functionalities':
33
+ st.title(f'welcome to our web application :green[{user_name}]')
34
+ st.subheader('welcome to our MY SQL Database Explorer ~ ')
35
+ if st.button('All your databases ~ '):
36
+ db, l = database_details(st.session_state.localhost, st.session_state.user, st.session_state.password)
37
+ st.table(db)
38
+
39
+ st.subheader('Now we will see details of any database~ ')
40
+
41
+ st.session_state.db_name = st.text_input('Which Database you want')
42
+
43
+ if st.button('All tables present in that particular database'):
44
+ if not st.session_state.db_name:
45
+ st.warning('Input database name first')
46
+ else:
47
+ tables, l = tables_in_this_DB(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
48
+ st.write(f'There is only {l} tables present in this database')
49
+ st.write(tables)
50
+
51
+ st.subheader('check out tables~ ')
52
+
53
+ if st.button('Print the tables~'):
54
+ tables_data = printing_tables(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
55
+ for table_name, table_data in tables_data.items():
56
+ st.write(f"Table: {table_name}")
57
+ st.table(table_data)
58
+
59
+ st.subheader('Retrieve the CREATE TABLE Statements')
60
+
61
+ if st.button('Generate statements'):
62
+ statements = create_table_command(st.session_state.localhost, st.session_state.user, st.session_state.password, st.session_state.db_name)
63
+ for table_name, table_statements in statements.items():
64
+ st.write(f'{table_name}')
65
+ st.session_state.table_commands = table_statements
66
+ st.code(table_statements)
67
+
68
+
69
+ elif selected == 'Chat with AI':
70
+ #st.set_page_config(page_title='🧠MemoryBot🤖', layout='wide')
71
+ # Initialize session states
72
+ if "generated" not in st.session_state:
73
+ st.session_state["generated"] = []
74
+ if "past" not in st.session_state:
75
+ st.session_state["past"] = []
76
+ if "input" not in st.session_state:
77
+ st.session_state["input"] = ""
78
+ if "stored_session" not in st.session_state:
79
+ st.session_state["stored_session"] = []
80
+
81
+ def get_text():
82
+ """
83
+ Get the user input text.
84
+
85
+ Returns:
86
+ (str): The text entered by the user
87
+ """
88
+ input_text = st.text_input("You: ", st.session_state["input"], key="input",
89
+ placeholder="Your AI assistant here! Ask me anything ...",
90
+ label_visibility='hidden')
91
+ return input_text
92
+
93
+ def new_chat():
94
+ """
95
+ Clears session state and starts a new chat.
96
+ """
97
+ save = []
98
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
99
+ save.append("User:" + st.session_state["past"][i])
100
+ save.append("Bot:" + st.session_state["generated"][i])
101
+ st.session_state["stored_session"].append(save)
102
+ st.session_state["generated"] = []
103
+ st.session_state["past"] = []
104
+ st.session_state["input"] = ""
105
+
106
+ with st.sidebar.expander("🛠️ ", expanded=False):
107
+ MODEL = st.selectbox(label='Model', options=['GEMMA-2B FINE TUNED'])
108
+
109
+
110
+ st.title("🤖 Chat Bot with 🧠")
111
+ st.subheader(" Powered by 🚀 GEMMA")
112
+
113
+
114
+ st.sidebar.button("New Chat", on_click = new_chat, type='primary')
115
+ user_input = get_text()
116
+
117
+ if user_input:
118
+ tokenizer,model = LOAD_GEMMA()
119
+ device = torch.device("cpu")
120
+ alpeca_prompt = f"""Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables.
121
+ ### Instruction: {user_input}. ### Input: {st.session_state.table_commands}
122
+ ### Response:
123
+ """
124
+ with st.status('Generating Result',expanded=False) as status:
125
+ inputs = tokenizer([alpeca_prompt], return_tensors="pt").to(device)
126
+ outputs = model.generate(**inputs, max_new_tokens=30)
127
+ output = tokenizer.decode(outputs[0], skip_special_tokens=False)
128
+ st.session_state.past.append(user_input)
129
+ st.session_state.generated.append(output)
130
+ print(output)
131
+ status.update(label="Result Generated!", state="complete", expanded=False)
132
+
133
+ download_str = []
134
+ # Display the conversation history using an expander, and allow the user to download it
135
+ with st.expander("Conversation", expanded=True):
136
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
137
+ st.info(st.session_state["past"][i],icon="🧐")
138
+ st.success(st.session_state["generated"][i], icon="🤖")
139
+ download_str.append(st.session_state["past"][i])
140
+ download_str.append(st.session_state["generated"][i])
141
+
142
+ # Can throw error - requires fix
143
+ download_str = '\n'.join(download_str)
144
+ if download_str:
145
+ st.download_button('Download',download_str)
146
+
147
+ # Display stored conversation sessions in the sidebar
148
+ for i, sublist in enumerate(st.session_state.stored_session):
149
+ with st.sidebar.expander(label= f"Conversation-Session:{i}"):
150
+ st.write(sublist)
151
+
152
+ # Allow the user to clear all stored conversation sessions
153
+ if st.session_state.stored_session:
154
+ if st.sidebar.button("Clear-all"):
155
+ del st.session_state.stored_session
model_functions.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM,AutoTokenizer
2
+ import streamlit as st
3
+
4
+ @st.cache_resource(show_spinner='Loading the Gemma model. Be patient🙏')
5
+ def LOAD_GEMMA():
6
+ model_id = "aryachakraborty/GEMMA-2B-NL-SQL"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(model_id,
9
+ low_cpu_mem_usage = True
10
+ ).cpu()
11
+ return tokenizer,model
12
+
13
+
14
+ def LOAD_MISTRAL():
15
+ model_id=''
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ model = AutoModelForCausalLM.from_pretrained(model_id,
18
+ low_cpu_usage=True,
19
+ ).cpu()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bitsandbytes==0.42.0
2
+ peft==0.8.2
3
+ trl==0.7.10
4
+ accelerate==0.27.1
5
+ datasets==2.17.0
6
+ transformers==4.38.0
7
+ streamlit
8
+ streamlit_option_menu
9
+ torch
10
+ mysql.connector
11
+ pandas