|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import math |
|
import os |
|
import os.path as osp |
|
import sys |
|
import warnings |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from huggingface_hub import file_exists, repo_exists, snapshot_download |
|
from huggingface_hub.utils import HFValidationError, validate_repo_id |
|
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, |
|
AutoTokenizer, BitsAndBytesConfig, PretrainedConfig, |
|
PreTrainedModel, PreTrainedTokenizer) |
|
from transformers.modeling_utils import ContextManagers, no_init_weights |
|
|
|
from .configuration_llava import LlavaConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
|
DEFAULT_IM_START_TOKEN = "<im_start>" |
|
DEFAULT_IM_END_TOKEN = "<im_end>" |
|
IMAGE_PLACEHOLDER = "<image-placeholder>" |
|
|
|
|
|
|
|
import torch |
|
|
|
from transformers import CLIPImageProcessor, CLIPVisionModel, PretrainedConfig |
|
|
|
|
|
class VisionTower(nn.Module): |
|
def __init__(self, vision_tower, args, delay_load=False): |
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
|
|
self.vision_tower_name = vision_tower |
|
self.select_layer = getattr(args, "mm_vision_select_layer", -2) |
|
self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
|
|
|
self.cfg_only = None |
|
|
|
def feature_select(self, image_forward_outs): |
|
image_features = image_forward_outs.hidden_states[self.select_layer] |
|
if self.select_feature == "patch": |
|
image_features = image_features[:, 1:] |
|
elif self.select_feature == "cls_patch": |
|
image_features = image_features |
|
else: |
|
raise ValueError(f"Unexpected select feature: {self.select_feature}") |
|
return image_features |
|
|
|
def _maybe_resize_pos_embeds( |
|
self, |
|
model: PreTrainedModel, |
|
image_processor, |
|
resolution: int = -1, |
|
interpolate_mode: str = "linear", |
|
): |
|
if resolution in [model.config.image_size, -1]: |
|
return |
|
print( |
|
f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." |
|
) |
|
embeddings = model.vision_model.embeddings |
|
patch_size = embeddings.patch_size |
|
num_new_tokens = int((resolution // patch_size) ** 2) |
|
|
|
old_embeddings = embeddings.position_embedding |
|
match interpolate_mode: |
|
case "linear": |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
if is_deepspeed_zero3_enabled(): |
|
import deepspeed |
|
|
|
with deepspeed.zero.GatheredParameters( |
|
[old_embeddings.weight], modifier_rank=None |
|
): |
|
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
|
else: |
|
old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
|
new_embeddings = nn.Embedding( |
|
num_new_tokens, |
|
old_embedding_dim, |
|
dtype=old_embeddings.weight.dtype, |
|
device=old_embeddings.weight.device, |
|
) |
|
mapped_indices = ( |
|
torch.arange(num_new_tokens).to(old_embeddings.weight.device) |
|
/ (num_new_tokens - 1) |
|
* (old_num_tokens - 1) |
|
) |
|
floor_indices = torch.clamp( |
|
mapped_indices.floor().long(), min=0, max=old_num_tokens - 1 |
|
) |
|
ceil_indices = torch.clamp( |
|
mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1 |
|
) |
|
if is_deepspeed_zero3_enabled(): |
|
params = [old_embeddings.weight, new_embeddings.weight] |
|
with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
|
interpolated_embeds = (mapped_indices - floor_indices)[ |
|
:, None |
|
] * old_embeddings.weight.data[ceil_indices, :] + ( |
|
ceil_indices - mapped_indices |
|
)[ |
|
:, None |
|
] * old_embeddings.weight.data[ |
|
floor_indices, : |
|
] |
|
else: |
|
interpolated_embeds = (mapped_indices - floor_indices)[ |
|
:, None |
|
] * old_embeddings.weight.data[ceil_indices, :] + ( |
|
ceil_indices - mapped_indices |
|
)[ |
|
:, None |
|
] * old_embeddings.weight.data[ |
|
floor_indices, : |
|
] |
|
new_embeddings.weight.data = interpolated_embeds |
|
case _: |
|
raise NotImplementedError |
|
|
|
if hasattr(old_embeddings, "_hf_hook"): |
|
hook = old_embeddings._hf_hook |
|
add_hook_to_module(new_embeddings, hook) |
|
new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) |
|
|
|
model.config.image_size = resolution |
|
if hasattr(image_processor, "crop_size"): |
|
|
|
image_processor.crop_size = resolution |
|
else: |
|
|
|
assert hasattr(image_processor, "size") |
|
image_processor.size = {"height": resolution, "width": resolution} |
|
|
|
embeddings.position_embedding = new_embeddings |
|
embeddings.image_size = resolution |
|
embeddings.num_patches = embeddings.num_positions = num_new_tokens |
|
embeddings.position_ids = ( |
|
torch.arange(embeddings.num_positions) |
|
.expand((1, -1)) |
|
.to(old_embeddings.weight.device) |
|
) |
|
|
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_forward_out = self.vision_tower( |
|
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
|
output_hidden_states=True, |
|
) |
|
image_feature = self.feature_select(image_forward_out).to(image.dtype) |
|
image_features.append(image_feature) |
|
else: |
|
image_forward_outs = self.vision_tower( |
|
images.to(device=self.device, dtype=self.dtype), |
|
output_hidden_states=True, |
|
) |
|
image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_tower.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.device |
|
|
|
@property |
|
def config(self): |
|
if self.is_loaded: |
|
return self.vision_tower.config |
|
else: |
|
return self.cfg_only |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.config.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
|
class VisionTowerS2(VisionTower): |
|
def __init__(self, vision_tower, args, delay_load=False): |
|
super().__init__(vision_tower, args, delay_load) |
|
|
|
self.scales = list(map(int, args.s2_scales.split(","))) |
|
self.scales.sort() |
|
self.max_split_size = args.s2_max_split_size |
|
|
|
@torch.no_grad() |
|
def forward_feature(self, images): |
|
image_forward_outs = self.vision_tower( |
|
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True |
|
) |
|
image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
return image_features |
|
|
|
@torch.no_grad() |
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_feature = multiscale_forward( |
|
self.forward_feature, |
|
image.unsqueeze(0), |
|
img_sizes=self.scales, |
|
max_split_size=self.max_split_size, |
|
) |
|
image_features.append(image_feature) |
|
else: |
|
image_features = multiscale_forward( |
|
self.forward_feature, |
|
images, |
|
img_sizes=self.scales, |
|
max_split_size=self.max_split_size, |
|
) |
|
|
|
return image_features |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.config.hidden_size * len(self.scales) |
|
|
|
|
|
class CLIPVisionTower(VisionTower): |
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig): |
|
super().__init__(model_name_or_path, config) |
|
self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path) |
|
self.vision_tower = CLIPVisionModel.from_pretrained( |
|
model_name_or_path, torch_dtype=eval(config.model_dtype) |
|
) |
|
self.is_loaded = True |
|
|
|
|
|
class CLIPVisionTowerS2(VisionTowerS2): |
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig): |
|
super().__init__(model_name_or_path, config) |
|
self.image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path) |
|
self.vision_tower = CLIPVisionModel.from_pretrained( |
|
model_name_or_path, torch_dtype=eval(config.model_dtype) |
|
) |
|
|
|
|
|
self.image_processor.size["shortest_edge"] = self.scales[-1] |
|
self.image_processor.crop_size["height"] = self.image_processor.crop_size[ |
|
"width" |
|
] = self.scales[-1] |
|
|
|
self.is_loaded = True |
|
|
|
|
|
class IdentityMap(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
@property |
|
def config(self): |
|
return {"mm_projector_type": "identity"} |
|
|
|
|
|
class SimpleResBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.pre_norm = nn.LayerNorm(channels) |
|
|
|
self.proj = nn.Sequential( |
|
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.pre_norm(x) |
|
return x + self.proj(x) |
|
|
|
|
|
class DownSampleBlock(nn.Module): |
|
def forward(self, x): |
|
vit_embeds = x |
|
h = w = int(vit_embeds.shape[1] ** 0.5) |
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
|
vit_embeds = self.flat_square(vit_embeds) |
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
|
return vit_embeds |
|
|
|
def flat_square(self, x): |
|
n, w, h, c = x.size() |
|
if w % 2 == 1: |
|
x = torch.concat( |
|
[x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1 |
|
).contiguous() |
|
n, w, h, c = x.size() |
|
if h % 2 == 1: |
|
x = torch.concat( |
|
[x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2 |
|
).contiguous() |
|
n, w, h, c = x.size() |
|
x = x.view(n, w, int(h / 2), int(c * 2)) |
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) |
|
return x |
|
|
|
|
|
class MultimodalProjectorConfig(PretrainedConfig): |
|
model_type = "v2l_projector" |
|
|
|
def __init__(self, mm_projector_type: str = None, **kwargs): |
|
super().__init__() |
|
self.mm_projector_type = mm_projector_type |
|
|
|
|
|
class MultimodalProjector(PreTrainedModel): |
|
config_class = MultimodalProjectorConfig |
|
|
|
def __init__( |
|
self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig |
|
): |
|
super().__init__(mm_projector_cfg) |
|
mm_projector_type = mm_projector_cfg.mm_projector_type |
|
if mm_projector_type == "identity": |
|
self.layers = IdentityMap() |
|
elif mm_projector_type == "linear": |
|
self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size) |
|
elif mm_projector_type == "mlp_downsample": |
|
self.layers = nn.Sequential( |
|
DownSampleBlock(), |
|
nn.LayerNorm(config.mm_hidden_size * 4), |
|
nn.Linear(config.mm_hidden_size * 4, config.hidden_size), |
|
nn.GELU(), |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
) |
|
else: |
|
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type) |
|
if mlp_gelu_match: |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
|
self.layers = nn.Sequential(*modules) |
|
else: |
|
raise ValueError(f"Unknown projector type: {mm_projector_type}") |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return self.layers(x) |
|
|
|
|
|
def build_mm_projector( |
|
model_type_or_path: str, config: PretrainedConfig |
|
) -> PreTrainedModel: |
|
if model_type_or_path is None: |
|
return None |
|
|
|
|
|
if config.resume_path: |
|
assert os.path.exists( |
|
model_type_or_path |
|
), f"Resume mm projector path {model_type_or_path} does not exist!" |
|
return MultimodalProjector.from_pretrained( |
|
model_type_or_path, config, torch_dtype=eval(config.model_dtype) |
|
) |
|
|
|
else: |
|
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) |
|
mm_projector = MultimodalProjector(mm_projector_cfg, config).to( |
|
eval(config.model_dtype) |
|
) |
|
return mm_projector |
|
|
|
|
|
def build_vision_tower( |
|
model_name_or_path: str, config: PretrainedConfig |
|
) -> PreTrainedModel: |
|
|
|
if model_name_or_path is None: |
|
return None |
|
|
|
vision_tower_arch = None |
|
if config.resume_path and "radio" not in model_name_or_path: |
|
assert os.path.exists( |
|
model_name_or_path |
|
), f"Resume vision tower path {model_name_or_path} does not exist!" |
|
vision_tower_cfg = AutoConfig.from_pretrained( |
|
model_name_or_path, trust_remote_code=True |
|
) |
|
vision_tower_arch = vision_tower_cfg.architectures[0].lower() |
|
vision_tower_name = ( |
|
vision_tower_arch if vision_tower_arch is not None else model_name_or_path |
|
) |
|
|
|
use_s2 = getattr(config, "s2", False) |
|
|
|
if "intern" in vision_tower_name.lower(): |
|
if hasattr(config, "drop_path_rate"): |
|
vision_tower = InternVisionTower( |
|
model_name_or_path, config=config, drop_path_rate=config.drop_path_rate |
|
) |
|
else: |
|
vision_tower = InternVisionTower( |
|
model_name_or_path, config=config, drop_path_rate=0.0 |
|
) |
|
elif "clip" in vision_tower_name: |
|
if use_s2: |
|
vision_tower = CLIPVisionTowerS2(model_name_or_path, config) |
|
else: |
|
vision_tower = CLIPVisionTower(model_name_or_path, config) |
|
elif "siglip" in vision_tower_name: |
|
if use_s2: |
|
vision_tower = SiglipVisionTowerS2(model_name_or_path, config) |
|
else: |
|
vision_tower = SiglipVisionTower(model_name_or_path, config) |
|
else: |
|
raise ValueError(f"Unknown vision tower: {model_name_or_path}") |
|
|
|
config.mm_hidden_size = ( |
|
vision_tower.config.hidden_size if not use_s2 else vision_tower.hidden_size |
|
) |
|
return vision_tower |
|
|
|
|
|
def has_tokenizer(repo_id_or_path: str) -> bool: |
|
|
|
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): |
|
return True |
|
|
|
|
|
try: |
|
return repo_exists(repo_id_or_path) and file_exists( |
|
repo_id_or_path, "tokenizer_config.json" |
|
) |
|
except HFValidationError: |
|
return False |
|
|
|
|
|
def context_length_extension(config): |
|
orig_ctx_len = getattr(config, "max_position_embeddings", None) |
|
model_max_length = getattr(config, "model_max_length", None) |
|
if orig_ctx_len and model_max_length > orig_ctx_len: |
|
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") |
|
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) |
|
config.rope_scaling = {"type": "linear", "factor": scaling_factor} |
|
return config |
|
|
|
|
|
def build_llm_and_tokenizer( |
|
model_name_or_path: str, |
|
config: PretrainedConfig, |
|
attn_implementation=None, |
|
model_max_length=None, |
|
*args, |
|
**kwargs, |
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: |
|
llm_cfg = AutoConfig.from_pretrained(model_name_or_path) |
|
llm_cfg._attn_implementation = attn_implementation |
|
llm_cfg.model_max_length = model_max_length |
|
if model_max_length is not None: |
|
context_length_extension(llm_cfg) |
|
|
|
llm = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
config=llm_cfg, |
|
torch_dtype=eval(config.model_dtype), |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
|
|
llm_path = model_name_or_path |
|
if not has_tokenizer(llm_path): |
|
llm_path = osp.join(llm_path, "llm") |
|
if not has_tokenizer(llm_path): |
|
raise ValueError(f"Cannot find tokenizer in {llm_path}.") |
|
|
|
|
|
try: |
|
llm_arch = getattr(llm_cfg, "architectures")[0].lower() |
|
except BaseException: |
|
warnings.warn( |
|
f'Cannot find LLM architecture, please check the "config.json" under "{llm_path}".' |
|
) |
|
|
|
if "mpt" in llm_arch: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
llm_path, |
|
model_max_length=llm_cfg.model_max_length, |
|
padding_side="right", |
|
) |
|
elif "yi" in llm_path or ( |
|
getattr(llm_cfg, "num_hidden_layers", -1) == 60 |
|
and getattr(llm_cfg, "num_attention_heads", -1) == 56 |
|
): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
llm_path, |
|
model_max_length=llm_cfg.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
llm_path, |
|
model_max_length=llm_cfg.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
legacy=False, |
|
) |
|
|
|
|
|
config.hidden_size = llm.config.hidden_size |
|
return llm, tokenizer |
|
|
|
|
|
def get_model_config(config): |
|
default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"] |
|
|
|
if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: |
|
root_path = config._name_or_path |
|
else: |
|
root_path = config.resume_path |
|
|
|
|
|
if root_path is not None and not osp.exists(root_path): |
|
try: |
|
valid_hf_repo = repo_exists(root_path) |
|
except HFValidationError as e: |
|
valid_hf_repo = False |
|
if valid_hf_repo: |
|
root_path = snapshot_download(root_path) |
|
|
|
return_list = [] |
|
for key in default_keys: |
|
cfg = getattr(config, key, None) |
|
if isinstance(cfg, dict): |
|
try: |
|
return_list.append(os.path.join(root_path, key[:-4])) |
|
except: |
|
raise ValueError(f"Cannot find resume path in config for {key}!") |
|
elif isinstance(cfg, PretrainedConfig): |
|
return_list.append(os.path.join(root_path, key[:-4])) |
|
elif isinstance(cfg, str): |
|
return_list.append(cfg) |
|
|
|
return return_list |
|
|
|
|
|
def is_mm_model(model_path): |
|
""" |
|
Check if the model at the given path is a visual language model. |
|
|
|
Args: |
|
model_path (str): The path to the model. |
|
|
|
Returns: |
|
bool: True if the model is an MM model, False otherwise. |
|
""" |
|
config = AutoConfig.from_pretrained(model_path) |
|
architectures = config.architectures |
|
for architecture in architectures: |
|
if "llava" in architecture.lower(): |
|
return True |
|
return False |
|
|
|
|
|
def auto_upgrade(config): |
|
cfg = AutoConfig.from_pretrained(config) |
|
if "llava" in config and "llava" not in cfg.model_type: |
|
assert cfg.model_type == "llama" |
|
print( |
|
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." |
|
) |
|
print( |
|
"You must upgrade the checkpoint to the new code base (this can be done automatically)." |
|
) |
|
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") |
|
if confirm.lower() in ["y", "yes"]: |
|
print("Upgrading checkpoint...") |
|
assert len(cfg.architectures) == 1 |
|
setattr(cfg.__class__, "model_type", "llava") |
|
cfg.architectures[0] = "LlavaLlamaForCausalLM" |
|
cfg.save_pretrained(config) |
|
print("Checkpoint upgraded.") |
|
else: |
|
print("Checkpoint upgrade aborted.") |
|
exit(1) |
|
|
|
|
|
def get_pg_manager(): |
|
return None |
|
|
|
|
|
|
|
class LlavaMetaModel(ABC): |
|
def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): |
|
|
|
if ( |
|
hasattr(self, "llm") |
|
or hasattr(self, "vision_tower") |
|
or hasattr(self, "mm_projector") |
|
): |
|
|
|
return |
|
|
|
model_dtype = getattr(config, "model_dtype", "torch.float16") |
|
if not hasattr(config, "model_dtype"): |
|
warnings.warn( |
|
"model_dtype not found in config, defaulting to torch.float16." |
|
) |
|
config.model_dtype = model_dtype |
|
|
|
cfgs = get_model_config(config) |
|
if len(cfgs) == 3: |
|
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs |
|
else: |
|
raise ValueError( |
|
"`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.llm, self.tokenizer = build_llm_and_tokenizer( |
|
llm_cfg, config, *args, **kwargs |
|
) |
|
self.vision_tower = build_vision_tower(vision_tower_cfg, config) |
|
self.mm_projector = build_mm_projector(mm_projector_cfg, config) |
|
|
|
self.post_config() |
|
self.is_loaded = True |
|
|
|
assert ( |
|
self.llm is not None |
|
or self.vision_tower is not None |
|
or self.mm_projector is not None |
|
), "At least one of the components must be instantiated." |
|
|
|
@classmethod |
|
def load_from_config(cls, model_path_or_config, *args, **kwargs): |
|
pass |
|
|
|
|
|
@classmethod |
|
def load_pretrained(cls, model_path_or_config, *args, **kwargs): |
|
kwargs.pop("config", None) |
|
|
|
if isinstance(model_path_or_config, str): |
|
config = AutoConfig.from_pretrained(model_path_or_config) |
|
elif isinstance(model_path_or_config, LlavaConfig): |
|
config = model_path_or_config |
|
else: |
|
raise NotImplementedError( |
|
f"wrong type, {type(model_path_or_config)} \ |
|
{isinstance(model_path_or_config, LlavaConfig)}" |
|
) |
|
|
|
model_dtype = getattr(config, "model_dtype", "torch.float16") |
|
if not hasattr(config, "model_dtype"): |
|
warnings.warn( |
|
"model_dtype not found in config, defaulting to torch.float16." |
|
) |
|
config.model_dtype = model_dtype |
|
|
|
cfgs = get_model_config(config) |
|
if len(cfgs) == 3: |
|
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs |
|
else: |
|
raise ValueError( |
|
"`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config." |
|
) |
|
|
|
|
|
init_context = [ |
|
no_init_weights(_enable=True), |
|
] |
|
|
|
|
|
|
|
|
|
|
|
with ContextManagers(init_context): |
|
vlm = cls(config, *args, **kwargs) |
|
|
|
|
|
if ( |
|
hasattr(vlm, "llm") |
|
or hasattr(vlm, "vision_tower") |
|
or hasattr(vlm, "mm_projector") |
|
): |
|
if vlm.is_loaded: |
|
return vlm |
|
|
|
vlm.llm, vlm.tokenizer = build_llm_and_tokenizer( |
|
llm_cfg, config, *args, **kwargs |
|
) |
|
vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) |
|
vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) |
|
|
|
self.post_config() |
|
self.is_loaded = True |
|
|
|
|
|
assert ( |
|
vlm.llm is not None |
|
or vlm.vision_tower is not None |
|
or vlm.mm_projector is not None |
|
), "At least one of the components must be instantiated." |
|
return vlm |
|
|
|
|
|
def save_pretrained(self, output_dir, state_dict=None): |
|
if state_dict is None: |
|
|
|
|
|
state_dict = self.state_dict() |
|
|
|
if getattr(self, "tokenizer", None): |
|
self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) |
|
|
|
if self.get_llm(): |
|
print(f"saving llm to {osp.join(output_dir, 'llm')}") |
|
self.llm.config._name_or_path = osp.join(output_dir, "llm") |
|
llm_state_dict = OrderedDict( |
|
{k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k} |
|
) |
|
self.llm.save_pretrained( |
|
os.path.join(output_dir, "llm"), state_dict=llm_state_dict |
|
) |
|
self.config.llm_cfg = self.llm.config |
|
|
|
if self.get_vision_tower(): |
|
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") |
|
self.vision_tower.config._name_or_path = osp.join( |
|
output_dir, "vision_tower" |
|
) |
|
vision_tower_state_dict = OrderedDict( |
|
{ |
|
k.split("vision_tower.vision_tower.")[-1]: v |
|
for k, v in state_dict.items() |
|
if "vision_tower" in k |
|
} |
|
) |
|
self.vision_tower.vision_tower.save_pretrained( |
|
os.path.join(output_dir, "vision_tower"), |
|
state_dict=vision_tower_state_dict, |
|
) |
|
self.vision_tower.image_processor.save_pretrained( |
|
os.path.join(output_dir, "vision_tower") |
|
) |
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
if hasattr(self.config.vision_tower_cfg, "auto_map"): |
|
if "radio" not in self.get_vision_tower().__class__.__name__.lower(): |
|
delattr(self.config.vision_tower_cfg, "auto_map") |
|
|
|
if self.get_mm_projector(): |
|
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") |
|
self.mm_projector.config._name_or_path = osp.join( |
|
output_dir, "mm_projector" |
|
) |
|
mm_projector_state_dict = OrderedDict( |
|
{ |
|
k.split("mm_projector.")[-1]: v |
|
for k, v in state_dict.items() |
|
if "mm_projector" in k |
|
} |
|
) |
|
self.mm_projector.save_pretrained( |
|
os.path.join(output_dir, "mm_projector"), |
|
state_dict=mm_projector_state_dict, |
|
) |
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
|
self.config._name_or_path = output_dir |
|
self.config.architectures = [self.__class__.__name__] |
|
self.config.save_pretrained(output_dir) |
|
|
|
def get_llm(self): |
|
llm = getattr(self, "llm", None) |
|
if type(llm) is list: |
|
llm = llm[0] |
|
return llm |
|
|
|
def get_lm_head(self): |
|
lm_head = getattr(self.get_llm(), "lm_head", None) |
|
return lm_head |
|
|
|
def get_vision_tower(self): |
|
vision_tower = getattr(self, "vision_tower", None) |
|
if type(vision_tower) is list: |
|
vision_tower = vision_tower[0] |
|
return vision_tower |
|
|
|
def get_mm_projector(self): |
|
mm_projector = getattr(self, "mm_projector", None) |
|
if type(mm_projector) is list: |
|
mm_projector = mm_projector[0] |
|
return mm_projector |
|
|
|
def post_config(self): |
|
self.training = self.get_llm().training |
|
|
|
if getattr(self.config, "llm_cfg", None) is None: |
|
self.config.llm_cfg = self.llm.config |
|
if getattr(self.config, "vision_tower_cfg", None) is None: |
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
if getattr(self.config, "mm_projector_cfg", None) is None: |
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
|
def freezed_module_patch(self): |
|
""" |
|
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. |
|
""" |
|
if self.training: |
|
if self.get_llm() and not getattr( |
|
self.config, "tune_language_model", False |
|
): |
|
pass |
|
|
|
if self.get_vision_tower() and not getattr( |
|
self.config, "tune_vision_tower", False |
|
): |
|
self.get_vision_tower().eval() |
|
if self.get_mm_projector() and not getattr( |
|
self.config, "tune_mm_projector", False |
|
): |
|
self.get_mm_projector().eval() |
|
|
|
def encode_images(self, images): |
|
image_features = self.get_vision_tower()(images) |
|
image_features = self.get_mm_projector()(image_features) |
|
return image_features |
|
|
|
|
|
|
|
def _temporary_reorder_cache(self, past_key_values, sorted_idx): |
|
return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) |
|
|
|
def get_input_embeddings(self): |
|
return self.get_llm().get_input_embeddings() |
|
|
|
def get_output_embeddings(self): |
|
return self.get_llm().get_output_embeddings() |
|
|
|
def resize_token_embeddings(self, embed_size): |
|
self.get_llm().resize_token_embeddings(embed_size) |
|
|
|
|
|
class LlavaMetaForCausalLM(ABC): |
|
"""This class is originally implemented by the LLaVA team and |
|
modified by Haotian Tang and Jason Lu based on Ji Lin's implementation |
|
to support multiple images and input packing.""" |
|
|
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, input_ids, position_ids, attention_mask, past_key_values, labels, images |
|
): |
|
|
|
PROCESS_GROUP_MANAGER = get_pg_manager() |
|
if PROCESS_GROUP_MANAGER is None: |
|
sp_degree = -1 |
|
sp_rank = -1 |
|
else: |
|
sp_degree = PROCESS_GROUP_MANAGER.sp_degree |
|
sp_rank = PROCESS_GROUP_MANAGER.sp_rank |
|
|
|
vision_tower = self.get_vision_tower() |
|
if ( |
|
vision_tower is None |
|
or images is None |
|
or (input_ids.shape[1] == 1 and PROCESS_GROUP_MANAGER is None) |
|
): |
|
if ( |
|
past_key_values is not None |
|
and vision_tower is not None |
|
and images is not None |
|
and input_ids.shape[1] == 1 |
|
): |
|
target_shape = past_key_values[-1][-1].shape[-2] + 1 |
|
attention_mask = torch.cat( |
|
( |
|
attention_mask, |
|
torch.ones( |
|
( |
|
attention_mask.shape[0], |
|
target_shape - attention_mask.shape[1], |
|
), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
), |
|
), |
|
dim=1, |
|
) |
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
|
return ( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
None, |
|
labels, |
|
) |
|
|
|
if type(images) is list: |
|
images = torch.cat(images, dim=0) |
|
elif images.ndim == 5: |
|
images = images.flatten(0, 1) |
|
image_features = self.encode_images(images).to(self.device) |
|
|
|
if getattr(self.config, "turn_mm_projector", False) and getattr( |
|
self.config, "mm_use_im_start_end", False |
|
): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device |
|
) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
|
|
input_ids_copy = input_ids.clone() |
|
|
|
input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0 |
|
input_embeds = self.llm.model.embed_tokens(input_ids_copy) |
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] |
|
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
input_embeds_1 = [ |
|
cur_input_embeds[cur_attention_mask] |
|
for cur_input_embeds, cur_attention_mask in zip( |
|
input_embeds, attention_mask |
|
) |
|
] |
|
labels = [ |
|
cur_labels[cur_attention_mask] |
|
for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
|
] |
|
|
|
new_input_embeds = [] |
|
new_labels = [] |
|
cur_image_idx = 0 |
|
|
|
|
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
cur_input_ids = input_ids[batch_idx] |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
if num_images == 0: |
|
cur_image_features = image_features[0] |
|
cur_input_embeds_1 = input_embeds_1[batch_idx] |
|
cur_input_embeds = torch.cat( |
|
[cur_input_embeds_1, cur_image_features[0:0]], dim=0 |
|
) |
|
new_input_embeds.append(cur_input_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
|
|
continue |
|
|
|
cur_input_embeds = input_embeds_1[batch_idx] |
|
image_token_indices = ( |
|
[-1] |
|
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() |
|
+ [cur_input_ids.shape[0]] |
|
) |
|
cur_input_ids_noim = [] |
|
cur_labels = labels[batch_idx] |
|
cur_labels_noim = [] |
|
cur_input_embeds_no_im = [] |
|
for i in range(len(image_token_indices) - 1): |
|
if ( |
|
sp_degree > 1 and i == 0 and sp_rank != 0 |
|
): |
|
cur_input_ids_noim.append(cur_input_ids[0:0]) |
|
cur_labels_noim.append(cur_labels[0:0]) |
|
cur_input_embeds_no_im.append(cur_input_embeds[0:0]) |
|
continue |
|
cur_input_ids_noim.append( |
|
cur_input_ids[ |
|
image_token_indices[i] + 1 : image_token_indices[i + 1] |
|
] |
|
) |
|
cur_labels_noim.append( |
|
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]] |
|
) |
|
cur_input_embeds_no_im.append( |
|
cur_input_embeds[ |
|
image_token_indices[i] + 1 : image_token_indices[i + 1] |
|
] |
|
) |
|
|
|
cur_new_input_embeds = [] |
|
cur_new_labels = [] |
|
for i in range(num_images + 1): |
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
|
cur_new_labels.append(cur_labels_noim[i]) |
|
if i < num_images: |
|
cur_image_features = image_features[cur_image_idx] |
|
cur_image_idx += 1 |
|
cur_new_input_embeds.append(cur_image_features) |
|
cur_new_labels.append( |
|
torch.full( |
|
(cur_image_features.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
|
cur_new_labels = torch.cat(cur_new_labels) |
|
|
|
new_input_embeds.append(cur_new_input_embeds) |
|
new_labels.append(cur_new_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr( |
|
self.llm.config, "tokenizer_model_max_length", None |
|
) |
|
if tokenizer_model_max_length is not None: |
|
if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): |
|
warnings.warn("Inputs truncated!") |
|
new_input_embeds = [ |
|
x[:tokenizer_model_max_length] for x in new_input_embeds |
|
] |
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
|
|
|
|
batch_size = len(new_input_embeds) |
|
|
|
new_input_embeds_padded = [] |
|
new_labels_padded = torch.full( |
|
(batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device, |
|
) |
|
attention_mask = torch.zeros( |
|
(batch_size, max_len), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
position_ids = torch.zeros( |
|
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate( |
|
zip(new_input_embeds, new_labels) |
|
): |
|
cur_len = cur_new_embed.shape[0] |
|
if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left": |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
cur_new_embed, |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, -cur_len:] = cur_new_labels |
|
attention_mask[i, -cur_len:] = True |
|
position_ids[i, -cur_len:] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
else: |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
cur_new_embed, |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
|
|
if PROCESS_GROUP_MANAGER is not None: |
|
return ( |
|
None, |
|
_position_ids, |
|
attention_mask, |
|
past_key_values, |
|
new_input_embeds, |
|
new_labels, |
|
) |
|
|
|
return ( |
|
None, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
new_input_embeds, |
|
new_labels, |
|
) |
|
|
|
def repack_multimodal_data( |
|
self, |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
inputs_embeds, |
|
labels, |
|
): |
|
|
|
PROCESS_GROUP_MANAGER = get_pg_manager() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if PROCESS_GROUP_MANAGER is not None: |
|
sp_degree = PROCESS_GROUP_MANAGER.sp_degree |
|
sp_rank = PROCESS_GROUP_MANAGER.sp_rank |
|
sp_group = PROCESS_GROUP_MANAGER.ulysses_pg |
|
bs, shard_seqlen = position_ids.shape |
|
ulysess_seq_len = [ |
|
torch.zeros(1, dtype=torch.int64, device=position_ids.device) |
|
for _ in range(sp_degree) |
|
] |
|
dist.all_gather( |
|
ulysess_seq_len, |
|
torch.tensor(shard_seqlen, device=position_ids.device), |
|
group=sp_group, |
|
) |
|
|
|
|
|
|
|
attention_mask_list = [ |
|
torch.zeros( |
|
(bs, ulysess_seq_len[i]), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
for i in range(sp_degree) |
|
] |
|
dist.all_gather(attention_mask_list, attention_mask, group=sp_group) |
|
effective_seqlen_list = [ |
|
attention_mask_list[i].sum(dim=-1) for i in range(sp_degree) |
|
] |
|
effective_seqlen = torch.stack(effective_seqlen_list, dim=-1) |
|
effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0) |
|
|
|
global_attention_mask_list = [] |
|
for i in range(bs): |
|
global_attention_mask_batch_list = [] |
|
for j in range(sp_degree): |
|
global_attention_mask_batch_list.append( |
|
attention_mask_list[j][i, : effective_seqlen_batch_list[i][j]] |
|
) |
|
global_attention_mask_list.append( |
|
torch.cat(global_attention_mask_batch_list, dim=0) |
|
) |
|
global_attention_mask = torch.nn.utils.rnn.pad_sequence( |
|
global_attention_mask_list, batch_first=True, padding_value=False |
|
) |
|
|
|
|
|
global_seq_len = global_attention_mask.shape[-1] |
|
seq_len_sharded = global_seq_len // sp_degree |
|
start_idx_reshard = seq_len_sharded * sp_rank |
|
end_idx_reshard = ( |
|
start_idx_reshard + seq_len_sharded |
|
if sp_rank < sp_degree - 1 |
|
else global_seq_len |
|
) |
|
|
|
|
|
|
|
|
|
|
|
new_attention_mask = torch.narrow( |
|
global_attention_mask, |
|
1, |
|
start_idx_reshard, |
|
end_idx_reshard - start_idx_reshard, |
|
) |
|
|
|
|
|
position_ids_list = [ |
|
torch.zeros( |
|
(bs, ulysess_seq_len[i]), |
|
dtype=position_ids.dtype, |
|
device=position_ids.device, |
|
) |
|
for i in range(sp_degree) |
|
] |
|
dist.all_gather(position_ids_list, position_ids, group=sp_group) |
|
global_position_ids_list = [] |
|
for i in range(bs): |
|
global_position_ids_batch_list = [] |
|
for j in range(sp_degree): |
|
global_position_ids_batch_list.append( |
|
position_ids_list[j][i, : effective_seqlen_batch_list[i][j]] |
|
) |
|
global_position_ids_list.append( |
|
torch.cat(global_position_ids_batch_list, dim=0) |
|
) |
|
global_position_ids = torch.nn.utils.rnn.pad_sequence( |
|
global_position_ids_list, batch_first=True, padding_value=-1 |
|
) |
|
new_position_ids = torch.narrow( |
|
global_position_ids, |
|
1, |
|
start_idx_reshard, |
|
end_idx_reshard - start_idx_reshard, |
|
) |
|
|
|
|
|
labels_list = [ |
|
torch.zeros( |
|
(bs, ulysess_seq_len[i]), dtype=labels.dtype, device=labels.device |
|
) |
|
for i in range(sp_degree) |
|
] |
|
dist.all_gather(labels_list, labels, group=sp_group) |
|
global_labels_list = [] |
|
for i in range(bs): |
|
global_labels_batch_list = [] |
|
for j in range(sp_degree): |
|
global_labels_batch_list.append( |
|
labels_list[j][i, : effective_seqlen_batch_list[i][j]] |
|
) |
|
global_labels_list.append(torch.cat(global_labels_batch_list, dim=0)) |
|
global_labels = torch.nn.utils.rnn.pad_sequence( |
|
global_labels_list, batch_first=True, padding_value=IGNORE_INDEX |
|
) |
|
new_labels = torch.narrow( |
|
global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysess_seq_len_cat = torch.cat(ulysess_seq_len, dim=0) |
|
global_inputs_embeds_list = [] |
|
if sp_rank == 0: |
|
original_start_id = 0 |
|
original_end_id = torch.sum(ulysess_seq_len_cat[: sp_rank + 1]).item() |
|
elif sp_rank == sp_degree - 1: |
|
original_start_id = torch.sum(ulysess_seq_len_cat[:sp_rank]).item() |
|
original_end_id = torch.sum(ulysess_seq_len_cat[: sp_rank + 1]).item() |
|
else: |
|
original_start_id = torch.sum(ulysess_seq_len_cat[:sp_rank]).item() |
|
original_end_id = torch.sum(ulysess_seq_len_cat[: sp_rank + 1]).item() |
|
all_inputs_embeds = torch.zeros( |
|
bs, |
|
torch.sum(ulysess_seq_len_cat), |
|
inputs_embeds.shape[-1], |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
).contiguous() |
|
all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds |
|
dist.barrier(group=sp_group) |
|
dist.all_reduce(all_inputs_embeds, group=sp_group) |
|
dist.barrier(group=sp_group) |
|
for i in range(bs): |
|
global_inputs_embeds_batch_list = [] |
|
for j in range(sp_degree): |
|
prev_len = torch.sum(ulysess_seq_len_cat[:j]).item() if j > 0 else 0 |
|
start_id = prev_len |
|
end_id = prev_len + effective_seqlen_batch_list[i][j] |
|
global_inputs_embeds_batch_list.append( |
|
all_inputs_embeds[i, start_id:end_id] |
|
) |
|
global_inputs_embeds_list.append( |
|
torch.cat(global_inputs_embeds_batch_list, dim=0) |
|
) |
|
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
global_inputs_embeds_list, batch_first=True, padding_value=0 |
|
) |
|
new_inputs_embeds = torch.narrow( |
|
global_inputs_embeds, |
|
1, |
|
start_idx_reshard, |
|
end_idx_reshard - start_idx_reshard, |
|
) |
|
|
|
return ( |
|
None, |
|
new_position_ids, |
|
new_attention_mask, |
|
past_key_values, |
|
new_inputs_embeds, |
|
new_labels, |
|
None, |
|
) |
|
|
|
|
|
|
|
new_inputs_embeds = [] |
|
new_position_ids = [] |
|
new_labels = [] |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
sorted_seqlens_in_batch, sorted_idx = torch.sort( |
|
seqlens_in_batch, descending=True |
|
) |
|
max_seqlen = inputs_embeds.shape[1] |
|
|
|
cur_inputs_embeds = [] |
|
cur_position_ids = [] |
|
cur_labels = [] |
|
cur_batch_len = 0 |
|
for i in range(len(sorted_seqlens_in_batch)): |
|
cur_seqlen = sorted_seqlens_in_batch[i].item() |
|
if cur_seqlen + cur_batch_len <= max_seqlen: |
|
cur_batch_len += cur_seqlen |
|
|
|
|
|
cur_inputs_embeds.append( |
|
inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]] |
|
) |
|
cur_position_ids.append( |
|
torch.arange( |
|
cur_inputs_embeds[-1].shape[0], |
|
device=cur_inputs_embeds[-1].device, |
|
) |
|
) |
|
|
|
|
|
cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]) |
|
else: |
|
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) |
|
new_position_ids.append(torch.cat(cur_position_ids, 0)) |
|
new_labels.append(torch.cat(cur_labels, 0)) |
|
|
|
cur_batch_len = cur_seqlen |
|
cur_inputs_embeds = [ |
|
inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]] |
|
] |
|
cur_position_ids = [ |
|
torch.arange( |
|
cur_inputs_embeds[-1].shape[0], |
|
device=cur_inputs_embeds[-1].device, |
|
) |
|
] |
|
cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]] |
|
|
|
|
|
|
|
if len(cur_inputs_embeds): |
|
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) |
|
new_position_ids.append(torch.cat(cur_position_ids, 0)) |
|
new_labels.append(torch.cat(cur_labels, 0)) |
|
|
|
new_inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id |
|
) |
|
|
|
new_position_ids = torch.nn.utils.rnn.pad_sequence( |
|
new_position_ids, batch_first=True, padding_value=-1 |
|
) |
|
|
|
new_labels = torch.nn.utils.rnn.pad_sequence( |
|
new_labels, batch_first=True, padding_value=IGNORE_INDEX |
|
) |
|
|
|
new_attention_mask = new_position_ids.ne(-1) |
|
|
|
assert new_attention_mask.sum() == attention_mask.sum() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
None, |
|
new_position_ids, |
|
new_attention_mask, |
|
past_key_values, |
|
new_inputs_embeds, |
|
new_labels, |
|
sorted_seqlens_in_batch, |
|
) |
|
|
|
def initialize_vision_tokenizer(self, model_args, tokenizer): |
|
if model_args.mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if model_args.mm_use_im_start_end: |
|
num_new_tokens = tokenizer.add_tokens( |
|
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
|
) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = self.get_input_embeddings().weight.data |
|
output_embeddings = self.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
if model_args.pretrain_mm_mlp_adapter: |
|
mm_projector_weights = torch.load( |
|
model_args.pretrain_mm_mlp_adapter, map_location="cpu" |
|
) |
|
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
|
assert num_new_tokens == 2 |
|
if input_embeddings.shape == embed_tokens_weight.shape: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight[ |
|
-num_new_tokens: |
|
] |
|
elif embed_tokens_weight.shape[0] == num_new_tokens: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight |
|
else: |
|
raise ValueError( |
|
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." |
|
) |
|
elif model_args.mm_use_im_patch_token: |
|
if model_args.mm_projector: |
|
for p in self.get_input_embeddings().parameters(): |
|
p.requires_grad = False |
|
for p in self.get_output_embeddings().parameters(): |
|
p.requires_grad = False |
|
|