Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import os, torch
|
|
3 |
from datasets import load_dataset
|
4 |
from huggingface_hub import HfApi, login
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
6 |
-
qTrainingArguments
|
7 |
|
8 |
hf_profile = "bstraehle"
|
9 |
|
@@ -12,140 +11,134 @@ action_2 = "Prompt fine-tuned model"
|
|
12 |
|
13 |
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
|
14 |
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
|
19 |
|
20 |
-
def process(action,
|
21 |
#raise gr.Error("Please clone and bring your own credentials.")
|
22 |
if action == action_1:
|
23 |
-
result = fine_tune_model(
|
24 |
elif action == action_2:
|
25 |
-
|
26 |
-
result = prompt_model(fine_tuned_model_id, system_prompt, user_prompt, schema)
|
27 |
return result
|
28 |
|
29 |
-
def fine_tune_model(
|
30 |
-
#
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
dataset
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
38 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
39 |
print(model)
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
tokenizer.pad_token = tokenizer.eos_token
|
44 |
|
45 |
-
#
|
46 |
def preprocess(examples):
|
47 |
model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, padding="max_length", truncation=True)
|
48 |
return model_inputs
|
49 |
-
|
50 |
dataset = dataset.map(preprocess, batched=True)
|
51 |
-
|
52 |
-
# Split dataset to training and validation sets
|
53 |
-
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000)) # Adjust the range as needed
|
54 |
-
val_dataset = dataset["test"].shuffle(seed=42).select(range(100)) # Adjust the range as needed
|
55 |
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
training_args = Seq2SeqTrainingArguments(
|
58 |
output_dir="./results",
|
59 |
-
|
|
|
60 |
per_device_train_batch_size=16,
|
61 |
per_device_eval_batch_size=64,
|
62 |
-
|
63 |
-
weight_decay=0.01,
|
64 |
-
logging_dir="./logs",
|
65 |
save_total_limit=2,
|
66 |
save_steps=500,
|
67 |
eval_steps=500,
|
|
|
|
|
68 |
metric_for_best_model="accuracy",
|
69 |
greater_is_better=True,
|
70 |
-
save_on_each_node=True,
|
71 |
load_best_model_at_end=True,
|
72 |
-
eval_strategy="steps",
|
73 |
push_to_hub=True,
|
|
|
74 |
)
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
# Create
|
77 |
trainer = Seq2SeqTrainer(
|
78 |
model=model,
|
79 |
args=training_args,
|
80 |
train_dataset=train_dataset,
|
81 |
-
eval_dataset=
|
82 |
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
|
83 |
)
|
84 |
-
|
85 |
-
# Train the model
|
86 |
-
trainer.train()
|
87 |
|
88 |
-
|
89 |
-
trainer
|
90 |
-
|
91 |
-
# Create a repository object
|
92 |
-
repo = Repository(
|
93 |
-
local_dir="./fine_tuned_model",
|
94 |
-
repo_type="model",
|
95 |
-
repo_id="bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql",
|
96 |
-
)
|
97 |
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
# Push the model to the hub
|
102 |
-
repo.push_to_hub(commit_message="Initial commit")
|
103 |
|
104 |
-
def prompt_model(
|
105 |
pipe = pipeline("text-generation",
|
106 |
-
model=
|
107 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
108 |
device_map="auto",
|
109 |
max_new_tokens=1000)
|
|
|
110 |
messages = [
|
111 |
{"role": "system", "content": system_prompt.format(schema=schema)},
|
112 |
{"role": "user", "content": user_prompt},
|
113 |
{"role": "assistant", "content": ""}
|
114 |
]
|
|
|
115 |
output = pipe(messages)
|
|
|
116 |
result = output[0]["generated_text"][-1]["content"]
|
117 |
-
print(result)
|
118 |
-
return result
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
model.save_pretrained(base_model_id)
|
124 |
-
return tokenizer
|
125 |
|
126 |
-
|
127 |
-
fine_tuned_model_id = replace_hf_profile(base_model_id)
|
128 |
-
login(token=os.environ["HF_TOKEN"])
|
129 |
-
api = HfApi()
|
130 |
-
#api.delete_repo(repo_id=fine_tuned_model_id, repo_type="model")
|
131 |
-
api.create_repo(repo_id=fine_tuned_model_id)
|
132 |
-
api.upload_folder(
|
133 |
-
folder_path=base_model_id,
|
134 |
-
repo_id=fine_tuned_model_id
|
135 |
-
)
|
136 |
-
tokenizer.push_to_hub(fine_tuned_model_id)
|
137 |
-
return fine_tuned_model_id
|
138 |
|
139 |
-
def
|
140 |
-
|
141 |
-
|
142 |
|
|
|
|
|
|
|
|
|
|
|
143 |
demo = gr.Interface(fn=process,
|
144 |
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
|
145 |
-
gr.Textbox(label = "
|
146 |
-
gr.Textbox(label = "Dataset", value =
|
147 |
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
|
148 |
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
|
149 |
-
gr.Textbox(label = "Schema", value =
|
150 |
-
outputs=[gr.Textbox(label = "Completion", value = os.environ["OUTPUT"])])
|
151 |
demo.launch()
|
|
|
3 |
from datasets import load_dataset
|
4 |
from huggingface_hub import HfApi, login
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
|
|
6 |
|
7 |
hf_profile = "bstraehle"
|
8 |
|
|
|
11 |
|
12 |
system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
|
13 |
user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
|
14 |
+
sql_schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
|
15 |
|
16 |
+
model_name = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
|
17 |
+
dataset_name = "gretelai/synthetic_text_to_sql"
|
18 |
|
19 |
+
def process(action, model_name, dataset_name, system_prompt, user_prompt, sql_schema):
|
20 |
#raise gr.Error("Please clone and bring your own credentials.")
|
21 |
if action == action_1:
|
22 |
+
result = fine_tune_model(model_name, dataset_name)
|
23 |
elif action == action_2:
|
24 |
+
result = prompt_model(model_name, system_prompt, user_prompt, sql_schema)
|
|
|
25 |
return result
|
26 |
|
27 |
+
def fine_tune_model(model_name, dataset_name):
|
28 |
+
# Load dataset
|
29 |
+
dataset = load_dataset(dataset_name)
|
30 |
|
31 |
+
print("### Dataset")
|
32 |
+
print(dataset)
|
33 |
+
print("###")
|
34 |
+
|
35 |
+
# Load model
|
36 |
+
model, tokenizer = load_model(model_name)
|
37 |
|
38 |
+
print("### Model")
|
|
|
|
|
39 |
print(model)
|
40 |
+
print("### Tokenizer")
|
41 |
+
print(tokenizer)
|
42 |
+
print("###")
|
|
|
43 |
|
44 |
+
# Pre-process dataset
|
45 |
def preprocess(examples):
|
46 |
model_inputs = tokenizer(examples["sql_prompt"], text_target=examples["sql"], max_length=512, padding="max_length", truncation=True)
|
47 |
return model_inputs
|
|
|
48 |
dataset = dataset.map(preprocess, batched=True)
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
print("### Pre-processed dataset")
|
51 |
+
print(dataset)
|
52 |
+
print("###")
|
53 |
+
|
54 |
+
# Split dataset into training and validation sets
|
55 |
+
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
|
56 |
+
test_dataset = dataset["test"].shuffle(seed=42).select(range(100))
|
57 |
+
|
58 |
+
print("### Training dataset")
|
59 |
+
print(test_dataset)
|
60 |
+
print("### Validation dataset")
|
61 |
+
print(test_dataset)
|
62 |
+
print("###")
|
63 |
+
|
64 |
+
# Configure training arguments
|
65 |
training_args = Seq2SeqTrainingArguments(
|
66 |
output_dir="./results",
|
67 |
+
logging_dir="./logs",
|
68 |
+
num_train_epochs=1,
|
69 |
per_device_train_batch_size=16,
|
70 |
per_device_eval_batch_size=64,
|
71 |
+
eval_strategy="steps",
|
|
|
|
|
72 |
save_total_limit=2,
|
73 |
save_steps=500,
|
74 |
eval_steps=500,
|
75 |
+
warmup_steps=500,
|
76 |
+
weight_decay=0.01,
|
77 |
metric_for_best_model="accuracy",
|
78 |
greater_is_better=True,
|
|
|
79 |
load_best_model_at_end=True,
|
|
|
80 |
push_to_hub=True,
|
81 |
+
save_on_each_node=True,
|
82 |
)
|
83 |
+
|
84 |
+
print("### Training arguments")
|
85 |
+
print(training_args)
|
86 |
+
print("###")
|
87 |
|
88 |
+
# Create trainer
|
89 |
trainer = Seq2SeqTrainer(
|
90 |
model=model,
|
91 |
args=training_args,
|
92 |
train_dataset=train_dataset,
|
93 |
+
eval_dataset=test_dataset,
|
94 |
compute_metrics=lambda pred: {"accuracy": torch.sum(pred.label_ids == pred.predictions.argmax(-1))},
|
95 |
)
|
|
|
|
|
|
|
96 |
|
97 |
+
print("### Trainer")
|
98 |
+
print(trainer)
|
99 |
+
print("###")
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
# Train model
|
102 |
+
#trainer.train()
|
|
|
|
|
|
|
103 |
|
104 |
+
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
|
105 |
pipe = pipeline("text-generation",
|
106 |
+
model=model_name,
|
107 |
+
#model_kwargs={"torch_dtype": torch.bfloat16},
|
108 |
device_map="auto",
|
109 |
max_new_tokens=1000)
|
110 |
+
|
111 |
messages = [
|
112 |
{"role": "system", "content": system_prompt.format(schema=schema)},
|
113 |
{"role": "user", "content": user_prompt},
|
114 |
{"role": "assistant", "content": ""}
|
115 |
]
|
116 |
+
|
117 |
output = pipe(messages)
|
118 |
+
|
119 |
result = output[0]["generated_text"][-1]["content"]
|
|
|
|
|
120 |
|
121 |
+
print("###")
|
122 |
+
print(result)
|
123 |
+
print("###")
|
|
|
|
|
124 |
|
125 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
def load_model(model_name):
|
128 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
129 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
130 |
|
131 |
+
if not tokenizer.pad_token:
|
132 |
+
tokenizer.pad_token = tokenizer.eos_token
|
133 |
+
|
134 |
+
return model, tokenizer
|
135 |
+
|
136 |
demo = gr.Interface(fn=process,
|
137 |
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_1),
|
138 |
+
gr.Textbox(label = "Model Name", value = model_name, lines = 1),
|
139 |
+
gr.Textbox(label = "Dataset Name", value = dataset_name, lines = 1),
|
140 |
gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
|
141 |
gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
|
142 |
+
gr.Textbox(label = "SQL Schema", value = sql_schema, lines = 2)],
|
143 |
+
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
|
144 |
demo.launch()
|