import os import torch import transformers import matplotlib.pyplot as plt from datetime import datetime from functools import partial from peft import LoraConfig, get_peft_model from peft import prepare_model_for_kbit_training from datasets import load_dataset from accelerate import FullyShardedDataParallelPlugin, Accelerator from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig def formatting_func_QA(example): text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}" return text def formatting_func_Edit(example, is_train=True): text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\ Question: Given the Edit Action {example['edit']}, what is its edit type?\n" if is_train: text = text + f"### Answer: Edit Class: {example['class']}" return text def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset): lengths = [len(x['input_ids']) for x in tokenized_train_dataset] lengths += [len(x['input_ids']) for x in tokenized_val_dataset] print(len(lengths)) # Plotting the histogram plt.figure(figsize=(10, 6)) plt.hist(lengths, bins=10, alpha=0.7, color='blue') plt.xlabel('Length of input_ids') plt.ylabel('Frequency') plt.title('Distribution of Lengths of input_ids') # Saving the figure to a file plt.savefig('./experiments/figure.png') # Spe def generate_and_tokenize_prompt(prompt, formatting=None): return tokenizer(formatting(prompt)) def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None): result = tokenizer( formatting(prompt), truncation=True, max_length=max_length, padding="max_length", ) result["labels"] = result["input_ids"].copy() return result def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def train(): generate_and_tokenize = partial(generate_and_tokenize_prompt2, max_length=128, formatting=formatting_func_Edit) # configs here latter change model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/" output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm" output_root = "/opt/tiger/llm" os.makedirs(output_root, exist_ok=True) ######### Tune model with Mixtral MoE ######### base_model_id = f"{model_root}/Mixtral-8x7B-v0.1" base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1" base_model_name = "mixtral-8x7b" # ######### Tune model with Mixtral Instruct 7B ######### # base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2" # base_model_name = "mixtral-7b" ######### Instructions ######### train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl" val_json = train_json project = "edit-finetune" run_name = base_model_name + "-" + project output_dir = f"{output_root}/{run_name}" fsdp_plugin = FullyShardedDataParallelPlugin( state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False), optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False), ) accelerator = Accelerator(fsdp_plugin=fsdp_plugin) train_dataset = load_dataset('json', data_files=train_json, split='train') eval_dataset = load_dataset('json', data_files=val_json, split='train') tokenizer = AutoTokenizer.from_pretrained( base_model_id, padding_side="left", add_eos_token=True, add_bos_token=True, ) tokenizer.pad_token = tokenizer.eos_token tokenized_train_dataset = train_dataset.map(generate_and_tokenize) tokenized_val_dataset = eval_dataset.map(generate_and_tokenize) print(tokenized_train_dataset[1]['input_ids']) plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset) # load model and do finetune bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) model = AutoModelForCausalLM.from_pretrained( base_model_id, quantization_config=bnb_config, device_map="auto") model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) print(model) config = LoraConfig( r=32, lora_alpha=64, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "w1", "w2", "w3", "lm_head", ], bias="none", lora_dropout=0.01, # Conventional task_type="CAUSAL_LM", ) model = get_peft_model(model, config) print_trainable_parameters(model) print(model) ## RUN training ## tokenizer = AutoTokenizer.from_pretrained( base_model_id, padding_side="left", add_eos_token=True, add_bos_token=True, ) tokenizer.pad_token = tokenizer.eos_token if torch.cuda.device_count() > 1: # If more than 1 GPU model.is_parallelizable = True model.model_parallel = True trainer = transformers.Trainer( model=model, train_dataset=tokenized_train_dataset, eval_dataset=tokenized_val_dataset, args=transformers.TrainingArguments( output_dir=output_dir, warmup_steps=1, per_device_train_batch_size=2, gradient_accumulation_steps=1, gradient_checkpointing=True, max_steps=100, learning_rate=2.5e-5, # Want a small lr for finetuning fp16=True, optim="paged_adamw_8bit", logging_steps=25, # When to start reporting loss logging_dir="./experiments/logs", # Directory for storing logs save_strategy="steps", # Save the model checkpoint every logging step save_steps=100, # Save checkpoints every 50 steps evaluation_strategy="steps", # Evaluate the model every logging step eval_steps=25, # Evaluate and save checkpoints every 50 steps do_eval=True, # Perform evaluation at the end of training report_to="wandb", # Comment this out if you don't want to use weights & baises run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional) ), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), ) model.config.use_cache = False # silence the warnings. Please re-enable for inference! trainer.train() if __name__ == '__main__': train()