Update app.py
Browse files
app.py
CHANGED
@@ -8,12 +8,14 @@ ACTION_1 = "Prompt base model"
|
|
8 |
ACTION_2 = "Fine-tune base model"
|
9 |
ACTION_3 = "Prompt fine-tuned model"
|
10 |
|
|
|
|
|
11 |
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
|
12 |
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
|
13 |
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
|
14 |
|
15 |
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
16 |
-
FT_MODEL_NAME = "
|
17 |
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
18 |
|
19 |
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
|
@@ -77,11 +79,11 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
77 |
# Configure training arguments
|
78 |
|
79 |
training_args = Seq2SeqTrainingArguments(
|
80 |
-
output_dir="./
|
81 |
logging_dir="./logs",
|
82 |
num_train_epochs=1,
|
83 |
-
max_steps=
|
84 |
-
push_to_hub=True,
|
85 |
#per_device_train_batch_size=16,
|
86 |
#per_device_eval_batch_size=64,
|
87 |
#eval_strategy="steps",
|
@@ -114,9 +116,9 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
114 |
|
115 |
trainer.train()
|
116 |
|
117 |
-
#
|
118 |
|
119 |
-
tokenizer.push_to_hub(
|
120 |
|
121 |
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
122 |
pipe = pipeline("text-generation",
|
@@ -150,7 +152,7 @@ def load_model(model_name):
|
|
150 |
demo = gr.Interface(fn=process,
|
151 |
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
|
152 |
gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
|
153 |
-
gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
|
154 |
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
155 |
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
156 |
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|
|
|
8 |
ACTION_2 = "Fine-tune base model"
|
9 |
ACTION_3 = "Prompt fine-tuned model"
|
10 |
|
11 |
+
HF_ACCOUNT = "bstraehle"
|
12 |
+
|
13 |
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
|
14 |
USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
|
15 |
SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
|
16 |
|
17 |
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
18 |
+
FT_MODEL_NAME = "Meta-Llama-3.1-8B-Instruct-text-to-sql"
|
19 |
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
20 |
|
21 |
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
|
|
|
79 |
# Configure training arguments
|
80 |
|
81 |
training_args = Seq2SeqTrainingArguments(
|
82 |
+
output_dir=f"./{FT_MODEL_NAME}",
|
83 |
logging_dir="./logs",
|
84 |
num_train_epochs=1,
|
85 |
+
max_steps=1, # overwrites num_train_epochs
|
86 |
+
push_to_hub=True, # only model, also need to push tokenizer
|
87 |
#per_device_train_batch_size=16,
|
88 |
#per_device_eval_batch_size=64,
|
89 |
#eval_strategy="steps",
|
|
|
116 |
|
117 |
trainer.train()
|
118 |
|
119 |
+
# Push tokenizer to HF
|
120 |
|
121 |
+
tokenizer.push_to_hub(FT_MODEL_NAME)
|
122 |
|
123 |
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
124 |
pipe = pipeline("text-generation",
|
|
|
152 |
demo = gr.Interface(fn=process,
|
153 |
inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
|
154 |
gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
|
155 |
+
gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
|
156 |
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
157 |
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
158 |
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|