import inspect # from .builder import build_llm_and_tokenizer, build_mm_projector, build_vision_tower import os import os.path as osp import shutil import warnings from typing import List, Optional, Tuple, Union # from .llava_llama import LlavaLlamaModel # from llava.model import * # from llava.model.utils import is_mm_model import torch import torch.nn as nn from huggingface_hub import repo_exists, snapshot_download from huggingface_hub.utils import HFValidationError, validate_repo_id # from llava.model.multimodal_encoder.vision_encoder import (VisionTower, # VisionTowerS2) from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig, LlamaConfig, LlamaForCausalLM, PretrainedConfig, PreTrainedModel, SiglipImageProcessor, SiglipVisionModel) from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_llava import LlavaConfig # , LlavaLlamaConfig # from .llava_arch import LlavaMetaForCausalLM, LlavaMetaModel from .utils import get_model_config CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" IMAGE_PLACEHOLDER = "" def is_deepspeed_zero3_enabled(): return None import torch import torch.nn as nn from transformers import (AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel) class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": "identity"} class SimpleResBlock(nn.Module): def __init__(self, channels): super().__init__() self.pre_norm = nn.LayerNorm(channels) self.proj = nn.Sequential( nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) class DownSampleBlock(nn.Module): def forward(self, x): vit_embeds = x h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.flat_square(vit_embeds) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) return vit_embeds def flat_square(self, x): n, w, h, c = x.size() if w % 2 == 1: x = torch.concat( [x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1 ).contiguous() n, w, h, c = x.size() if h % 2 == 1: x = torch.concat( [x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2 ).contiguous() n, w, h, c = x.size() x = x.view(n, w, int(h / 2), int(c * 2)) x = x.permute(0, 2, 1, 3).contiguous() x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) return x class MultimodalProjectorConfig(PretrainedConfig): model_type = "v2l_projector" def __init__(self, mm_projector_type: str = None, **kwargs): super().__init__() self.mm_projector_type = mm_projector_type class MultimodalProjector(PreTrainedModel): config_class = MultimodalProjectorConfig def __init__( self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig ): super().__init__(mm_projector_cfg) mm_projector_type = mm_projector_cfg.mm_projector_type if mm_projector_type == "identity": self.layers = IdentityMap() elif mm_projector_type == "linear": self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size) elif mm_projector_type == "mlp_downsample": self.layers = nn.Sequential( DownSampleBlock(), nn.LayerNorm(config.mm_hidden_size * 4), nn.Linear(config.mm_hidden_size * 4, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size), ) else: mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) self.layers = nn.Sequential(*modules) else: raise ValueError(f"Unknown projector type: {mm_projector_type}") def forward(self, x, *args, **kwargs): return self.layers(x) def build_mm_projector( model_type_or_path: str, config: PretrainedConfig ) -> PreTrainedModel: if model_type_or_path is None: return None ## load from pretrained model if config.resume_path: assert os.path.exists( model_type_or_path ), f"Resume mm projector path {model_type_or_path} does not exist!" return MultimodalProjector.from_pretrained( model_type_or_path, config, torch_dtype=eval(config.model_dtype) ) ## build from scratch else: mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) mm_projector = MultimodalProjector(mm_projector_cfg, config).to( eval(config.model_dtype) ) return mm_projector class VisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = getattr(args, "mm_vision_select_layer", -2) self.select_feature = getattr(args, "mm_vision_select_feature", "patch") self.cfg_only = None def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == "patch": image_features = image_features[:, 1:] elif self.select_feature == "cls_patch": image_features = image_features else: raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features def _maybe_resize_pos_embeds( self, model: PreTrainedModel, image_processor, resolution: int = -1, interpolate_mode: str = "linear", ): if resolution in [model.config.image_size, -1]: return print( f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." ) embeddings = model.vision_model.embeddings patch_size = embeddings.patch_size num_new_tokens = int((resolution // patch_size) ** 2) old_embeddings = embeddings.position_embedding match interpolate_mode: case "linear": ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M ## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)] import torch import torch.nn as nn old_num_tokens, old_embedding_dim = old_embeddings.weight.size() new_embeddings = nn.Embedding( num_new_tokens, old_embedding_dim, dtype=old_embeddings.weight.dtype, device=old_embeddings.weight.device, ) mapped_indices = ( torch.arange(num_new_tokens).to(old_embeddings.weight.device) / (num_new_tokens - 1) * (old_num_tokens - 1) ) floor_indices = torch.clamp( mapped_indices.floor().long(), min=0, max=old_num_tokens - 1 ) ceil_indices = torch.clamp( mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1 ) if is_deepspeed_zero3_enabled(): params = [old_embeddings.weight, new_embeddings.weight] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): interpolated_embeds = (mapped_indices - floor_indices)[ :, None ] * old_embeddings.weight.data[ceil_indices, :] + ( ceil_indices - mapped_indices )[ :, None ] * old_embeddings.weight.data[ floor_indices, : ] else: interpolated_embeds = (mapped_indices - floor_indices)[ :, None ] * old_embeddings.weight.data[ceil_indices, :] + ( ceil_indices - mapped_indices )[ :, None ] * old_embeddings.weight.data[ floor_indices, : ] new_embeddings.weight.data = interpolated_embeds case _: raise NotImplementedError if hasattr(old_embeddings, "_hf_hook"): hook = old_embeddings._hf_hook # disable to inference # add_hook_to_module(new_embeddings, hook) new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) ## update vision encoder's configurations model.config.image_size = resolution if hasattr(image_processor, "crop_size"): # CLIP vision tower image_processor.crop_size = resolution else: # SIGLIP vision tower assert hasattr(image_processor, "size") image_processor.size = {"height": resolution, "width": resolution} ## TODO define a '_reinitialize' method for VisionTower embeddings.position_embedding = new_embeddings embeddings.image_size = resolution embeddings.num_patches = embeddings.num_positions = num_new_tokens embeddings.position_ids = ( torch.arange(embeddings.num_positions) .expand((1, -1)) .to(old_embeddings.weight.device) ) def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower( image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True, ) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower( images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, ) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 class SiglipVisionTower(VisionTower): def __init__( self, model_name_or_path: str, config: PretrainedConfig, state_dict=None ): super().__init__(model_name_or_path, config) self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) self.vision_tower = SiglipVisionModel.from_pretrained( # TODO(ligeng): why pass config here leading to errors? model_name_or_path, torch_dtype=eval(config.model_dtype), state_dict=state_dict, ) self.is_loaded = True def build_vision_tower( model_name_or_path: str, config: PretrainedConfig ) -> PreTrainedModel: ## skip vision tower instantiation if model_name_or_path is None: return None vision_tower_arch = None if config.resume_path and "radio" not in model_name_or_path: assert os.path.exists( model_name_or_path ), f"Resume vision tower path {model_name_or_path} does not exist!" vision_tower_cfg = AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=True ) vision_tower_arch = vision_tower_cfg.architectures[0].lower() vision_tower_name = ( vision_tower_arch if vision_tower_arch is not None else model_name_or_path ) use_s2 = getattr(config, "s2", False) if "siglip" in vision_tower_name: if use_s2: vision_tower = SiglipVisionTowerS2(model_name_or_path, config) else: vision_tower = SiglipVisionTower(model_name_or_path, config) else: raise ValueError(f"Unknown vision tower: {model_name_or_path}") config.mm_hidden_size = ( vision_tower.config.hidden_size if not use_s2 else vision_tower.hidden_size ) return vision_tower def has_tokenizer(repo_id_or_path: str) -> bool: # Check if the tokenizer is in a local directory if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): return True # Check if the tokenizer is in a Hugging Face Hub repo try: return repo_exists(repo_id_or_path) and file_exists( repo_id_or_path, "tokenizer_config.json" ) except HFValidationError: return False def context_length_extension(config): orig_ctx_len = getattr(config, "max_position_embeddings", None) model_max_length = getattr(config, "model_max_length", None) if orig_ctx_len and model_max_length > orig_ctx_len: print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) config.rope_scaling = {"type": "linear", "factor": scaling_factor} return config def build_llm_and_tokenizer( model_name_or_path: str, config: PretrainedConfig, attn_implementation=None, model_max_length=None, *args, **kwargs, ): llm_cfg = AutoConfig.from_pretrained(model_name_or_path) llm_cfg._attn_implementation = attn_implementation llm_cfg.model_max_length = model_max_length if model_max_length is not None: context_length_extension(llm_cfg) llm = AutoModelForCausalLM.from_pretrained( model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs, ) # Locate the tokenizer. llm_path = model_name_or_path if not has_tokenizer(llm_path): llm_path = osp.join(llm_path, "llm") if not has_tokenizer(llm_path): raise ValueError(f"Cannot find tokenizer in {llm_path}.") # TODO(ligeng): use LLM class to judge to better compability. try: llm_arch = getattr(llm_cfg, "architectures")[0].lower() except BaseException: warnings.warn( f'Cannot find LLM architecture, please check the "config.json" under "{llm_path}".' ) if "mpt" in llm_arch: tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", ) elif "yi" in llm_path or ( getattr(llm_cfg, "num_hidden_layers", -1) == 60 and getattr(llm_cfg, "num_attention_heads", -1) == 56 ): tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", use_fast=False, ) else: tokenizer = AutoTokenizer.from_pretrained( llm_path, model_max_length=llm_cfg.model_max_length, padding_side="right", use_fast=False, legacy=False, ) # TODO(ligeng): is this necessary for llava? config.hidden_size = llm.config.hidden_size return llm, tokenizer def is_mm_model(model_path): """ Check if the model at the given path is a visual language model. Args: model_path (str): The path to the model. Returns: bool: True if the model is an MM model, False otherwise. """ config = AutoConfig.from_pretrained(model_path) architectures = config.architectures for architecture in architectures: if "llava" in architecture.lower(): return True return False def load_pretrained_model( model_path, model_name, model_base=None, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs, ): kwargs = {"device_map": device_map, **kwargs} if device != "cuda": kwargs["device_map"] = {"": device} if load_8bit: kwargs["load_in_8bit"] = True elif load_4bit: kwargs["load_in_4bit"] = True kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) else: kwargs["torch_dtype"] = torch.float16 # kwargs["torch_dtype"] = torch.bfloat16 if is_mm_model(model_path): # Load LLaVA model ## TODO @yunhao: mind fixing lora if "lora" in model_name.lower() and model_base is None: warnings.warn( "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged." ) if ( "lora" in model_name.lower() or "dora" in model_name.lower() ) and model_base is not None: lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) print(lora_cfg_pretrained) print("Loading LLaVA from base model...") config = AutoConfig.from_pretrained(model_base) prepare_config_for_eval(config, kwargs) model = LlavaLlamaModel.from_pretrained( model_base, low_cpu_mem_usage=True, config=config, **kwargs ) tokenizer = model.tokenizer token_num, tokem_dim = ( model.llm.lm_head.out_features, model.llm.lm_head.in_features, ) if model.llm.lm_head.weight.shape[0] != token_num: model.llm.lm_head.weight = torch.nn.Parameter( torch.empty( token_num, tokem_dim, device=model.device, dtype=model.dtype ) ) model.llm.embed_tokens.weight = torch.nn.Parameter( torch.empty( token_num, tokem_dim, device=model.device, dtype=model.dtype ) ) print("Loading additional LLaVA weights...") if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): non_lora_trainables = torch.load( os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu", ) else: # this is probably from HF Hub from huggingface_hub import hf_hub_download def load_from_hf(repo_id, filename, subfolder=None): cache_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder ) return torch.load(cache_file, map_location="cpu") non_lora_trainables = load_from_hf( model_path, "non_lora_trainables.bin" ) non_lora_trainables = { (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items() } if any(k.startswith("model.model.") for k in non_lora_trainables): non_lora_trainables = { (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items() } model.load_state_dict(non_lora_trainables, strict=False) from peft import PeftModel print("Loading LoRA weights...") model = PeftModel.from_pretrained(model, model_path) print("Merging LoRA weights...") model = model.merge_and_unload() print("Model is loaded...") ## TODO @yunhao: mind fixing this elif model_base is not None: # this may be mm projector only print("Loading LLaVA from base model...") cfg_pretrained = AutoConfig.from_pretrained( model_path, trust_remote_code=True ) mm_config_wrapper(config, kwargs) if "mpt" in model_name.lower(): if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")): shutil.copyfile( os.path.join(model_base, "configuration_mpt.py"), os.path.join(model_path, "configuration_mpt.py"), ) tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) model = LlavaMPTForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained( model_base, use_fast=False, legacy=False ) model = LlavaLlamaForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs ) else: config = AutoConfig.from_pretrained(model_path) config.resume_path = model_path prepare_config_for_eval(config, kwargs) if "mpt" in model_name.lower(): model = LlavaMPTForCausalLM.from_pretrained( model_path, config=config, low_cpu_mem_usage=True, **kwargs ) elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): model = LlavaMistralForCausalLM.from_pretrained( model_path, config=config, low_cpu_mem_usage=True, **kwargs ) elif "gemma" in model_name.lower(): model = LlavaGemmaForCausalLM.from_pretrained( model_path, config=config, low_cpu_mem_usage=True, **kwargs ) else: # kentang-mit@: llama-2 model # config._attn_implementation = "flash_attention_2" model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs) tokenizer = model.tokenizer else: # Load language model if model_base is not None: # PEFT model from peft import PeftModel tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_base, low_cpu_mem_usage=True, **kwargs ) print(f"Loading LoRA weights from {model_path}") model = PeftModel.from_pretrained(model, model_path) print(f"Merging weights") model = model.merge_and_unload() print("Convert to FP16...") model.to(torch.float16) else: if "mpt" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs ) else: tokenizer = AutoTokenizer.from_pretrained( model_path, use_fast=False, legacy=False ) model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) model.eval() image_processor = None if is_mm_model(model_path): mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() vision_tower.to(device=device, dtype=torch.float16) # vision_tower.to(device=device, dtype=torch.bfloat16) mm_projector = model.get_mm_projector() mm_projector.to(device=device, dtype=torch.float16) # mm_projector.to(device=device, dtype=torch.bfloat16) image_processor = vision_tower.image_processor if hasattr(model.llm.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"): target_model = f"{model_name}{suffix}" target_cfg = getattr(config, target_model, None) if isinstance(target_cfg, str): return target_cfg elif isinstance(target_cfg, dict): return target_cfg["architectures"][0] else: raise ValueError(f"Invalid {target_model} configuration!") def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict): try: # compatible with deprecated config convention if getattr(config, "vision_tower_cfg", None) is None: config.vision_tower_cfg = config.mm_vision_tower except AttributeError: raise ValueError( f"Invalid configuration! Cannot find vision_tower in config:\n{config}" ) config.model_dtype = kwargs.pop("torch_dtype").__str__() # siglip does not support device_map = "auto" vision_tower_name = parse_model_name_or_path(config, "vision_tower") if "siglip" in vision_tower_name.lower(): kwargs["device_map"] = "cuda" class LlavaLlamaConfig(LlavaConfig): model_type = "llava_llama" # class LlavaLlamaModel(PreTrainedModel): # config_class = LlavaLlamaConfig # main_input_name = "input_embeds" # supports_gradient_checkpointing = True # @classmethod # def from_pretrained( # cls, # pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], # *model_args, # config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, # cache_dir: Optional[Union[str, os.PathLike]] = None, # ignore_mismatched_sizes: bool = False, # force_download: bool = False, # local_files_only: bool = False, # token: Optional[Union[str, bool]] = None, # revision: str = "main", # use_safetensors: bool = None, # **kwargs, # ): # if hasattr(cls, "load_pretrained"): # return cls.load_pretrained( # pretrained_model_name_or_path, # *model_args, # config=config, # cache_dir=cache_dir, # ignore_mismatched_sizes=ignore_mismatched_sizes, # force_download=force_download, # local_files_only=local_files_only, # token=token, # revision=revision, # use_safetensors=use_safetensors, # **kwargs, # ) # return None from abc import ABC, abstractmethod from collections import OrderedDict class LlavaMetaModel(ABC): def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. if ( hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector") ): # already initialized, skipped return model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn( "model_dtype not found in config, defaulting to torch.float16." ) config.model_dtype = model_dtype cfgs = get_model_config(config) if len(cfgs) == 3: llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs else: raise ValueError( "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." ) self.llm, self.tokenizer = build_llm_and_tokenizer( llm_cfg, config, *args, **kwargs ) self.vision_tower = build_vision_tower(vision_tower_cfg, config) self.mm_projector = build_mm_projector(mm_projector_cfg, config) self.post_config() self.is_loaded = True assert ( self.llm is not None or self.vision_tower is not None or self.mm_projector is not None ), "At least one of the components must be instantiated." @classmethod def load_from_config(cls, model_path_or_config, *args, **kwargs): pass ## FIXME we will use this function to load model in the future @classmethod def load_pretrained(cls, model_path_or_config, *args, **kwargs): kwargs.pop("config", None) if isinstance(model_path_or_config, str): config = AutoConfig.from_pretrained(model_path_or_config) elif isinstance(model_path_or_config, LlavaConfig): config = model_path_or_config else: raise NotImplementedError( f"wrong type, {type(model_path_or_config)} \ {isinstance(model_path_or_config, LlavaConfig)}" ) model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn( "model_dtype not found in config, defaulting to torch.float16." ) config.model_dtype = model_dtype cfgs = get_model_config(config) if len(cfgs) == 3: llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs else: raise ValueError( "`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." ) vlm = cls(config, *args, **kwargs) # print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish") if ( hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "mm_projector") ): if vlm.is_loaded: return vlm vlm.llm, vlm.tokenizer = build_llm_and_tokenizer( llm_cfg, config, *args, **kwargs ) vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) cls.post_config() cls.is_loaded = True # FIXME(ligeng, yunhao): llm should never be none here. assert ( vlm.llm is not None or vlm.vision_tower is not None or vlm.mm_projector is not None ), "At least one of the components must be instantiated." return vlm ## FIXME we will use this function to save the model in the future def save_pretrained(self, output_dir, state_dict=None): if state_dict is None: # other wise fetch from deepspeed # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) state_dict = self.state_dict() if getattr(self, "tokenizer", None): self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) if self.get_llm(): print(f"saving llm to {osp.join(output_dir, 'llm')}") self.llm.config._name_or_path = osp.join(output_dir, "llm") llm_state_dict = OrderedDict( {k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k} ) self.llm.save_pretrained( os.path.join(output_dir, "llm"), state_dict=llm_state_dict ) self.config.llm_cfg = self.llm.config if self.get_vision_tower(): print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") self.vision_tower.config._name_or_path = osp.join( output_dir, "vision_tower" ) vision_tower_state_dict = OrderedDict( { k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k } ) self.vision_tower.vision_tower.save_pretrained( os.path.join(output_dir, "vision_tower"), state_dict=vision_tower_state_dict, ) self.vision_tower.image_processor.save_pretrained( os.path.join(output_dir, "vision_tower") ) self.config.vision_tower_cfg = self.vision_tower.config if hasattr(self.config.vision_tower_cfg, "auto_map"): if "radio" not in self.get_vision_tower().__class__.__name__.lower(): delattr(self.config.vision_tower_cfg, "auto_map") if self.get_mm_projector(): print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") self.mm_projector.config._name_or_path = osp.join( output_dir, "mm_projector" ) mm_projector_state_dict = OrderedDict( { k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k } ) self.mm_projector.save_pretrained( os.path.join(output_dir, "mm_projector"), state_dict=mm_projector_state_dict, ) self.config.mm_projector_cfg = self.mm_projector.config ## update and save top-level config self.config._name_or_path = output_dir self.config.architectures = [self.__class__.__name__] self.config.save_pretrained(output_dir) def get_llm(self): llm = getattr(self, "llm", None) if type(llm) is list: llm = llm[0] return llm def get_lm_head(self): lm_head = getattr(self.get_llm(), "lm_head", None) return lm_head def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def get_mm_projector(self): mm_projector = getattr(self, "mm_projector", None) if type(mm_projector) is list: mm_projector = mm_projector[0] return mm_projector def post_config(self): self.training = self.get_llm().training ## configuration if getattr(self.config, "llm_cfg", None) is None: self.config.llm_cfg = self.llm.config if getattr(self.config, "vision_tower_cfg", None) is None: self.config.vision_tower_cfg = self.vision_tower.config if getattr(self.config, "mm_projector_cfg", None) is None: self.config.mm_projector_cfg = self.mm_projector.config def freezed_module_patch(self): """ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. """ if self.training: if self.get_llm() and not getattr( self.config, "tune_language_model", False ): pass # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.") if self.get_vision_tower() and not getattr( self.config, "tune_vision_tower", False ): self.get_vision_tower().eval() if self.get_mm_projector() and not getattr( self.config, "tune_mm_projector", False ): self.get_mm_projector().eval() def encode_images(self, images): image_features = self.get_vision_tower()(images) image_features = self.get_mm_projector()(image_features) return image_features ## @yunhao: is there a better way to handle function call and attributes for llm? ## support beam search def _temporary_reorder_cache(self, past_key_values, sorted_idx): return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) def get_input_embeddings(self): return self.get_llm().get_input_embeddings() def get_output_embeddings(self): return self.get_llm().get_output_embeddings() def resize_token_embeddings(self, embed_size): self.get_llm().resize_token_embeddings(embed_size) # ## FIXME we will follow the convention to add a new class for CausalLM in the future class LlavaLlamaModel(LlavaMetaModel, PreTrainedModel): config_class = LlavaLlamaConfig main_input_name = "input_embeds" supports_gradient_checkpointing = True def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: super().__init__(config) return self.init_vlm(config=config, *args, **kwargs) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): if hasattr(cls, "load_pretrained"): return cls.load_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs, ) return super(LlavaLlamaModel).from_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs, ) def forward( self, input_ids: torch.LongTensor = None, images: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, seqlens_in_batch: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, dpo_forward: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: self.freezed_module_patch() if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images ) support_packing = ( "seqlens_in_batch" in inspect.signature(self.llm.forward).parameters ) if self.training and support_packing and not dpo_forward: ( _, new_position_ids, new_attention_mask, _, new_inputs_embeds, new_labels, sorted_seqlens_in_batch, ) = self.repack_multimodal_data( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ) if sorted_seqlens_in_batch is None: sorted_seqlens_in_batch = seqlens_in_batch new_input_ids = None past_key_values = None else: new_attention_mask = attention_mask new_position_ids = position_ids new_inputs_embeds = inputs_embeds new_labels = labels sorted_seqlens_in_batch = attention_mask.sum(-1).int() new_input_ids = input_ids if support_packing: outputs = self.llm.forward( input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids, past_key_values=past_key_values, inputs_embeds=new_inputs_embeds, labels=new_labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, seqlens_in_batch=sorted_seqlens_in_batch, ) else: outputs = self.llm.forward( input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids, past_key_values=past_key_values, inputs_embeds=new_inputs_embeds, labels=new_labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if dpo_forward: return outputs.logits, new_labels return outputs @torch.no_grad() def generate( self, input_ids: Optional[torch.FloatTensor] = None, images: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, **generation_kwargs, ): if images is not None: ( _, _, attention_mask, _, inputs_embeds, _, ) = self.prepare_inputs_labels_for_multimodal( input_ids, None, attention_mask, None, None, images ) else: inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = inputs_embeds.to(self.dtype) outputs = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, ) return outputs # AutoConfig.register("llava_llama", LlavaLlamaConfig) # AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)