bstraehle commited on
Commit
3d0bfc5
1 Parent(s): e6625f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -8,12 +8,14 @@ ACTION_1 = "Prompt base model"
8
  ACTION_2 = "Fine-tune base model"
9
  ACTION_3 = "Prompt fine-tuned model"
10
 
 
 
11
  SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
12
  USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
13
  SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
14
 
15
  BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
- FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
17
  DATASET_NAME = "gretelai/synthetic_text_to_sql"
18
 
19
  def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
@@ -77,11 +79,11 @@ def fine_tune_model(base_model_name, dataset_name):
77
  # Configure training arguments
78
 
79
  training_args = Seq2SeqTrainingArguments(
80
- output_dir="./Meta-Llama-3.1-8B-Instruct-text-to-sql",
81
  logging_dir="./logs",
82
  num_train_epochs=1,
83
- max_steps=2, # overwrites num_train_epochs
84
- push_to_hub=True,
85
  #per_device_train_batch_size=16,
86
  #per_device_eval_batch_size=64,
87
  #eval_strategy="steps",
@@ -114,9 +116,9 @@ def fine_tune_model(base_model_name, dataset_name):
114
 
115
  trainer.train()
116
 
117
- # Save tokenizer to HF
118
 
119
- tokenizer.push_to_hub("Meta-Llama-3.1-8B-Instruct-text-to-sql")
120
 
121
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
122
  pipe = pipeline("text-generation",
@@ -150,7 +152,7 @@ def load_model(model_name):
150
  demo = gr.Interface(fn=process,
151
  inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
152
  gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
153
- gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
154
  gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
155
  gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
156
  gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
 
8
  ACTION_2 = "Fine-tune base model"
9
  ACTION_3 = "Prompt fine-tuned model"
10
 
11
+ HF_ACCOUNT = "bstraehle"
12
+
13
  SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
14
  USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
15
  SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
16
 
17
  BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
18
+ FT_MODEL_NAME = "Meta-Llama-3.1-8B-Instruct-text-to-sql"
19
  DATASET_NAME = "gretelai/synthetic_text_to_sql"
20
 
21
  def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
 
79
  # Configure training arguments
80
 
81
  training_args = Seq2SeqTrainingArguments(
82
+ output_dir=f"./{FT_MODEL_NAME}",
83
  logging_dir="./logs",
84
  num_train_epochs=1,
85
+ max_steps=1, # overwrites num_train_epochs
86
+ push_to_hub=True, # only model, also need to push tokenizer
87
  #per_device_train_batch_size=16,
88
  #per_device_eval_batch_size=64,
89
  #eval_strategy="steps",
 
116
 
117
  trainer.train()
118
 
119
+ # Push tokenizer to HF
120
 
121
+ tokenizer.push_to_hub(FT_MODEL_NAME)
122
 
123
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
124
  pipe = pipeline("text-generation",
 
152
  demo = gr.Interface(fn=process,
153
  inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
154
  gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
155
+ gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
156
  gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
157
  gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
158
  gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),