import sys import logging import datasets from datasets import load_dataset from peft import LoraConfig import torch import transformers from trl import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig training_config = { "bf16": True, "do_eval": False, "learning_rate": 5.0e-06, "log_level": "info", "logging_steps": 20, "logging_strategy": "steps", "lr_scheduler_type": "cosine", "num_train_epochs": 1, "max_steps": -1, "output_dir": "./instruct_chk_dir", "overwrite_output_dir": True, "per_device_eval_batch_size": 4, "per_device_train_batch_size": 4, "remove_unused_columns": True, "save_steps": 100, "save_total_limit": 1, "seed": 0, "gradient_checkpointing": True, "gradient_checkpointing_kwargs":{"use_reentrant": False}, "gradient_accumulation_steps": 1, "warmup_ratio": 0.2, } peft_config = { "r": 16, "lora_alpha": 32, "lora_dropout": 0.05, "bias": "none", "task_type": "CAUSAL_LM", "target_modules": "all-linear", "modules_to_save": None, } config = { "max_len": 4096, } train_conf = TrainingArguments(**training_config) peft_conf = LoraConfig(**peft_config) # Model Init checkpoint_path = "microsoft/Phi-3-mini-128k-instruct" model_kwargs = dict( use_cache=False, trust_remote_code=True, attn_implementation="flash_attention_2", # loading the model with flash-attenstion support torch_dtype=torch.bfloat16, #device_map=None device_map="sequential" ) model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) tokenizer.model_max_length = config['max_len'] tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) tokenizer.padding_side = 'right' dataset_id = "BAAI/Infinity-Instruct" raw_dataset = load_dataset(dataset_id, "0625", split="train") dataset = raw_dataset.select(range(10000)) # Preproc dataset def preproc(example, tokenizer): convo = example['conversations'] for i, dic in enumerate(convo): dic['role'] = dic.pop('from') dic['content'] = dic.pop('value') if dic['role'] == 'gpt': dic['role'] = 'assistant' elif dic['role'] == 'human': dic['role'] = 'user' example['text'] = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) return example column_names = list(dataset.features) train_dataset = dataset.map( preproc, fn_kwargs={"tokenizer": tokenizer}, num_proc=10, remove_columns=column_names ) # eval_dataset = dataset[9000:] # eval_dataset = eval_dataset.map( # preproc, # fn_kwargs={"tokenizer": tokenizer}, # num_proc=10, # remove_columns=column_names # ) # Train Model trainer = SFTTrainer( model=model, args=train_conf, peft_config=peft_conf, train_dataset=train_dataset, #eval_dataset=eval_dataset, max_seq_length=config['max_len'], dataset_text_field="text", tokenizer=tokenizer, packing=True ) train_result = trainer.train() metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() # Eval Model tokenizer.padding_side = 'left' metrics = trainer.evaluate() metrics["eval_samples"] = len(processed_test_dataset) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Save model trainer.save_model(train_conf.output_dir) # def apply_chat_template( # example, # tokenizer, # ): # messages = example["messages"] # example["text"] = tokenizer.apply_chat_template( # messages, tokenize=False, add_generation_prompt=False) # return example # raw_dataset = load_dataset("HuggingFaceH4/ultrachat_200k") # train_dataset = raw_dataset["train_sft"].select(range(10000)) # test_dataset = raw_dataset["test_sft"].select(range(1000)) # column_names = list(train_dataset.features) # processed_train_dataset = train_dataset.map( # apply_chat_template, # fn_kwargs={"tokenizer": tokenizer}, # num_proc=10, # remove_columns=column_names, # desc="Applying chat template to train_sft", # ) # processed_test_dataset = test_dataset.map( # apply_chat_template, # fn_kwargs={"tokenizer": tokenizer}, # num_proc=10, # remove_columns=column_names, # desc="Applying chat template to test_sft", # )