Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
import os, torch
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
-
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
@@ -46,12 +46,14 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
46 |
|
47 |
model, tokenizer = load_model(base_model_name)
|
48 |
|
|
|
|
|
49 |
print("### Model")
|
50 |
print(model)
|
51 |
print("### Tokenizer")
|
52 |
print(tokenizer)
|
53 |
print("###")
|
54 |
-
|
55 |
# Pre-process dataset
|
56 |
|
57 |
def preprocess(examples):
|
@@ -91,14 +93,22 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
91 |
print(training_args)
|
92 |
print("###")
|
93 |
|
94 |
-
# PEFT
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
|
|
|
|
|
|
|
|
|
102 |
# Create trainer
|
103 |
|
104 |
trainer = Seq2SeqTrainer(
|
@@ -106,17 +116,16 @@ def fine_tune_model(base_model_name, dataset_name):
|
|
106 |
args=training_args,
|
107 |
train_dataset=train_dataset,
|
108 |
eval_dataset=eval_dataset,
|
109 |
-
#peft_config=peft_config,
|
110 |
# TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
|
111 |
)
|
112 |
|
113 |
# Train model
|
114 |
|
115 |
-
trainer.train()
|
116 |
|
117 |
# Push tokenizer to HF
|
118 |
|
119 |
-
tokenizer.push_to_hub(FT_MODEL_NAME)
|
120 |
|
121 |
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
122 |
pipe = pipeline("text-generation",
|
|
|
4 |
import os, torch
|
5 |
from datasets import load_dataset
|
6 |
from huggingface_hub import HfApi, login
|
7 |
+
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
9 |
|
10 |
ACTION_1 = "Prompt base model"
|
|
|
46 |
|
47 |
model, tokenizer = load_model(base_model_name)
|
48 |
|
49 |
+
model.print_trainable_parameters()
|
50 |
+
|
51 |
print("### Model")
|
52 |
print(model)
|
53 |
print("### Tokenizer")
|
54 |
print(tokenizer)
|
55 |
print("###")
|
56 |
+
|
57 |
# Pre-process dataset
|
58 |
|
59 |
def preprocess(examples):
|
|
|
93 |
print(training_args)
|
94 |
print("###")
|
95 |
|
96 |
+
# PEFT https://www.philschmid.de/fine-tune-flan-t5-peft
|
97 |
|
98 |
+
lora_config = LoraConfig(
|
99 |
+
r=16,
|
100 |
+
# TODO
|
101 |
+
#bias="none",
|
102 |
+
#lora_alpha=32,
|
103 |
+
#lora_dropout=0.05,
|
104 |
+
#target_modules=["q", "v"],
|
105 |
+
task_type=TaskType.SEQ_2_SEQ_LM,
|
106 |
+
)
|
107 |
|
108 |
+
model = prepare_model_for_int8_training(model)
|
109 |
+
model = get_peft_model(model, lora_config)
|
110 |
+
model.print_trainable_parameters()
|
111 |
+
|
112 |
# Create trainer
|
113 |
|
114 |
trainer = Seq2SeqTrainer(
|
|
|
116 |
args=training_args,
|
117 |
train_dataset=train_dataset,
|
118 |
eval_dataset=eval_dataset,
|
|
|
119 |
# TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
|
120 |
)
|
121 |
|
122 |
# Train model
|
123 |
|
124 |
+
#trainer.train()
|
125 |
|
126 |
# Push tokenizer to HF
|
127 |
|
128 |
+
#tokenizer.push_to_hub(FT_MODEL_NAME)
|
129 |
|
130 |
def prompt_model(model_name, system_prompt, user_prompt, sql_context):
|
131 |
pipe = pipeline("text-generation",
|