fffiloni's picture
Migrated from GitHub
d59f323 verified
raw
history blame
24.2 kB
import torch
from xtuner.model import InternVL_V1_5
from typing import List, Optional, Tuple, Union
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
LlamaTokenizer)
import torch.nn as nn
from mmengine import print_log
from torch.nn import CrossEntropyLoss
from transformers import (AutoConfig, AutoModel, AutoTokenizer,
BitsAndBytesConfig)
from xtuner.model.utils import (find_all_linear_names, get_peft_model_state_dict,
guess_load_checkpoint, make_inputs_require_grad)
import os
def get_rank_and_world_size():
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
return rank, world_size
# This function is used to split large model
def split_model(model_name):
import math
device_map = {}
num_gpus = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
num_gpus = num_gpus // world_size
num_layers = {'InternVL2-8B': 32, 'InternVL2-26B': 48,
'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
# Since the first GPU will be used for ViT, treat it as 0.8 GPU.
num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.2))
num_layers_per_gpu = [num_layers_per_gpu] * num_gpus
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.8)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i
layer_cnt += 1
device_map['vision_model'] = rank
device_map['mlp1'] = rank
device_map['language_model.model.tok_embeddings'] = rank
device_map['language_model.model.embed_tokens'] = rank
device_map['language_model.output'] = rank
device_map['language_model.model.norm'] = rank
device_map['language_model.lm_head'] = rank
device_map[f'language_model.model.layers.{num_layers - 1}'] = rank
return device_map
class InternVL_Slowfast(InternVL_V1_5):
def __init__(self,
model_path,
freeze_llm=False,
freeze_visual_encoder=False,
llm_lora=None,
visual_encoder_lora=None,
quantization_vit=False,
quantization_llm=False,
pretrained_pth=None,
special_tokens=None,
model_split=False,
):
print_log('Start to load InternVL_V1_5 model.', logger='current')
super(InternVL_V1_5, self).__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = freeze_visual_encoder
self.use_llm_lora = llm_lora is not None
self.use_visual_encoder_lora = visual_encoder_lora is not None
self.quantization_vit = quantization_vit
self.quantization_llm = quantization_llm
if quantization_vit:
assert visual_encoder_lora is not None
if quantization_llm:
assert quantization_llm and llm_lora is not None
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if config.llm_config.model_type == 'internlm2':
config.llm_config.attn_implementation = 'flash_attention_2'
else:
config.llm_config._attn_implementation = 'flash_attention_2'
if quantization_vit is False and quantization_llm is False:
quantization = None
else:
llm_int8_skip_modules = ['mlp1']
if quantization_llm and not quantization_vit:
llm_int8_skip_modules.append('vision_model')
if quantization_vit and not quantization_llm:
llm_int8_skip_modules.append('language_model')
quantization_config = dict(
type=BitsAndBytesConfig,
llm_int8_skip_modules=llm_int8_skip_modules,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')
quantization_clazz = quantization_config.pop('type')
quantization = quantization_clazz(**quantization_config)
if model_split:
# print("\n\nDone Model Split !!!!!!!!!!!\n\n")
device_map = split_model("InternVL2-26B")
# print(device_map)
self.device = 'cuda'
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map=device_map).eval()
else:
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
quantization_config=quantization,
config=config,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True)
self.tokenizer = tokenizer
if special_tokens is not None:
self._add_special_tokens(special_tokens)
img_context_token_id = tokenizer.convert_tokens_to_ids('<IMG_CONTEXT>')
self.model.img_context_token_id = img_context_token_id
if self.freeze_llm:
self.model.language_model.requires_grad_(False)
if self.freeze_visual_encoder:
self.model.vision_model.requires_grad_(False)
if hasattr(self.model.language_model, 'enable_input_require_grads'):
self.model.language_model.enable_input_require_grads()
else:
self.model.language_model.get_input_embeddings(
).register_forward_hook(make_inputs_require_grad)
self.gradient_checkpointing_enable()
if self.use_llm_lora:
self._prepare_llm_for_lora(llm_lora)
if self.use_visual_encoder_lora:
self._prepare_visual_encoder_for_lora(visual_encoder_lora)
if pretrained_pth is not None:
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
self.load_state_dict(pretrained_state_dict, strict=False)
print(f'Load pretrained weight from {pretrained_pth}')
self._count = 0
print_log(self, logger='current')
print_log('InternVL_V1_5 construction is complete', logger='current')
self.transfer_to_hf = False
def _add_special_tokens(self, special_tokens):
num_new_tokens = self.tokenizer.add_tokens(
special_tokens, special_tokens=True)
if num_new_tokens > 0:
self.model.language_model.resize_token_embeddings(len(self.tokenizer))
def _post_init(self, fast_pool_size=4, fast_pool=True):
if fast_pool:
self.fast_pool = nn.AdaptiveAvgPool2d((fast_pool_size, fast_pool_size))
return
def forward(self, data, data_samples=None, mode='loss', fast_token_idx=None):
if 'fast_pixel_values' in data.keys():
assert fast_token_idx is not None
fast_pixel_values = data['fast_pixel_values']
if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5:
if type(fast_pixel_values) is list:
fast_pixel_values = [
x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values
]
# b*n, c, h, w
fast_concat_images = torch.cat(
[image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0)
else:
raise NotImplementedError()
else:
fast_pixel_values = None
fast_concat_images = None
pixel_values = data['pixel_values']
if type(pixel_values) is list or pixel_values.ndim == 5:
if type(pixel_values) is list:
pixel_values = [
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
]
# b*n, c, h, w
concat_images = torch.cat(
[image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0)
else:
raise NotImplementedError()
input_ids = data['input_ids']
position_ids = data['position_ids']
attention_mask = data['attention_mask']
# sum is 0 are text
image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0
image_flags = image_flags.long()
labels = data['labels']
use_cache = False
if 'vp_overall_mask' not in data.keys():
vp_overall_mask = None
else:
vp_overall_mask = data['vp_overall_mask']
if 'prompt_masks' in data.keys():
prompt_masks = data['prompt_masks']
else:
prompt_masks = None
outputs = self._llm_forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
image_flags=image_flags,
pixel_values=concat_images,
labels=labels,
use_cache=use_cache,
output_hidden_states=True,
fast_pixel_values=fast_concat_images,
fast_token_idx=fast_token_idx,
vp_overall_mask=vp_overall_mask,
prompt_masks=prompt_masks,
)
return outputs
def _llm_forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
fast_pixel_values=None,
fast_token_idx=None,
vp_overall_mask=None,
prompt_masks=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None \
else self.model.config.use_return_dict
image_flags = image_flags.squeeze(-1)
# We only added the clone code here to avoid the error.
input_embeds = self.model.language_model.get_input_embeddings()(
input_ids).clone()
if fast_pixel_values is not None:
n_fast_images = fast_pixel_values.shape[0]
whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0)
vit_embeds = self.model.extract_feature(whole_pixel_values)
vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16?
fast_vit_embeds = vit_embeds[:n_fast_images] # (n_fast_images, hw, c)
_size = int(fast_vit_embeds.shape[1] ** 0.5)
fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, fast_vit_embeds.shape[-1])
# pooling
fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) # (n_fast_images, c, h, w)
fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) # (n_fast_images, c, hw)
fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1)
vit_embeds = vit_embeds[n_fast_images:]
else:
vit_embeds = self.model.extract_feature(pixel_values)
vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16?
fast_vit_embeds = None
vit_embeds = vit_embeds[image_flags == 1]
vit_batch_size = pixel_values.shape[0]
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
self._count += 1
if vp_overall_mask is not None and prompt_masks is not None:
vp_embeds = []
vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool()
prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks]
vp_overall_mask = vp_overall_mask[image_flags == 1]
overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c)
i_vp_img = 0
for i_img in range(len(vit_embeds)):
vp_embeds.append(vit_embeds[i_img].reshape(-1, C))
if vp_overall_mask[i_img]:
tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C)
objects_prompt_masks = prompt_masks[i_vp_img]
n_obj = len(objects_prompt_masks)
tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1)
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
i_vp_img += 1
vp_embeds = torch.cat(vp_embeds, dim=0)
else:
vp_embeds = None
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.model.img_context_token_id)
if vp_embeds is None:
try:
input_embeds[selected] = vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape='
f'{input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = selected.sum()
if n_token > len(vit_embeds):
print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!")
expand_ratio = n_token // len(vit_embeds) + 1
vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0)
input_embeds[selected] = vit_embeds[:n_token]
else:
try:
input_embeds[selected] = vp_embeds.reshape(-1, C)
except Exception as e:
vp_embeds = vp_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape='
f'{input_embeds[selected].shape}, '
f'vp_embeds.shape={vp_embeds.shape}')
n_token = selected.sum()
if n_token > len(vp_embeds):
print(f"Wrong !!! {n_token} image tokens in text but only {len(vp_embeds)} vit embeds !!!")
expand_ratio = n_token // len(vp_embeds) + 1
vp_embeds = torch.cat([vp_embeds] * expand_ratio, dim=0)
input_embeds[selected] = vp_embeds[:n_token]
if fast_vit_embeds is not None:
selected = (input_ids == fast_token_idx)
selected_tot = selected.sum().item()
if selected_tot > fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]:
assert selected_tot % (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) == 0
repeat_times = selected_tot / (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1])
fast_vit_embeds = fast_vit_embeds.repeat(int(repeat_times), 1, 1)
try:
input_embeds[selected] = fast_vit_embeds.reshape(-1, C)
except Exception as e:
fast_vit_embeds = fast_vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[fast_selected].shape='
f'{input_embeds[selected].shape}, '
f'fast_vit_embeds.shape={fast_vit_embeds.shape}')
n_token = selected.sum()
input_embeds[selected] = fast_vit_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.model.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(
-1, self.model.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
fast_token_idx=None,
fast_pixel_values=None,
prompt_masks=None,
vp_overall_mask=None,
**generate_kwargs,
) -> torch.LongTensor:
device = self.model.device
assert self.model.img_context_token_id is not None
if fast_pixel_values is not None:
assert fast_token_idx is not None
if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5:
if type(fast_pixel_values) is list:
fast_pixel_values = [
x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values
]
# b*n, c, h, w
fast_pixel_values = torch.cat(
[image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0)
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
if type(pixel_values) is list or pixel_values.ndim == 5:
if type(pixel_values) is list:
pixel_values = [
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
]
# b*n, c, h, w
pixel_values = torch.cat(
[image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0)
if fast_pixel_values is not None:
n_fast_images = fast_pixel_values.shape[0]
whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0)
vit_embeds = self.model.extract_feature(whole_pixel_values.to(device))
# vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16?
fast_vit_embeds = vit_embeds[:n_fast_images] # (n_fast_images, hw, c)
_size = int(fast_vit_embeds.shape[1] ** 0.5)
fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size,
fast_vit_embeds.shape[-1])
# pooling
fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) # (n_fast_images, c, h, w)
fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) # (n_fast_images, c, hw)
fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1)
vit_embeds = vit_embeds[n_fast_images:]
else:
fast_vit_embeds = None
vit_embeds = self.model.extract_feature(pixel_values.to(device))
image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
image_flags = image_flags.long()
vit_embeds = vit_embeds[image_flags == 1]
input_embeds = self.model.language_model.get_input_embeddings()(input_ids.to(device))
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
if vp_overall_mask is not None and prompt_masks is not None:
vp_embeds = []
vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool()
prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks]
vp_overall_mask = vp_overall_mask[image_flags == 1]
overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c)
i_vp_img = 0
for i_img in range(len(vit_embeds)):
vp_embeds.append(vit_embeds[i_img].reshape(-1, C))
if vp_overall_mask[i_img]:
tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C)
objects_prompt_masks = prompt_masks[i_vp_img]
n_obj = len(objects_prompt_masks)
tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1)
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
i_vp_img += 1
vp_embeds = torch.cat(vp_embeds, dim=0)
else:
vp_embeds = None
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.model.img_context_token_id)
assert selected.sum() != 0
if vp_embeds is None:
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
else:
if len(input_embeds[selected]) != len(vp_embeds.reshape(-1, C)):
print("Shape mismatch, selected is {}, vp embeds is {} !!!"\
.format(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C))))
min_tokens = min(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C)))
input_embeds[selected][:min_tokens] = vp_embeds.reshape(-1, C)[:min_tokens].to(input_embeds.device)
else:
input_embeds[selected] = vp_embeds.reshape(-1, C).to(input_embeds.device)
if fast_vit_embeds is not None:
selected = (input_ids == fast_token_idx)
# FIXME, add repeat.
assert selected.sum() != 0
input_embeds[selected] = fast_vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
outputs = self.model.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask.to(device),
generation_config=generation_config,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=True,
**generate_kwargs,
)
return outputs
def state_dict(self, *args, **kwargs):
if self.transfer_to_hf:
state_dict = super(InternVL_V1_5, self).state_dict(*args, **kwargs)
return state_dict
else:
return super().state_dict(*args, **kwargs)