Jen Ben Arye commited on
Commit
6f11489
·
1 Parent(s): 669afda

trained model using kto - sanity check

Browse files
Files changed (1) hide show
  1. kto_pipeline.py +122 -0
kto_pipeline.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # os.environ['CUDA_VISIBLE_DEVICES'] = "3"
3
+
4
+ import torch
5
+ from dataclasses import dataclass
6
+
7
+ from accelerate import PartialState
8
+ from datasets import load_dataset, DatasetDict
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
10
+
11
+ from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
12
+
13
+
14
+
15
+ print(f'GPU number: {torch.cuda.current_device()}')
16
+
17
+
18
+ # Define and parse arguments.
19
+ @dataclass
20
+ class ScriptArguments:
21
+ """
22
+ The arguments for the KTO training script.
23
+ """
24
+
25
+ dataset_name: str = "trl-lib/kto-mix-14k"
26
+
27
+
28
+ # Initialize the arguments directly
29
+ script_args = ScriptArguments(
30
+ dataset_name="trl-lib/kto-mix-14k"
31
+ )
32
+
33
+ training_args = KTOConfig(
34
+ output_dir="/raid/lingo/jen_ben/HF-RLHF/kto_oct_26", # MODFIFY
35
+ num_train_epochs=100,
36
+ per_device_train_batch_size=32,
37
+ learning_rate=5e-7,
38
+ lr_scheduler_type="cosine",
39
+ gradient_accumulation_steps=1,
40
+ logging_steps=10,
41
+ eval_steps=500,
42
+ warmup_ratio=0.1,
43
+ bf16=True,
44
+ logging_first_step=True
45
+ )
46
+
47
+ model_args = ModelConfig(
48
+ model_name_or_path="trl-lib/qwen1.5-1.8b-sft",
49
+ # any additional model-specific arguments
50
+ )
51
+
52
+
53
+ # Load a pretrained model
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
56
+ )
57
+ ref_model = AutoModelForCausalLM.from_pretrained(
58
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
59
+ )
60
+ print(f'loaded model')
61
+
62
+ # load a tokenaizer
63
+ tokenizer = AutoTokenizer.from_pretrained(
64
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
65
+ )
66
+ if tokenizer.pad_token is None:
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+ # If we are aligning a base model, we use ChatML as the default template
70
+ if tokenizer.chat_template is None:
71
+ model, tokenizer = setup_chat_format(model, tokenizer)
72
+ print(f'loaded tokenizer')
73
+
74
+ # Load the dataset
75
+ dataset = load_dataset(script_args.dataset_name)
76
+
77
+ # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
78
+ dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)
79
+ print(f'loaded dataset')
80
+
81
+
82
+ # Apply chat template
83
+ def format_dataset(example):
84
+ example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
85
+ example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
86
+ return example
87
+
88
+
89
+ # Compute that only on the main process for faster data processing.
90
+ # see: https://github.com/huggingface/trl/pull/1255
91
+ with PartialState().local_main_process_first():
92
+ dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)
93
+
94
+
95
+
96
+ # Initialize the KTO trainer
97
+ trainer = KTOTrainer(
98
+ model,
99
+ ref_model,
100
+ args=training_args,
101
+ train_dataset=dataset["train"],
102
+ eval_dataset=dataset["test"],
103
+ tokenizer=tokenizer,
104
+ peft_config=get_peft_config(model_args),
105
+ )
106
+
107
+ print(f'start training')
108
+
109
+ trainer.train()
110
+
111
+ print(f'finished training')
112
+
113
+ metrics = trainer.evaluate()
114
+ print(f'metrics: \n {metrics}')
115
+ trainer.log_metrics("eval", metrics)
116
+ trainer.save_metrics("eval", metrics)
117
+
118
+
119
+ # Save and push to hub
120
+ trainer.save_model(training_args.output_dir)
121
+ if training_args.push_to_hub:
122
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)