bstraehle commited on
Commit
da6722c
1 Parent(s): 8f45dd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -8,22 +8,22 @@ ACTION_1 = "Prompt base model"
8
  ACTION_2 = "Fine-tune base model"
9
  ACTION_3 = "Prompt fine-tuned model"
10
 
11
- SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
12
- USER_PROMPT = "What is the total trade value and average price for each trader and stock in the trade_history table?"
13
- SQL_SCHEMA = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
14
 
15
  BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
  FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
17
  DATASET_NAME = "gretelai/synthetic_text_to_sql"
18
 
19
- def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_schema):
20
  #raise gr.Error("Please clone and bring your own credentials.")
21
  if action == ACTION_1:
22
- result = prompt_model(base_model_name, system_prompt, user_prompt, sql_schema)
23
  elif action == ACTION_2:
24
  result = fine_tune_model(base_model_name, dataset_name)
25
  elif action == ACTION_3:
26
- result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_schema)
27
  return result
28
 
29
  def fine_tune_model(base_model_name, dataset_name):
@@ -108,14 +108,14 @@ def fine_tune_model(base_model_name, dataset_name):
108
  # Train model
109
  trainer.train()
110
 
111
- def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
112
  pipe = pipeline("text-generation",
113
  model=model_name,
114
  device_map="auto",
115
  max_new_tokens=1000)
116
 
117
  messages = [
118
- {"role": "system", "content": system_prompt.format(schema=sql_schema)},
119
  {"role": "user", "content": user_prompt},
120
  {"role": "assistant", "content": ""}
121
  ]
@@ -144,6 +144,6 @@ demo = gr.Interface(fn=process,
144
  gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
145
  gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
146
  gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
147
- gr.Textbox(label = "SQL Schema", value = SQL_SCHEMA, lines = 2)],
148
  outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
149
  demo.launch()
 
8
  ACTION_2 = "Fine-tune base model"
9
  ACTION_3 = "Prompt fine-tuned model"
10
 
11
+ SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SQL_CONTEXT. Do not generate any additional text. SQL_CONTEXT: {sql_context}"
12
+ USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
13
+ SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
14
 
15
  BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
  FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
17
  DATASET_NAME = "gretelai/synthetic_text_to_sql"
18
 
19
+ def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
20
  #raise gr.Error("Please clone and bring your own credentials.")
21
  if action == ACTION_1:
22
+ result = prompt_model(base_model_name, system_prompt, user_prompt, sql_context)
23
  elif action == ACTION_2:
24
  result = fine_tune_model(base_model_name, dataset_name)
25
  elif action == ACTION_3:
26
+ result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
27
  return result
28
 
29
  def fine_tune_model(base_model_name, dataset_name):
 
108
  # Train model
109
  trainer.train()
110
 
111
+ def prompt_model(model_name, system_prompt, user_prompt, sql_context):
112
  pipe = pipeline("text-generation",
113
  model=model_name,
114
  device_map="auto",
115
  max_new_tokens=1000)
116
 
117
  messages = [
118
+ {"role": "system", "content": system_prompt.format(sql_context=sql_context)},
119
  {"role": "user", "content": user_prompt},
120
  {"role": "assistant", "content": ""}
121
  ]
 
144
  gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
145
  gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
146
  gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
147
+ gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 2)],
148
  outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
149
  demo.launch()