burtenshaw commited on
Commit
aac30ac
·
1 Parent(s): 9cc6120

add trl script

Browse files
Files changed (2) hide show
  1. ml/kto.py +117 -0
  2. ml/train.sh +15 -0
ml/kto.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
17
+
18
+ # Full training:
19
+ python examples/scripts/kto.py \
20
+ --dataset_name trl-lib/kto-mix-14k \
21
+ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
22
+ --per_device_train_batch_size 16 \
23
+ --num_train_epochs 1 \
24
+ --learning_rate 5e-7 \
25
+ --lr_scheduler_type=cosine \
26
+ --gradient_accumulation_steps 1 \
27
+ --logging_steps 10 \
28
+ --eval_steps 500 \
29
+ --output_dir=kto-aligned-model \
30
+ --warmup_ratio 0.1 \
31
+ --report_to wandb \
32
+ --bf16 \
33
+ --logging_first_step
34
+
35
+ # QLoRA:
36
+ python examples/scripts/kto.py \
37
+ --dataset_name trl-lib/kto-mix-14k \
38
+ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
39
+ --per_device_train_batch_size 8 \
40
+ --num_train_epochs 1 \
41
+ --learning_rate 5e-7 \
42
+ --lr_scheduler_type=cosine \
43
+ --gradient_accumulation_steps 1 \
44
+ --logging_steps 10 \
45
+ --eval_steps 500 \
46
+ --output_dir=kto-aligned-model-lora \
47
+ --warmup_ratio 0.1 \
48
+ --report_to wandb \
49
+ --bf16 \
50
+ --logging_first_step \
51
+ --use_peft \
52
+ --load_in_4bit \
53
+ --lora_target_modules=all-linear \
54
+ --lora_r=16 \
55
+ --lora_alpha=16
56
+ """
57
+
58
+ from datasets import load_dataset
59
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
60
+
61
+ from trl import (
62
+ KTOConfig,
63
+ KTOTrainer,
64
+ ModelConfig,
65
+ ScriptArguments,
66
+ get_peft_config,
67
+ setup_chat_format,
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
73
+ script_args, training_args, model_args = parser.parse_args_into_dataclasses()
74
+
75
+ # Load a pretrained model
76
+ model = AutoModelForCausalLM.from_pretrained(
77
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
78
+ )
79
+ ref_model = AutoModelForCausalLM.from_pretrained(
80
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
81
+ )
82
+
83
+ tokenizer = AutoTokenizer.from_pretrained(
84
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
85
+ )
86
+ if tokenizer.pad_token is None:
87
+ tokenizer.pad_token = tokenizer.eos_token
88
+
89
+ # If we are aligning a base model, we use ChatML as the default template
90
+ if tokenizer.chat_template is None:
91
+ model, tokenizer = setup_chat_format(model, tokenizer)
92
+
93
+ # Load the dataset
94
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
95
+
96
+ # Initialize the KTO trainer
97
+ trainer = KTOTrainer(
98
+ model,
99
+ ref_model,
100
+ args=training_args,
101
+ train_dataset=dataset[script_args.dataset_train_split],
102
+ eval_dataset=(
103
+ dataset[script_args.dataset_test_split]
104
+ if training_args.eval_strategy != "no"
105
+ else None
106
+ ),
107
+ processing_class=tokenizer,
108
+ peft_config=get_peft_config(model_args),
109
+ )
110
+
111
+ # Train and push the model to the Hub
112
+ trainer.train()
113
+
114
+ # Save and push to hub
115
+ trainer.save_model(training_args.output_dir)
116
+ if training_args.push_to_hub:
117
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ml/train.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python kto.py \
2
+ --dataset_name trl-lib/kto-mix-14k \
3
+ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
4
+ --per_device_train_batch_size 16 \
5
+ --num_train_epochs 1 \
6
+ --learning_rate 5e-7 \
7
+ --lr_scheduler_type=cosine \
8
+ --gradient_accumulation_steps 1 \
9
+ --logging_steps 10 \
10
+ --eval_steps 500 \
11
+ --output_dir=kto-aligned-model \
12
+ --warmup_ratio 0.1 \
13
+ --report_to wandb \
14
+ --bf16 \
15
+ --logging_first_step