Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import transformers | |
import matplotlib.pyplot as plt | |
from datetime import datetime | |
from functools import partial | |
from peft import LoraConfig, get_peft_model | |
from peft import prepare_model_for_kbit_training | |
from datasets import load_dataset | |
from accelerate import FullyShardedDataParallelPlugin, Accelerator | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
def formatting_func_QA(example): | |
text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}" | |
return text | |
def formatting_func_Edit(example, is_train=True): | |
text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\ | |
Question: Given the Edit Action {example['edit']}, what is its edit type?\n" | |
if is_train: | |
text = text + f"### Answer: Edit Class: {example['class']}" | |
return text | |
def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset): | |
lengths = [len(x['input_ids']) for x in tokenized_train_dataset] | |
lengths += [len(x['input_ids']) for x in tokenized_val_dataset] | |
print(len(lengths)) | |
# Plotting the histogram | |
plt.figure(figsize=(10, 6)) | |
plt.hist(lengths, bins=10, alpha=0.7, color='blue') | |
plt.xlabel('Length of input_ids') | |
plt.ylabel('Frequency') | |
plt.title('Distribution of Lengths of input_ids') | |
# Saving the figure to a file | |
plt.savefig('./experiments/figure.png') # Spe | |
def generate_and_tokenize_prompt(prompt, formatting=None): | |
return tokenizer(formatting(prompt)) | |
def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None): | |
result = tokenizer( | |
formatting(prompt), | |
truncation=True, | |
max_length=max_length, | |
padding="max_length", | |
) | |
result["labels"] = result["input_ids"].copy() | |
return result | |
def print_trainable_parameters(model): | |
""" | |
Prints the number of trainable parameters in the model. | |
""" | |
trainable_params = 0 | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
if param.requires_grad: | |
trainable_params += param.numel() | |
print( | |
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
) | |
def train(): | |
generate_and_tokenize = partial(generate_and_tokenize_prompt2, | |
max_length=128, | |
formatting=formatting_func_Edit) | |
# configs here latter change | |
model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/" | |
output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm" | |
output_root = "/opt/tiger/llm" | |
os.makedirs(output_root, exist_ok=True) | |
######### Tune model with Mixtral MoE ######### | |
base_model_id = f"{model_root}/Mixtral-8x7B-v0.1" | |
base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1" | |
base_model_name = "mixtral-8x7b" | |
# ######### Tune model with Mixtral Instruct 7B ######### | |
# base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2" | |
# base_model_name = "mixtral-7b" | |
######### Instructions ######### | |
train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl" | |
val_json = train_json | |
project = "edit-finetune" | |
run_name = base_model_name + "-" + project | |
output_dir = f"{output_root}/{run_name}" | |
fsdp_plugin = FullyShardedDataParallelPlugin( | |
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False), | |
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False), | |
) | |
accelerator = Accelerator(fsdp_plugin=fsdp_plugin) | |
train_dataset = load_dataset('json', data_files=train_json, split='train') | |
eval_dataset = load_dataset('json', data_files=val_json, split='train') | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_model_id, | |
padding_side="left", | |
add_eos_token=True, | |
add_bos_token=True, | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenized_train_dataset = train_dataset.map(generate_and_tokenize) | |
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize) | |
print(tokenized_train_dataset[1]['input_ids']) | |
plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset) | |
# load model and do finetune | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, quantization_config=bnb_config, device_map="auto") | |
model.gradient_checkpointing_enable() | |
model = prepare_model_for_kbit_training(model) | |
print(model) | |
config = LoraConfig( | |
r=32, | |
lora_alpha=64, | |
target_modules=[ | |
"q_proj", | |
"k_proj", | |
"v_proj", | |
"o_proj", | |
"w1", | |
"w2", | |
"w3", | |
"lm_head", | |
], | |
bias="none", | |
lora_dropout=0.01, # Conventional | |
task_type="CAUSAL_LM", | |
) | |
model = get_peft_model(model, config) | |
print_trainable_parameters(model) | |
print(model) | |
## RUN training ## | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_model_id, | |
padding_side="left", | |
add_eos_token=True, | |
add_bos_token=True, | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
if torch.cuda.device_count() > 1: # If more than 1 GPU | |
model.is_parallelizable = True | |
model.model_parallel = True | |
trainer = transformers.Trainer( | |
model=model, | |
train_dataset=tokenized_train_dataset, | |
eval_dataset=tokenized_val_dataset, | |
args=transformers.TrainingArguments( | |
output_dir=output_dir, | |
warmup_steps=1, | |
per_device_train_batch_size=2, | |
gradient_accumulation_steps=1, | |
gradient_checkpointing=True, | |
max_steps=100, | |
learning_rate=2.5e-5, # Want a small lr for finetuning | |
fp16=True, | |
optim="paged_adamw_8bit", | |
logging_steps=25, # When to start reporting loss | |
logging_dir="./experiments/logs", # Directory for storing logs | |
save_strategy="steps", # Save the model checkpoint every logging step | |
save_steps=100, # Save checkpoints every 50 steps | |
evaluation_strategy="steps", # Evaluate the model every logging step | |
eval_steps=25, # Evaluate and save checkpoints every 50 steps | |
do_eval=True, # Perform evaluation at the end of training | |
report_to="wandb", # Comment this out if you don't want to use weights & baises | |
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional) | |
), | |
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), | |
) | |
model.config.use_cache = False # silence the warnings. Please re-enable for inference! | |
trainer.train() | |
if __name__ == '__main__': | |
train() |