VILA15_3b / llava_llama.py
Ligeng-Zhu's picture
Upload files with huggingface_hub
13f83b4 verified
raw
history blame contribute delete
No virus
45.2 kB
import inspect
# from .builder import build_llm_and_tokenizer, build_mm_projector, build_vision_tower
import os
import os.path as osp
import shutil
import warnings
from typing import List, Optional, Tuple, Union
# from .llava_llama import LlavaLlamaModel
# from llava.model import *
# from llava.model.utils import is_mm_model
import torch
import torch.nn as nn
from huggingface_hub import repo_exists, snapshot_download
from huggingface_hub.utils import HFValidationError, validate_repo_id
# from llava.model.multimodal_encoder.vision_encoder import (VisionTower,
# VisionTowerS2)
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, BitsAndBytesConfig, GenerationConfig,
LlamaConfig, LlamaForCausalLM, PretrainedConfig,
PreTrainedModel, SiglipImageProcessor,
SiglipVisionModel)
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_llava import LlavaConfig # , LlavaLlamaConfig
# from .llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
from .utils import get_model_config
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
def is_deepspeed_zero3_enabled():
return None
import torch
import torch.nn as nn
from transformers import (AutoConfig, AutoModel, PretrainedConfig,
PreTrainedModel)
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class DownSampleBlock(nn.Module):
def forward(self, x):
vit_embeds = x
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.flat_square(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
return vit_embeds
def flat_square(self, x):
n, w, h, c = x.size()
if w % 2 == 1:
x = torch.concat(
[x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1
).contiguous()
n, w, h, c = x.size()
if h % 2 == 1:
x = torch.concat(
[x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2
).contiguous()
n, w, h, c = x.size()
x = x.view(n, w, int(h / 2), int(c * 2))
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
return x
class MultimodalProjectorConfig(PretrainedConfig):
model_type = "v2l_projector"
def __init__(self, mm_projector_type: str = None, **kwargs):
super().__init__()
self.mm_projector_type = mm_projector_type
class MultimodalProjector(PreTrainedModel):
config_class = MultimodalProjectorConfig
def __init__(
self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig
):
super().__init__(mm_projector_cfg)
mm_projector_type = mm_projector_cfg.mm_projector_type
if mm_projector_type == "identity":
self.layers = IdentityMap()
elif mm_projector_type == "linear":
self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
elif mm_projector_type == "mlp_downsample":
self.layers = nn.Sequential(
DownSampleBlock(),
nn.LayerNorm(config.mm_hidden_size * 4),
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size),
)
else:
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
self.layers = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {mm_projector_type}")
def forward(self, x, *args, **kwargs):
return self.layers(x)
def build_mm_projector(
model_type_or_path: str, config: PretrainedConfig
) -> PreTrainedModel:
if model_type_or_path is None:
return None
## load from pretrained model
if config.resume_path:
assert os.path.exists(
model_type_or_path
), f"Resume mm projector path {model_type_or_path} does not exist!"
return MultimodalProjector.from_pretrained(
model_type_or_path, config, torch_dtype=eval(config.model_dtype)
)
## build from scratch
else:
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
mm_projector = MultimodalProjector(mm_projector_cfg, config).to(
eval(config.model_dtype)
)
return mm_projector
class VisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = getattr(args, "mm_vision_select_layer", -2)
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.cfg_only = None
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def _maybe_resize_pos_embeds(
self,
model: PreTrainedModel,
image_processor,
resolution: int = -1,
interpolate_mode: str = "linear",
):
if resolution in [model.config.image_size, -1]:
return
print(
f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
)
embeddings = model.vision_model.embeddings
patch_size = embeddings.patch_size
num_new_tokens = int((resolution // patch_size) ** 2)
old_embeddings = embeddings.position_embedding
match interpolate_mode:
case "linear":
## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
import torch
import torch.nn as nn
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
new_embeddings = nn.Embedding(
num_new_tokens,
old_embedding_dim,
dtype=old_embeddings.weight.dtype,
device=old_embeddings.weight.device,
)
mapped_indices = (
torch.arange(num_new_tokens).to(old_embeddings.weight.device)
/ (num_new_tokens - 1)
* (old_num_tokens - 1)
)
floor_indices = torch.clamp(
mapped_indices.floor().long(), min=0, max=old_num_tokens - 1
)
ceil_indices = torch.clamp(
mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1
)
if is_deepspeed_zero3_enabled():
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
interpolated_embeds = (mapped_indices - floor_indices)[
:, None
] * old_embeddings.weight.data[ceil_indices, :] + (
ceil_indices - mapped_indices
)[
:, None
] * old_embeddings.weight.data[
floor_indices, :
]
else:
interpolated_embeds = (mapped_indices - floor_indices)[
:, None
] * old_embeddings.weight.data[ceil_indices, :] + (
ceil_indices - mapped_indices
)[
:, None
] * old_embeddings.weight.data[
floor_indices, :
]
new_embeddings.weight.data = interpolated_embeds
case _:
raise NotImplementedError
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
# disable to inference
# add_hook_to_module(new_embeddings, hook)
new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
## update vision encoder's configurations
model.config.image_size = resolution
if hasattr(image_processor, "crop_size"):
# CLIP vision tower
image_processor.crop_size = resolution
else:
# SIGLIP vision tower
assert hasattr(image_processor, "size")
image_processor.size = {"height": resolution, "width": resolution}
## TODO define a '_reinitialize' method for VisionTower
embeddings.position_embedding = new_embeddings
embeddings.image_size = resolution
embeddings.num_patches = embeddings.num_positions = num_new_tokens
embeddings.position_ids = (
torch.arange(embeddings.num_positions)
.expand((1, -1))
.to(old_embeddings.weight.device)
)
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class SiglipVisionTower(VisionTower):
def __init__(
self, model_name_or_path: str, config: PretrainedConfig, state_dict=None
):
super().__init__(model_name_or_path, config)
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
self.vision_tower = SiglipVisionModel.from_pretrained(
# TODO(ligeng): why pass config here leading to errors?
model_name_or_path,
torch_dtype=eval(config.model_dtype),
state_dict=state_dict,
)
self.is_loaded = True
def build_vision_tower(
model_name_or_path: str, config: PretrainedConfig
) -> PreTrainedModel:
## skip vision tower instantiation
if model_name_or_path is None:
return None
vision_tower_arch = None
if config.resume_path and "radio" not in model_name_or_path:
assert os.path.exists(
model_name_or_path
), f"Resume vision tower path {model_name_or_path} does not exist!"
vision_tower_cfg = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=True
)
vision_tower_arch = vision_tower_cfg.architectures[0].lower()
vision_tower_name = (
vision_tower_arch if vision_tower_arch is not None else model_name_or_path
)
use_s2 = getattr(config, "s2", False)
if "siglip" in vision_tower_name:
if use_s2:
vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
else:
vision_tower = SiglipVisionTower(model_name_or_path, config)
else:
raise ValueError(f"Unknown vision tower: {model_name_or_path}")
config.mm_hidden_size = (
vision_tower.config.hidden_size if not use_s2 else vision_tower.hidden_size
)
return vision_tower
def has_tokenizer(repo_id_or_path: str) -> bool:
# Check if the tokenizer is in a local directory
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
return True
# Check if the tokenizer is in a Hugging Face Hub repo
try:
return repo_exists(repo_id_or_path) and file_exists(
repo_id_or_path, "tokenizer_config.json"
)
except HFValidationError:
return False
def context_length_extension(config):
orig_ctx_len = getattr(config, "max_position_embeddings", None)
model_max_length = getattr(config, "model_max_length", None)
if orig_ctx_len and model_max_length > orig_ctx_len:
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
return config
def build_llm_and_tokenizer(
model_name_or_path: str,
config: PretrainedConfig,
attn_implementation=None,
model_max_length=None,
*args,
**kwargs,
):
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
llm_cfg._attn_implementation = attn_implementation
llm_cfg.model_max_length = model_max_length
if model_max_length is not None:
context_length_extension(llm_cfg)
llm = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=llm_cfg,
torch_dtype=eval(config.model_dtype),
*args,
**kwargs,
)
# Locate the tokenizer.
llm_path = model_name_or_path
if not has_tokenizer(llm_path):
llm_path = osp.join(llm_path, "llm")
if not has_tokenizer(llm_path):
raise ValueError(f"Cannot find tokenizer in {llm_path}.")
# TODO(ligeng): use LLM class to judge to better compability.
try:
llm_arch = getattr(llm_cfg, "architectures")[0].lower()
except BaseException:
warnings.warn(
f'Cannot find LLM architecture, please check the "config.json" under "{llm_path}".'
)
if "mpt" in llm_arch:
tokenizer = AutoTokenizer.from_pretrained(
llm_path,
model_max_length=llm_cfg.model_max_length,
padding_side="right",
)
elif "yi" in llm_path or (
getattr(llm_cfg, "num_hidden_layers", -1) == 60
and getattr(llm_cfg, "num_attention_heads", -1) == 56
):
tokenizer = AutoTokenizer.from_pretrained(
llm_path,
model_max_length=llm_cfg.model_max_length,
padding_side="right",
use_fast=False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
llm_path,
model_max_length=llm_cfg.model_max_length,
padding_side="right",
use_fast=False,
legacy=False,
)
# TODO(ligeng): is this necessary for llava?
config.hidden_size = llm.config.hidden_size
return llm, tokenizer
def is_mm_model(model_path):
"""
Check if the model at the given path is a visual language model.
Args:
model_path (str): The path to the model.
Returns:
bool: True if the model is an MM model, False otherwise.
"""
config = AutoConfig.from_pretrained(model_path)
architectures = config.architectures
for architecture in architectures:
if "llava" in architecture.lower():
return True
return False
def load_pretrained_model(
model_path,
model_name,
model_base=None,
load_8bit=False,
load_4bit=False,
device_map="auto",
device="cuda",
**kwargs,
):
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch.float16
# kwargs["torch_dtype"] = torch.bfloat16
if is_mm_model(model_path):
# Load LLaVA model
## TODO @yunhao: mind fixing lora
if "lora" in model_name.lower() and model_base is None:
warnings.warn(
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
)
if (
"lora" in model_name.lower() or "dora" in model_name.lower()
) and model_base is not None:
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
print(lora_cfg_pretrained)
print("Loading LLaVA from base model...")
config = AutoConfig.from_pretrained(model_base)
prepare_config_for_eval(config, kwargs)
model = LlavaLlamaModel.from_pretrained(
model_base, low_cpu_mem_usage=True, config=config, **kwargs
)
tokenizer = model.tokenizer
token_num, tokem_dim = (
model.llm.lm_head.out_features,
model.llm.lm_head.in_features,
)
if model.llm.lm_head.weight.shape[0] != token_num:
model.llm.lm_head.weight = torch.nn.Parameter(
torch.empty(
token_num, tokem_dim, device=model.device, dtype=model.dtype
)
)
model.llm.embed_tokens.weight = torch.nn.Parameter(
torch.empty(
token_num, tokem_dim, device=model.device, dtype=model.dtype
)
)
print("Loading additional LLaVA weights...")
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
non_lora_trainables = torch.load(
os.path.join(model_path, "non_lora_trainables.bin"),
map_location="cpu",
)
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id, filename=filename, subfolder=subfolder
)
return torch.load(cache_file, map_location="cpu")
non_lora_trainables = load_from_hf(
model_path, "non_lora_trainables.bin"
)
non_lora_trainables = {
(k[11:] if k.startswith("base_model.") else k): v
for k, v in non_lora_trainables.items()
}
if any(k.startswith("model.model.") for k in non_lora_trainables):
non_lora_trainables = {
(k[6:] if k.startswith("model.") else k): v
for k, v in non_lora_trainables.items()
}
model.load_state_dict(non_lora_trainables, strict=False)
from peft import PeftModel
print("Loading LoRA weights...")
model = PeftModel.from_pretrained(model, model_path)
print("Merging LoRA weights...")
model = model.merge_and_unload()
print("Model is loaded...")
## TODO @yunhao: mind fixing this
elif model_base is not None:
# this may be mm projector only
print("Loading LLaVA from base model...")
cfg_pretrained = AutoConfig.from_pretrained(
model_path, trust_remote_code=True
)
mm_config_wrapper(config, kwargs)
if "mpt" in model_name.lower():
if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")):
shutil.copyfile(
os.path.join(model_base, "configuration_mpt.py"),
os.path.join(model_path, "configuration_mpt.py"),
)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
model = LlavaMPTForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_base, use_fast=False, legacy=False
)
model = LlavaLlamaForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
)
else:
config = AutoConfig.from_pretrained(model_path)
config.resume_path = model_path
prepare_config_for_eval(config, kwargs)
if "mpt" in model_name.lower():
model = LlavaMPTForCausalLM.from_pretrained(
model_path, config=config, low_cpu_mem_usage=True, **kwargs
)
elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
model = LlavaMistralForCausalLM.from_pretrained(
model_path, config=config, low_cpu_mem_usage=True, **kwargs
)
elif "gemma" in model_name.lower():
model = LlavaGemmaForCausalLM.from_pretrained(
model_path, config=config, low_cpu_mem_usage=True, **kwargs
)
else:
# kentang-mit@: llama-2 model
# config._attn_implementation = "flash_attention_2"
model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs)
tokenizer = model.tokenizer
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_base, low_cpu_mem_usage=True, **kwargs
)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print("Convert to FP16...")
model.to(torch.float16)
else:
if "mpt" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, legacy=False
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
model.eval()
image_processor = None
if is_mm_model(model_path):
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
vision_tower.to(device=device, dtype=torch.float16)
# vision_tower.to(device=device, dtype=torch.bfloat16)
mm_projector = model.get_mm_projector()
mm_projector.to(device=device, dtype=torch.float16)
# mm_projector.to(device=device, dtype=torch.bfloat16)
image_processor = vision_tower.image_processor
if hasattr(model.llm.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"):
target_model = f"{model_name}{suffix}"
target_cfg = getattr(config, target_model, None)
if isinstance(target_cfg, str):
return target_cfg
elif isinstance(target_cfg, dict):
return target_cfg["architectures"][0]
else:
raise ValueError(f"Invalid {target_model} configuration!")
def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
try:
# compatible with deprecated config convention
if getattr(config, "vision_tower_cfg", None) is None:
config.vision_tower_cfg = config.mm_vision_tower
except AttributeError:
raise ValueError(
f"Invalid configuration! Cannot find vision_tower in config:\n{config}"
)
config.model_dtype = kwargs.pop("torch_dtype").__str__()
# siglip does not support device_map = "auto"
vision_tower_name = parse_model_name_or_path(config, "vision_tower")
if "siglip" in vision_tower_name.lower():
kwargs["device_map"] = "cuda"
class LlavaLlamaConfig(LlavaConfig):
model_type = "llava_llama"
# class LlavaLlamaModel(PreTrainedModel):
# config_class = LlavaLlamaConfig
# main_input_name = "input_embeds"
# supports_gradient_checkpointing = True
# @classmethod
# def from_pretrained(
# cls,
# pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
# *model_args,
# config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
# cache_dir: Optional[Union[str, os.PathLike]] = None,
# ignore_mismatched_sizes: bool = False,
# force_download: bool = False,
# local_files_only: bool = False,
# token: Optional[Union[str, bool]] = None,
# revision: str = "main",
# use_safetensors: bool = None,
# **kwargs,
# ):
# if hasattr(cls, "load_pretrained"):
# return cls.load_pretrained(
# pretrained_model_name_or_path,
# *model_args,
# config=config,
# cache_dir=cache_dir,
# ignore_mismatched_sizes=ignore_mismatched_sizes,
# force_download=force_download,
# local_files_only=local_files_only,
# token=token,
# revision=revision,
# use_safetensors=use_safetensors,
# **kwargs,
# )
# return None
from abc import ABC, abstractmethod
from collections import OrderedDict
class LlavaMetaModel(ABC):
def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs):
# TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
if (
hasattr(self, "llm")
or hasattr(self, "vision_tower")
or hasattr(self, "mm_projector")
):
# already initialized, skipped
return
model_dtype = getattr(config, "model_dtype", "torch.float16")
if not hasattr(config, "model_dtype"):
warnings.warn(
"model_dtype not found in config, defaulting to torch.float16."
)
config.model_dtype = model_dtype
cfgs = get_model_config(config)
if len(cfgs) == 3:
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
else:
raise ValueError(
"`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config."
)
self.llm, self.tokenizer = build_llm_and_tokenizer(
llm_cfg, config, *args, **kwargs
)
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
self.post_config()
self.is_loaded = True
assert (
self.llm is not None
or self.vision_tower is not None
or self.mm_projector is not None
), "At least one of the components must be instantiated."
@classmethod
def load_from_config(cls, model_path_or_config, *args, **kwargs):
pass
## FIXME we will use this function to load model in the future
@classmethod
def load_pretrained(cls, model_path_or_config, *args, **kwargs):
kwargs.pop("config", None)
if isinstance(model_path_or_config, str):
config = AutoConfig.from_pretrained(model_path_or_config)
elif isinstance(model_path_or_config, LlavaConfig):
config = model_path_or_config
else:
raise NotImplementedError(
f"wrong type, {type(model_path_or_config)} \
{isinstance(model_path_or_config, LlavaConfig)}"
)
model_dtype = getattr(config, "model_dtype", "torch.float16")
if not hasattr(config, "model_dtype"):
warnings.warn(
"model_dtype not found in config, defaulting to torch.float16."
)
config.model_dtype = model_dtype
cfgs = get_model_config(config)
if len(cfgs) == 3:
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
else:
raise ValueError(
"`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config."
)
vlm = cls(config, *args, **kwargs)
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish")
if (
hasattr(vlm, "llm")
or hasattr(vlm, "vision_tower")
or hasattr(vlm, "mm_projector")
):
if vlm.is_loaded:
return vlm
vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(
llm_cfg, config, *args, **kwargs
)
vlm.vision_tower = build_vision_tower(vision_tower_cfg, config)
vlm.mm_projector = build_mm_projector(mm_projector_cfg, config)
cls.post_config()
cls.is_loaded = True
# FIXME(ligeng, yunhao): llm should never be none here.
assert (
vlm.llm is not None
or vlm.vision_tower is not None
or vlm.mm_projector is not None
), "At least one of the components must be instantiated."
return vlm
## FIXME we will use this function to save the model in the future
def save_pretrained(self, output_dir, state_dict=None):
if state_dict is None:
# other wise fetch from deepspeed
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
state_dict = self.state_dict()
if getattr(self, "tokenizer", None):
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
if self.get_llm():
print(f"saving llm to {osp.join(output_dir, 'llm')}")
self.llm.config._name_or_path = osp.join(output_dir, "llm")
llm_state_dict = OrderedDict(
{k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}
)
self.llm.save_pretrained(
os.path.join(output_dir, "llm"), state_dict=llm_state_dict
)
self.config.llm_cfg = self.llm.config
if self.get_vision_tower():
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
self.vision_tower.config._name_or_path = osp.join(
output_dir, "vision_tower"
)
vision_tower_state_dict = OrderedDict(
{
k.split("vision_tower.vision_tower.")[-1]: v
for k, v in state_dict.items()
if "vision_tower" in k
}
)
self.vision_tower.vision_tower.save_pretrained(
os.path.join(output_dir, "vision_tower"),
state_dict=vision_tower_state_dict,
)
self.vision_tower.image_processor.save_pretrained(
os.path.join(output_dir, "vision_tower")
)
self.config.vision_tower_cfg = self.vision_tower.config
if hasattr(self.config.vision_tower_cfg, "auto_map"):
if "radio" not in self.get_vision_tower().__class__.__name__.lower():
delattr(self.config.vision_tower_cfg, "auto_map")
if self.get_mm_projector():
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
self.mm_projector.config._name_or_path = osp.join(
output_dir, "mm_projector"
)
mm_projector_state_dict = OrderedDict(
{
k.split("mm_projector.")[-1]: v
for k, v in state_dict.items()
if "mm_projector" in k
}
)
self.mm_projector.save_pretrained(
os.path.join(output_dir, "mm_projector"),
state_dict=mm_projector_state_dict,
)
self.config.mm_projector_cfg = self.mm_projector.config
## update and save top-level config
self.config._name_or_path = output_dir
self.config.architectures = [self.__class__.__name__]
self.config.save_pretrained(output_dir)
def get_llm(self):
llm = getattr(self, "llm", None)
if type(llm) is list:
llm = llm[0]
return llm
def get_lm_head(self):
lm_head = getattr(self.get_llm(), "lm_head", None)
return lm_head
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_mm_projector(self):
mm_projector = getattr(self, "mm_projector", None)
if type(mm_projector) is list:
mm_projector = mm_projector[0]
return mm_projector
def post_config(self):
self.training = self.get_llm().training
## configuration
if getattr(self.config, "llm_cfg", None) is None:
self.config.llm_cfg = self.llm.config
if getattr(self.config, "vision_tower_cfg", None) is None:
self.config.vision_tower_cfg = self.vision_tower.config
if getattr(self.config, "mm_projector_cfg", None) is None:
self.config.mm_projector_cfg = self.mm_projector.config
def freezed_module_patch(self):
"""
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
"""
if self.training:
if self.get_llm() and not getattr(
self.config, "tune_language_model", False
):
pass
# logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
if self.get_vision_tower() and not getattr(
self.config, "tune_vision_tower", False
):
self.get_vision_tower().eval()
if self.get_mm_projector() and not getattr(
self.config, "tune_mm_projector", False
):
self.get_mm_projector().eval()
def encode_images(self, images):
image_features = self.get_vision_tower()(images)
image_features = self.get_mm_projector()(image_features)
return image_features
## @yunhao: is there a better way to handle function call and attributes for llm?
## support beam search
def _temporary_reorder_cache(self, past_key_values, sorted_idx):
return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx)
def get_input_embeddings(self):
return self.get_llm().get_input_embeddings()
def get_output_embeddings(self):
return self.get_llm().get_output_embeddings()
def resize_token_embeddings(self, embed_size):
self.get_llm().resize_token_embeddings(embed_size)
# ## FIXME we will follow the convention to add a new class for CausalLM in the future
class LlavaLlamaModel(LlavaMetaModel, PreTrainedModel):
config_class = LlavaLlamaConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
super().__init__(config)
return self.init_vlm(config=config, *args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
if hasattr(cls, "load_pretrained"):
return cls.load_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
return super(LlavaLlamaModel).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
def forward(
self,
input_ids: torch.LongTensor = None,
images: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
seqlens_in_batch: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
dpo_forward: bool = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
self.freezed_module_patch()
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images
)
support_packing = (
"seqlens_in_batch" in inspect.signature(self.llm.forward).parameters
)
if self.training and support_packing and not dpo_forward:
(
_,
new_position_ids,
new_attention_mask,
_,
new_inputs_embeds,
new_labels,
sorted_seqlens_in_batch,
) = self.repack_multimodal_data(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
)
if sorted_seqlens_in_batch is None:
sorted_seqlens_in_batch = seqlens_in_batch
new_input_ids = None
past_key_values = None
else:
new_attention_mask = attention_mask
new_position_ids = position_ids
new_inputs_embeds = inputs_embeds
new_labels = labels
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
new_input_ids = input_ids
if support_packing:
outputs = self.llm.forward(
input_ids=new_input_ids,
attention_mask=new_attention_mask,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=new_inputs_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
seqlens_in_batch=sorted_seqlens_in_batch,
)
else:
outputs = self.llm.forward(
input_ids=new_input_ids,
attention_mask=new_attention_mask,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=new_inputs_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if dpo_forward:
return outputs.logits, new_labels
return outputs
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.FloatTensor] = None,
images: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generation_kwargs,
):
if images is not None:
(
_,
_,
attention_mask,
_,
inputs_embeds,
_,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, None, attention_mask, None, None, images
)
else:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(self.dtype)
outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generation_kwargs,
)
return outputs
# AutoConfig.register("llava_llama", LlavaLlamaConfig)
# AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)