bstraehle commited on
Commit
7be2c23
1 Parent(s): c99016f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
- #from peft import LoraConfig
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
@@ -46,12 +46,14 @@ def fine_tune_model(base_model_name, dataset_name):
46
 
47
  model, tokenizer = load_model(base_model_name)
48
 
 
 
49
  print("### Model")
50
  print(model)
51
  print("### Tokenizer")
52
  print(tokenizer)
53
  print("###")
54
-
55
  # Pre-process dataset
56
 
57
  def preprocess(examples):
@@ -91,14 +93,22 @@ def fine_tune_model(base_model_name, dataset_name):
91
  print(training_args)
92
  print("###")
93
 
94
- # PEFT
95
 
96
- #peft_config = LoraConfig(
97
- # r=8,
98
- # bias="none",
99
- # task_type="CAUSAL_LM",
100
- #)
 
 
 
 
101
 
 
 
 
 
102
  # Create trainer
103
 
104
  trainer = Seq2SeqTrainer(
@@ -106,17 +116,16 @@ def fine_tune_model(base_model_name, dataset_name):
106
  args=training_args,
107
  train_dataset=train_dataset,
108
  eval_dataset=eval_dataset,
109
- #peft_config=peft_config,
110
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
111
  )
112
 
113
  # Train model
114
 
115
- trainer.train()
116
 
117
  # Push tokenizer to HF
118
 
119
- tokenizer.push_to_hub(FT_MODEL_NAME)
120
 
121
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
122
  pipe = pipeline("text-generation",
 
4
  import os, torch
5
  from datasets import load_dataset
6
  from huggingface_hub import HfApi, login
7
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
9
 
10
  ACTION_1 = "Prompt base model"
 
46
 
47
  model, tokenizer = load_model(base_model_name)
48
 
49
+ model.print_trainable_parameters()
50
+
51
  print("### Model")
52
  print(model)
53
  print("### Tokenizer")
54
  print(tokenizer)
55
  print("###")
56
+
57
  # Pre-process dataset
58
 
59
  def preprocess(examples):
 
93
  print(training_args)
94
  print("###")
95
 
96
+ # PEFT https://www.philschmid.de/fine-tune-flan-t5-peft
97
 
98
+ lora_config = LoraConfig(
99
+ r=16,
100
+ # TODO
101
+ #bias="none",
102
+ #lora_alpha=32,
103
+ #lora_dropout=0.05,
104
+ #target_modules=["q", "v"],
105
+ task_type=TaskType.SEQ_2_SEQ_LM,
106
+ )
107
 
108
+ model = prepare_model_for_int8_training(model)
109
+ model = get_peft_model(model, lora_config)
110
+ model.print_trainable_parameters()
111
+
112
  # Create trainer
113
 
114
  trainer = Seq2SeqTrainer(
 
116
  args=training_args,
117
  train_dataset=train_dataset,
118
  eval_dataset=eval_dataset,
 
119
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
120
  )
121
 
122
  # Train model
123
 
124
+ #trainer.train()
125
 
126
  # Push tokenizer to HF
127
 
128
+ #tokenizer.push_to_hub(FT_MODEL_NAME)
129
 
130
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
131
  pipe = pipeline("text-generation",