annotation / mixtral_tune.py
MudeHui's picture
Add application file
1fb65ae
import os
import torch
import transformers
import matplotlib.pyplot as plt
from datetime import datetime
from functools import partial
from peft import LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training
from datasets import load_dataset
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
def formatting_func_QA(example):
text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}"
return text
def formatting_func_Edit(example, is_train=True):
text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\
Question: Given the Edit Action {example['edit']}, what is its edit type?\n"
if is_train:
text = text + f"### Answer: Edit Class: {example['class']}"
return text
def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset):
lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
print(len(lengths))
# Plotting the histogram
plt.figure(figsize=(10, 6))
plt.hist(lengths, bins=10, alpha=0.7, color='blue')
plt.xlabel('Length of input_ids')
plt.ylabel('Frequency')
plt.title('Distribution of Lengths of input_ids')
# Saving the figure to a file
plt.savefig('./experiments/figure.png') # Spe
def generate_and_tokenize_prompt(prompt, formatting=None):
return tokenizer(formatting(prompt))
def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None):
result = tokenizer(
formatting(prompt),
truncation=True,
max_length=max_length,
padding="max_length",
)
result["labels"] = result["input_ids"].copy()
return result
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def train():
generate_and_tokenize = partial(generate_and_tokenize_prompt2,
max_length=128,
formatting=formatting_func_Edit)
# configs here latter change
model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/"
output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm"
output_root = "/opt/tiger/llm"
os.makedirs(output_root, exist_ok=True)
######### Tune model with Mixtral MoE #########
base_model_id = f"{model_root}/Mixtral-8x7B-v0.1"
base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1"
base_model_name = "mixtral-8x7b"
# ######### Tune model with Mixtral Instruct 7B #########
# base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2"
# base_model_name = "mixtral-7b"
######### Instructions #########
train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl"
val_json = train_json
project = "edit-finetune"
run_name = base_model_name + "-" + project
output_dir = f"{output_root}/{run_name}"
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
train_dataset = load_dataset('json', data_files=train_json, split='train')
eval_dataset = load_dataset('json', data_files=val_json, split='train')
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
padding_side="left",
add_eos_token=True,
add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenized_train_dataset = train_dataset.map(generate_and_tokenize)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize)
print(tokenized_train_dataset[1]['input_ids'])
plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)
# load model and do finetune
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id, quantization_config=bnb_config, device_map="auto")
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
print(model)
config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"w1",
"w2",
"w3",
"lm_head",
],
bias="none",
lora_dropout=0.01, # Conventional
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
print(model)
## RUN training ##
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
padding_side="left",
add_eos_token=True,
add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
if torch.cuda.device_count() > 1: # If more than 1 GPU
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_val_dataset,
args=transformers.TrainingArguments(
output_dir=output_dir,
warmup_steps=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
max_steps=100,
learning_rate=2.5e-5, # Want a small lr for finetuning
fp16=True,
optim="paged_adamw_8bit",
logging_steps=25, # When to start reporting loss
logging_dir="./experiments/logs", # Directory for storing logs
save_strategy="steps", # Save the model checkpoint every logging step
save_steps=100, # Save checkpoints every 50 steps
evaluation_strategy="steps", # Evaluate the model every logging step
eval_steps=25, # Evaluate and save checkpoints every 50 steps
do_eval=True, # Perform evaluation at the end of training
report_to="wandb", # Comment this out if you don't want to use weights & baises
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional)
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
if __name__ == '__main__':
train()