Update app.py
Browse files
app.py
CHANGED
@@ -8,22 +8,22 @@ ACTION_1 = "Prompt base model"
|
|
8 |
ACTION_2 = "Fine-tune base model"
|
9 |
ACTION_3 = "Prompt fine-tuned model"
|
10 |
|
11 |
-
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided
|
12 |
-
USER_PROMPT = "
|
13 |
-
|
14 |
|
15 |
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
16 |
FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
|
17 |
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
18 |
|
19 |
-
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt,
|
20 |
#raise gr.Error("Please clone and bring your own credentials.")
|
21 |
if action == ACTION_1:
|
22 |
-
result = prompt_model(base_model_name, system_prompt, user_prompt,
|
23 |
elif action == ACTION_2:
|
24 |
result = fine_tune_model(base_model_name, dataset_name)
|
25 |
elif action == ACTION_3:
|
26 |
-
result = prompt_model(ft_model_name, system_prompt, user_prompt,
|
27 |
return result
|
28 |
|
29 |
def fine_tune_model(base_model_name, dataset_name):
|
@@ -108,14 +108,14 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
108 |
# Train model
|
109 |
trainer.train()
|
110 |
|
111 |
-
def prompt_model(model_name, system_prompt, user_prompt,
|
112 |
pipe = pipeline("text-generation",
|
113 |
model=model_name,
|
114 |
device_map="auto",
|
115 |
max_new_tokens=1000)
|
116 |
|
117 |
messages = [
|
118 |
-
{"role": "system", "content": system_prompt.format(
|
119 |
{"role": "user", "content": user_prompt},
|
120 |
{"role": "assistant", "content": ""}
|
121 |
]
|
@@ -144,6 +144,6 @@ demo = gr.Interface(fn=process,
|
|
144 |
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
145 |
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
146 |
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|
147 |
-
gr.Textbox(label = "SQL
|
148 |
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
|
149 |
demo.launch()
|
|
|
8 |
ACTION_2 = "Fine-tune base model"
|
9 |
ACTION_3 = "Prompt fine-tuned model"
|
10 |
|
11 |
+
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
|
12 |
+
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
|
13 |
+
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
|
14 |
|
15 |
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
16 |
FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
|
17 |
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
18 |
|
19 |
+
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
|
20 |
#raise gr.Error("Please clone and bring your own credentials.")
|
21 |
if action == ACTION_1:
|
22 |
+
result = prompt_model(base_model_name, system_prompt, user_prompt, sql_context)
|
23 |
elif action == ACTION_2:
|
24 |
result = fine_tune_model(base_model_name, dataset_name)
|
25 |
elif action == ACTION_3:
|
26 |
+
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
|
27 |
return result
|
28 |
|
29 |
def fine_tune_model(base_model_name, dataset_name):
|
|
|
108 |
# Train model
|
109 |
trainer.train()
|
110 |
|
111 |
+
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
112 |
pipe = pipeline("text-generation",
|
113 |
model=model_name,
|
114 |
device_map="auto",
|
115 |
max_new_tokens=1000)
|
116 |
|
117 |
messages = [
|
118 |
+
{"role": "system", "content": system_prompt.format(sql_context=sql_context)},
|
119 |
{"role": "user", "content": user_prompt},
|
120 |
{"role": "assistant", "content": ""}
|
121 |
]
|
|
|
144 |
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
145 |
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
146 |
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|
147 |
+
gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 2)],
|
148 |
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
|
149 |
demo.launch()
|