|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
import os |
|
import copy |
|
import json |
|
import random |
|
import pathlib |
|
import traceback |
|
from dataclasses import dataclass, field |
|
from typing import Dict, Optional, Sequence, List |
|
|
|
|
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
import transformers |
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock |
|
|
|
import sys |
|
sys.path.append('./') |
|
from videollama2.model import * |
|
from videollama2.constants import NUM_FRAMES, IGNORE_INDEX, MODAL_INDEX_MAP |
|
from videollama2.mm_utils import tokenizer_multimodal_token, process_video, process_image |
|
from videollama2.videollama2_trainer import (VideoLLaMA2Trainer, |
|
get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, |
|
find_all_linear_names, safe_save_model_for_hf_trainer |
|
) |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
local_rank = None |
|
|
|
|
|
def rank0_print(*args): |
|
if local_rank == 0: |
|
print(*args) |
|
|
|
|
|
def set_seed(seed=42): |
|
""" |
|
Set the random seed for reproducible results. |
|
|
|
:param seed: An integer value to be used as the random seed. |
|
""" |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
|
|
model_type: Optional[str] = field(default="videollama2", metadata={"help": "Model type selected in the list: " + ", ".join(VLLMs.keys())}) |
|
model_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.5") |
|
version: Optional[str] = field(default="v1", metadata={"help": "Version of the conversation template."}) |
|
freeze_backbone: bool = field(default=False, metadata={"help": "Whether to freeze the LLM backbone."}) |
|
|
|
mm_projector_type: Optional[str] = field(default='linear') |
|
tune_mm_mlp_adapter: bool = field(default=False) |
|
pretrain_mm_mlp_adapter: Optional[str] = field(default=None) |
|
|
|
vision_tower: Optional[str] = field(default=None) |
|
mm_vision_select_layer: Optional[int] = field(default=-1) |
|
mm_vision_select_feature: Optional[str] = field(default="patch") |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
|
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
|
|
|
|
|
data_folder: Optional[str] = field(default=None) |
|
|
|
is_multimodal: bool = False |
|
lazy_preprocess: bool = False |
|
num_frames: Optional[int] = field(default=None) |
|
|
|
image_aspect_ratio: str = 'square' |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
optim: str = field(default="adamw_torch") |
|
mm_projector_lr: Optional[float] = None |
|
freeze_mm_mlp_adapter: bool = field(default=False) |
|
remove_unused_columns: bool = field(default=False) |
|
cache_dir: Optional[str] = field(default=None) |
|
|
|
group_by_modality_length: bool = field(default=False) |
|
model_max_length: int = field( |
|
default=512, |
|
metadata={ |
|
"help": |
|
"Maximum sequence length. Sequences will be right padded (and possibly truncated)." |
|
}, |
|
) |
|
|
|
double_quant: bool = field( |
|
default=True, |
|
metadata={"help": "Compress the quantization statistics through double quantization."} |
|
) |
|
quant_type: str = field( |
|
default="nf4", |
|
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} |
|
) |
|
bits: int = field( |
|
default=16, |
|
metadata={"help": "How many bits to use."} |
|
) |
|
lora_enable: bool = False |
|
lora_r: int = 64 |
|
lora_alpha: int = 16 |
|
lora_dropout: float = 0.05 |
|
lora_weight_path: str = "" |
|
lora_bias: str = "none" |
|
|
|
|
|
def preprocess_plain( |
|
sources: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
modal_token: str = None, |
|
) -> Dict: |
|
roles = {"human": "user", "gpt": "assistant"} |
|
conversations = [] |
|
input_ids = [] |
|
targets = [] |
|
for source in sources: |
|
|
|
assert len(source) == 2 |
|
assert modal_token in source[0]['value'] |
|
message = [ |
|
{'role': 'user', 'content': modal_token}, |
|
{'role': 'assistant', 'content': source[1]['value']} |
|
] |
|
conversation = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) |
|
|
|
input_ids.append(tokenizer_multimodal_token(conversation, tokenizer, modal_token, return_tensors='pt')) |
|
|
|
targets.append(copy.deepcopy(input_ids[-1])) |
|
instruction = tokenizer.apply_chat_template(message[:1], tokenize=False, add_generation_prompt=True) |
|
instruction_len = len(tokenizer_multimodal_token(instruction, tokenizer, modal_token, return_tensors='pt')) |
|
targets[-1][:instruction_len] = IGNORE_INDEX |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
def preprocess( |
|
sources: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
modal_token: str = None, |
|
) -> Dict: |
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
conversations = [] |
|
input_ids = [] |
|
targets = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != "user": |
|
|
|
source = source[1:] |
|
|
|
message = [{'role': roles[sentence['from']], 'content': sentence['value']} for sentence in source] |
|
conversation = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) |
|
input_ids.append(tokenizer_multimodal_token(conversation, tokenizer, modal_token, return_tensors='pt')) |
|
targets.append(copy.deepcopy(input_ids[-1])) |
|
|
|
assert len(source) % 2 == 0, f"Invalid conversation length {len(source)}." |
|
|
|
cur = 0 |
|
message = [] |
|
for idx, sentence in enumerate(source): |
|
if idx % 2 == 1: |
|
tmp_message = [ |
|
{'role': roles[source[idx-1]['from']], 'content': source[idx-1]['value']}, |
|
{'role': roles[sentence['from']], 'content': sentence['value']} |
|
] |
|
|
|
instruction = tokenizer.apply_chat_template(message + tmp_message[:1], tokenize=False, add_generation_prompt=True) |
|
conversation = tokenizer.apply_chat_template(message + tmp_message, tokenize=False, add_generation_prompt=False) |
|
|
|
instruction_len = len(tokenizer_multimodal_token(instruction, tokenizer, modal_token, return_tensors='pt')) |
|
conversation_len = len(tokenizer_multimodal_token(conversation, tokenizer, modal_token, return_tensors='pt')) |
|
|
|
targets[-1][cur:instruction_len] = IGNORE_INDEX |
|
|
|
cur = conversation_len |
|
message += tmp_message |
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
def preprocess_multimodal( |
|
sources: Sequence[str], |
|
data_args: DataArguments, |
|
modal_token: str = None, |
|
) -> Dict: |
|
is_multimodal = data_args.is_multimodal |
|
if not is_multimodal: |
|
return sources |
|
|
|
assert modal_token in MODAL_INDEX_MAP, f"Unsupported modal token {modal_token}." |
|
|
|
for source in sources: |
|
for sentence in source: |
|
if modal_token in sentence['value']: |
|
sentence['value'] = sentence['value'].replace(modal_token, '').strip() |
|
sentence['value'] = modal_token + '\n' + sentence['value'] |
|
sentence['value'] = sentence['value'].strip() |
|
replace_token = modal_token |
|
|
|
sentence["value"] = sentence["value"].replace(modal_token, replace_token) |
|
|
|
return sources |
|
|
|
|
|
class LazySupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__(self, data_path: str, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
data_args: DataArguments): |
|
super(LazySupervisedDataset, self).__init__() |
|
list_data_dict = json.load(open(data_path, "r")) |
|
|
|
rank0_print("Formatting inputs...Skip in lazy mode") |
|
self.tokenizer = tokenizer |
|
self.list_data_dict = list_data_dict |
|
self.data_args = data_args |
|
|
|
def __len__(self): |
|
return len(self.list_data_dict) |
|
|
|
@property |
|
def lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
img_tokens = 576 if 'image' in sample else 0 |
|
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) |
|
return length_list |
|
|
|
@property |
|
def modality_lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) |
|
cur_len = cur_len if 'image' in sample else -cur_len |
|
length_list.append(cur_len) |
|
return length_list |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
sources = self.list_data_dict[i] |
|
if isinstance(i, int): |
|
sources = [sources] |
|
assert len(sources) == 1, "Don't know why it is wrapped to a list" |
|
|
|
image_processor = self.data_args.image_processor |
|
video_processor = self.data_args.video_processor |
|
|
|
num_frames = NUM_FRAMES if self.data_args.num_frames is None else self.data_args.num_frames |
|
|
|
if 'image' in sources[0]: |
|
image_file = self.list_data_dict[i]['image'] |
|
image_folder = self.data_args.data_folder |
|
image_file = os.path.join(image_folder, image_file) |
|
|
|
try: |
|
image = process_image(image_file, image_processor, aspect_ratio=self.data_args.image_aspect_ratio) |
|
except: |
|
traceback.print_exc() |
|
backup_idx = random.randint(0, len(self.list_data_dict) - 1) |
|
print(f"Encounted error when reading image {image_file}, use {backup_idx}-th example instead!!!") |
|
return self.__getitem__(backup_idx) |
|
|
|
|
|
modal_token = "<image>" |
|
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args, modal_token) |
|
elif 'video' in sources[0]: |
|
video_file = self.list_data_dict[i]['video'] |
|
video_folder = self.data_args.data_folder |
|
video_file = os.path.join(video_folder, video_file) |
|
|
|
try: |
|
video = process_video(video_file, video_processor, aspect_ratio=self.data_args.image_aspect_ratio, num_frames=num_frames) |
|
except Exception as e: |
|
traceback.print_exc() |
|
backup_idx = random.randint(0, len(self.list_data_dict) - 1) |
|
print(f"Encounted error when reading video {video_file}, use {backup_idx}-th example instead!!!") |
|
return self.__getitem__(backup_idx) |
|
|
|
|
|
modal_token = "<video>" |
|
sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args, modal_token) |
|
else: |
|
modal_token = None |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
|
if self.data_args.is_pretraining: |
|
data_dict = preprocess_plain(sources, self.tokenizer, modal_token=modal_token) |
|
else: |
|
data_dict = preprocess(sources, self.tokenizer, modal_token=modal_token) |
|
|
|
if isinstance(i, int): |
|
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) |
|
|
|
|
|
if 'image' in self.list_data_dict[i]: |
|
data_dict['image'] = image |
|
elif 'video' in self.list_data_dict[i]: |
|
data_dict['video'] = video |
|
elif self.data_args.is_multimodal: |
|
|
|
data_dict['image'] = torch.zeros(3, self.data_args.image_size, self.data_args.image_size) |
|
return data_dict |
|
|
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple([instance[key] for instance in instances] |
|
for key in ("input_ids", "labels")) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, |
|
batch_first=True, |
|
padding_value=IGNORE_INDEX) |
|
input_ids = input_ids[:, :self.tokenizer.model_max_length] |
|
labels = labels[:, :self.tokenizer.model_max_length] |
|
batch = dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
batch['images'] = [] |
|
for instance in instances: |
|
for modal_token in MODAL_INDEX_MAP.keys(): |
|
modal_token = modal_token.lower() |
|
|
|
modal_name = re.findall(f'[<](.*)[>]', modal_token) |
|
assert len(modal_name) == 1 |
|
modal_name = modal_name[0] |
|
if modal_name in instance: |
|
batch['images'].append((instance[modal_name], modal_name)) |
|
|
|
return batch |
|
|
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, |
|
data_args) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = LazySupervisedDataset( |
|
tokenizer=tokenizer, |
|
data_path=data_args.data_path, |
|
data_args=data_args |
|
) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict(train_dataset=train_dataset, |
|
eval_dataset=None, |
|
data_collator=data_collator) |
|
|
|
|
|
def train(attn_implementation=None): |
|
global local_rank |
|
set_seed(42) |
|
|
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
local_rank = training_args.local_rank |
|
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
|
|
|
bnb_model_from_pretrained_args = {} |
|
if training_args.bits in [4, 8]: |
|
from transformers import BitsAndBytesConfig |
|
bnb_model_from_pretrained_args.update(dict( |
|
|
|
|
|
|
|
|
|
|
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=training_args.bits == 4, |
|
load_in_8bit=training_args.bits == 8, |
|
llm_int8_skip_modules=["mm_projector"], |
|
llm_int8_threshold=6.0, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=compute_dtype, |
|
bnb_4bit_use_double_quant=training_args.double_quant, |
|
bnb_4bit_quant_type=training_args.quant_type, |
|
bnb_4bit_quant_storage=compute_dtype, |
|
) |
|
)) |
|
|
|
config = VLLMConfigs[model_args.model_type].from_pretrained(model_args.model_path, trust_remote_code=True) |
|
if 'gemma2' in model_args.model_type: |
|
config._attn_implementation = 'eager' |
|
else: |
|
config._attn_implementation = attn_implementation |
|
|
|
if model_args.vision_tower is not None: |
|
model = VLLMs[model_args.model_type].from_pretrained( |
|
model_args.model_path, |
|
config=config, |
|
cache_dir=training_args.cache_dir, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
do_sample=True, |
|
**bnb_model_from_pretrained_args |
|
) |
|
if 'mixtral' in model_args.model_type: |
|
import deepspeed |
|
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) |
|
else: |
|
model = transformers.LlamaForCausalLM.from_pretrained( |
|
model_args.model_path, |
|
config=config, |
|
cache_dir=training_args.cache_dir, |
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
do_sample=True, |
|
**bnb_model_from_pretrained_args |
|
) |
|
model.config.use_cache = False |
|
|
|
if model_args.freeze_backbone: |
|
model.model.requires_grad_(False) |
|
|
|
if training_args.bits in [4, 8]: |
|
from peft import prepare_model_for_kbit_training |
|
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) |
|
|
|
if training_args.gradient_checkpointing: |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
if training_args.lora_enable: |
|
from peft import LoraConfig, get_peft_model |
|
lora_config = LoraConfig( |
|
r=training_args.lora_r, |
|
lora_alpha=training_args.lora_alpha, |
|
target_modules=find_all_linear_names(model), |
|
lora_dropout=training_args.lora_dropout, |
|
bias=training_args.lora_bias, |
|
task_type="CAUSAL_LM", |
|
) |
|
if training_args.bits == 16: |
|
if training_args.bf16: |
|
model.to(torch.bfloat16) |
|
if training_args.fp16: |
|
model.to(torch.float16) |
|
rank0_print("Adding LoRA adapters...") |
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=True, |
|
) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
|
if model_args.vision_tower is not None: |
|
|
|
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) |
|
|
|
vision_tower = model.get_vision_tower() |
|
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) |
|
|
|
data_args.image_size = vision_tower.image_size |
|
|
|
data_args.image_processor = vision_tower.image_processor |
|
data_args.video_processor = vision_tower.video_processor if hasattr(vision_tower, "video_processor") else vision_tower.image_processor |
|
|
|
data_args.is_multimodal = True |
|
|
|
model.config.image_aspect_ratio = data_args.image_aspect_ratio |
|
model.config.tokenizer_padding_side = tokenizer.padding_side |
|
model.config.tokenizer_model_max_length = tokenizer.model_max_length |
|
|
|
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter |
|
if model_args.tune_mm_mlp_adapter: |
|
model.requires_grad_(False) |
|
for p in model.get_model().mm_projector.parameters(): |
|
p.requires_grad = True |
|
|
|
if model_args.tune_mm_mlp_adapter: |
|
data_args.is_pretraining = True |
|
else: |
|
data_args.is_pretraining = False |
|
|
|
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter |
|
if training_args.freeze_mm_mlp_adapter: |
|
for p in model.get_model().mm_projector.parameters(): |
|
p.requires_grad = False |
|
|
|
if training_args.bits in [4, 8]: |
|
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) |
|
|
|
model.config.mm_projector_lr = training_args.mm_projector_lr |
|
model.config.num_frames = NUM_FRAMES if data_args.num_frames is None else data_args.num_frames |
|
|
|
if training_args.bits in [4, 8]: |
|
from peft.tuners.lora import LoraLayer |
|
for name, module in model.named_modules(): |
|
if isinstance(module, LoraLayer): |
|
if training_args.bf16: |
|
module = module.to(torch.bfloat16) |
|
if 'norm' in name: |
|
module = module.to(torch.float32) |
|
if 'lm_head' in name or 'embed_tokens' in name: |
|
if hasattr(module, 'weight'): |
|
if training_args.bf16 and module.weight.dtype == torch.float32: |
|
module = module.to(torch.bfloat16) |
|
|
|
print("Current model:", model) |
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
|
|
trainer = VideoLLaMA2Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) |
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
trainer.save_state() |
|
|
|
model.config.use_cache = True |
|
|
|
if training_args.lora_enable: |
|
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias) |
|
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters()) |
|
if training_args.local_rank == 0 or training_args.local_rank == -1: |
|
model.config.save_pretrained(training_args.output_dir) |
|
model.save_pretrained(training_args.output_dir, state_dict=state_dict) |
|
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) |
|
else: |
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|