Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from xtuner.model import InternVL_V1_5 | |
from typing import List, Optional, Tuple, Union | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, | |
LlamaTokenizer) | |
import torch.nn as nn | |
from mmengine import print_log | |
from torch.nn import CrossEntropyLoss | |
from transformers import (AutoConfig, AutoModel, AutoTokenizer, | |
BitsAndBytesConfig) | |
from xtuner.model.utils import (find_all_linear_names, get_peft_model_state_dict, | |
guess_load_checkpoint, make_inputs_require_grad) | |
import os | |
def get_rank_and_world_size(): | |
rank = int(os.environ.get('RANK', 0)) | |
world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
return rank, world_size | |
# This function is used to split large model | |
def split_model(model_name): | |
import math | |
device_map = {} | |
num_gpus = torch.cuda.device_count() | |
rank, world_size = get_rank_and_world_size() | |
num_gpus = num_gpus // world_size | |
num_layers = {'InternVL2-8B': 32, 'InternVL2-26B': 48, | |
'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name] | |
# Since the first GPU will be used for ViT, treat it as 0.8 GPU. | |
num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.2)) | |
num_layers_per_gpu = [num_layers_per_gpu] * num_gpus | |
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.8) | |
layer_cnt = 0 | |
for i, num_layer in enumerate(num_layers_per_gpu): | |
for j in range(num_layer): | |
device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i | |
layer_cnt += 1 | |
device_map['vision_model'] = rank | |
device_map['mlp1'] = rank | |
device_map['language_model.model.tok_embeddings'] = rank | |
device_map['language_model.model.embed_tokens'] = rank | |
device_map['language_model.output'] = rank | |
device_map['language_model.model.norm'] = rank | |
device_map['language_model.lm_head'] = rank | |
device_map[f'language_model.model.layers.{num_layers - 1}'] = rank | |
return device_map | |
class InternVL_Slowfast(InternVL_V1_5): | |
def __init__(self, | |
model_path, | |
freeze_llm=False, | |
freeze_visual_encoder=False, | |
llm_lora=None, | |
visual_encoder_lora=None, | |
quantization_vit=False, | |
quantization_llm=False, | |
pretrained_pth=None, | |
special_tokens=None, | |
model_split=False, | |
): | |
print_log('Start to load InternVL_V1_5 model.', logger='current') | |
super(InternVL_V1_5, self).__init__() | |
self.freeze_llm = freeze_llm | |
self.freeze_visual_encoder = freeze_visual_encoder | |
self.use_llm_lora = llm_lora is not None | |
self.use_visual_encoder_lora = visual_encoder_lora is not None | |
self.quantization_vit = quantization_vit | |
self.quantization_llm = quantization_llm | |
if quantization_vit: | |
assert visual_encoder_lora is not None | |
if quantization_llm: | |
assert quantization_llm and llm_lora is not None | |
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
if config.llm_config.model_type == 'internlm2': | |
config.llm_config.attn_implementation = 'flash_attention_2' | |
else: | |
config.llm_config._attn_implementation = 'flash_attention_2' | |
if quantization_vit is False and quantization_llm is False: | |
quantization = None | |
else: | |
llm_int8_skip_modules = ['mlp1'] | |
if quantization_llm and not quantization_vit: | |
llm_int8_skip_modules.append('vision_model') | |
if quantization_vit and not quantization_llm: | |
llm_int8_skip_modules.append('language_model') | |
quantization_config = dict( | |
type=BitsAndBytesConfig, | |
llm_int8_skip_modules=llm_int8_skip_modules, | |
load_in_4bit=True, | |
load_in_8bit=False, | |
llm_int8_threshold=6.0, | |
llm_int8_has_fp16_weight=False, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type='nf4') | |
quantization_clazz = quantization_config.pop('type') | |
quantization = quantization_clazz(**quantization_config) | |
if model_split: | |
# print("\n\nDone Model Split !!!!!!!!!!!\n\n") | |
device_map = split_model("InternVL2-26B") | |
# print(device_map) | |
self.device = 'cuda' | |
self.model = AutoModel.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map=device_map).eval() | |
else: | |
self.model = AutoModel.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
quantization_config=quantization, | |
config=config, | |
trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, trust_remote_code=True) | |
self.tokenizer = tokenizer | |
if special_tokens is not None: | |
self._add_special_tokens(special_tokens) | |
img_context_token_id = tokenizer.convert_tokens_to_ids('<IMG_CONTEXT>') | |
self.model.img_context_token_id = img_context_token_id | |
if self.freeze_llm: | |
self.model.language_model.requires_grad_(False) | |
if self.freeze_visual_encoder: | |
self.model.vision_model.requires_grad_(False) | |
if hasattr(self.model.language_model, 'enable_input_require_grads'): | |
self.model.language_model.enable_input_require_grads() | |
else: | |
self.model.language_model.get_input_embeddings( | |
).register_forward_hook(make_inputs_require_grad) | |
self.gradient_checkpointing_enable() | |
if self.use_llm_lora: | |
self._prepare_llm_for_lora(llm_lora) | |
if self.use_visual_encoder_lora: | |
self._prepare_visual_encoder_for_lora(visual_encoder_lora) | |
if pretrained_pth is not None: | |
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) | |
self.load_state_dict(pretrained_state_dict, strict=False) | |
print(f'Load pretrained weight from {pretrained_pth}') | |
self._count = 0 | |
print_log(self, logger='current') | |
print_log('InternVL_V1_5 construction is complete', logger='current') | |
self.transfer_to_hf = False | |
def _add_special_tokens(self, special_tokens): | |
num_new_tokens = self.tokenizer.add_tokens( | |
special_tokens, special_tokens=True) | |
if num_new_tokens > 0: | |
self.model.language_model.resize_token_embeddings(len(self.tokenizer)) | |
def _post_init(self, fast_pool_size=4, fast_pool=True): | |
if fast_pool: | |
self.fast_pool = nn.AdaptiveAvgPool2d((fast_pool_size, fast_pool_size)) | |
return | |
def forward(self, data, data_samples=None, mode='loss', fast_token_idx=None): | |
if 'fast_pixel_values' in data.keys(): | |
assert fast_token_idx is not None | |
fast_pixel_values = data['fast_pixel_values'] | |
if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5: | |
if type(fast_pixel_values) is list: | |
fast_pixel_values = [ | |
x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values | |
] | |
# b*n, c, h, w | |
fast_concat_images = torch.cat( | |
[image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0) | |
else: | |
raise NotImplementedError() | |
else: | |
fast_pixel_values = None | |
fast_concat_images = None | |
pixel_values = data['pixel_values'] | |
if type(pixel_values) is list or pixel_values.ndim == 5: | |
if type(pixel_values) is list: | |
pixel_values = [ | |
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values | |
] | |
# b*n, c, h, w | |
concat_images = torch.cat( | |
[image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) | |
else: | |
raise NotImplementedError() | |
input_ids = data['input_ids'] | |
position_ids = data['position_ids'] | |
attention_mask = data['attention_mask'] | |
# sum is 0 are text | |
image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0 | |
image_flags = image_flags.long() | |
labels = data['labels'] | |
use_cache = False | |
if 'vp_overall_mask' not in data.keys(): | |
vp_overall_mask = None | |
else: | |
vp_overall_mask = data['vp_overall_mask'] | |
if 'prompt_masks' in data.keys(): | |
prompt_masks = data['prompt_masks'] | |
else: | |
prompt_masks = None | |
outputs = self._llm_forward( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
attention_mask=attention_mask, | |
image_flags=image_flags, | |
pixel_values=concat_images, | |
labels=labels, | |
use_cache=use_cache, | |
output_hidden_states=True, | |
fast_pixel_values=fast_concat_images, | |
fast_token_idx=fast_token_idx, | |
vp_overall_mask=vp_overall_mask, | |
prompt_masks=prompt_masks, | |
) | |
return outputs | |
def _llm_forward( | |
self, | |
pixel_values: torch.FloatTensor, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
image_flags: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
fast_pixel_values=None, | |
fast_token_idx=None, | |
vp_overall_mask=None, | |
prompt_masks=None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
return_dict = return_dict if return_dict is not None \ | |
else self.model.config.use_return_dict | |
image_flags = image_flags.squeeze(-1) | |
# We only added the clone code here to avoid the error. | |
input_embeds = self.model.language_model.get_input_embeddings()( | |
input_ids).clone() | |
if fast_pixel_values is not None: | |
n_fast_images = fast_pixel_values.shape[0] | |
whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0) | |
vit_embeds = self.model.extract_feature(whole_pixel_values) | |
vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? | |
fast_vit_embeds = vit_embeds[:n_fast_images] # (n_fast_images, hw, c) | |
_size = int(fast_vit_embeds.shape[1] ** 0.5) | |
fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, fast_vit_embeds.shape[-1]) | |
# pooling | |
fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) # (n_fast_images, c, h, w) | |
fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) # (n_fast_images, c, hw) | |
fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1) | |
vit_embeds = vit_embeds[n_fast_images:] | |
else: | |
vit_embeds = self.model.extract_feature(pixel_values) | |
vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? | |
fast_vit_embeds = None | |
vit_embeds = vit_embeds[image_flags == 1] | |
vit_batch_size = pixel_values.shape[0] | |
B, N, C = input_embeds.shape | |
input_embeds = input_embeds.reshape(B * N, C) | |
self._count += 1 | |
if vp_overall_mask is not None and prompt_masks is not None: | |
vp_embeds = [] | |
vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool() | |
prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks] | |
vp_overall_mask = vp_overall_mask[image_flags == 1] | |
overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c) | |
i_vp_img = 0 | |
for i_img in range(len(vit_embeds)): | |
vp_embeds.append(vit_embeds[i_img].reshape(-1, C)) | |
if vp_overall_mask[i_img]: | |
tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C) | |
objects_prompt_masks = prompt_masks[i_vp_img] | |
n_obj = len(objects_prompt_masks) | |
tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1) | |
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1) | |
vp_embeds.append(tile_vit_embeds[objects_prompt_masks]) | |
i_vp_img += 1 | |
vp_embeds = torch.cat(vp_embeds, dim=0) | |
else: | |
vp_embeds = None | |
input_ids = input_ids.reshape(B * N) | |
selected = (input_ids == self.model.img_context_token_id) | |
if vp_embeds is None: | |
try: | |
input_embeds[selected] = vit_embeds.reshape(-1, C) | |
except Exception as e: | |
vit_embeds = vit_embeds.reshape(-1, C) | |
print(f'warning: {e}, input_embeds[selected].shape=' | |
f'{input_embeds[selected].shape}, ' | |
f'vit_embeds.shape={vit_embeds.shape}') | |
n_token = selected.sum() | |
if n_token > len(vit_embeds): | |
print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!") | |
expand_ratio = n_token // len(vit_embeds) + 1 | |
vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0) | |
input_embeds[selected] = vit_embeds[:n_token] | |
else: | |
try: | |
input_embeds[selected] = vp_embeds.reshape(-1, C) | |
except Exception as e: | |
vp_embeds = vp_embeds.reshape(-1, C) | |
print(f'warning: {e}, input_embeds[selected].shape=' | |
f'{input_embeds[selected].shape}, ' | |
f'vp_embeds.shape={vp_embeds.shape}') | |
n_token = selected.sum() | |
if n_token > len(vp_embeds): | |
print(f"Wrong !!! {n_token} image tokens in text but only {len(vp_embeds)} vit embeds !!!") | |
expand_ratio = n_token // len(vp_embeds) + 1 | |
vp_embeds = torch.cat([vp_embeds] * expand_ratio, dim=0) | |
input_embeds[selected] = vp_embeds[:n_token] | |
if fast_vit_embeds is not None: | |
selected = (input_ids == fast_token_idx) | |
selected_tot = selected.sum().item() | |
if selected_tot > fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]: | |
assert selected_tot % (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) == 0 | |
repeat_times = selected_tot / (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) | |
fast_vit_embeds = fast_vit_embeds.repeat(int(repeat_times), 1, 1) | |
try: | |
input_embeds[selected] = fast_vit_embeds.reshape(-1, C) | |
except Exception as e: | |
fast_vit_embeds = fast_vit_embeds.reshape(-1, C) | |
print(f'warning: {e}, input_embeds[fast_selected].shape=' | |
f'{input_embeds[selected].shape}, ' | |
f'fast_vit_embeds.shape={fast_vit_embeds.shape}') | |
n_token = selected.sum() | |
input_embeds[selected] = fast_vit_embeds[:n_token] | |
input_embeds = input_embeds.reshape(B, N, C) | |
outputs = self.model.language_model( | |
inputs_embeds=input_embeds, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
logits = outputs.logits | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss() | |
shift_logits = shift_logits.view( | |
-1, self.model.language_model.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
if not return_dict: | |
output = (logits, ) + outputs[1:] | |
return (loss, ) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def generate( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
input_ids: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
visual_features: Optional[torch.FloatTensor] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
fast_token_idx=None, | |
fast_pixel_values=None, | |
prompt_masks=None, | |
vp_overall_mask=None, | |
**generate_kwargs, | |
) -> torch.LongTensor: | |
device = self.model.device | |
assert self.model.img_context_token_id is not None | |
if fast_pixel_values is not None: | |
assert fast_token_idx is not None | |
if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5: | |
if type(fast_pixel_values) is list: | |
fast_pixel_values = [ | |
x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values | |
] | |
# b*n, c, h, w | |
fast_pixel_values = torch.cat( | |
[image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0) | |
if pixel_values is not None: | |
if visual_features is not None: | |
vit_embeds = visual_features | |
else: | |
if type(pixel_values) is list or pixel_values.ndim == 5: | |
if type(pixel_values) is list: | |
pixel_values = [ | |
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values | |
] | |
# b*n, c, h, w | |
pixel_values = torch.cat( | |
[image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) | |
if fast_pixel_values is not None: | |
n_fast_images = fast_pixel_values.shape[0] | |
whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0) | |
vit_embeds = self.model.extract_feature(whole_pixel_values.to(device)) | |
# vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? | |
fast_vit_embeds = vit_embeds[:n_fast_images] # (n_fast_images, hw, c) | |
_size = int(fast_vit_embeds.shape[1] ** 0.5) | |
fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, | |
fast_vit_embeds.shape[-1]) | |
# pooling | |
fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) # (n_fast_images, c, h, w) | |
fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) # (n_fast_images, c, hw) | |
fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1) | |
vit_embeds = vit_embeds[n_fast_images:] | |
else: | |
fast_vit_embeds = None | |
vit_embeds = self.model.extract_feature(pixel_values.to(device)) | |
image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0 | |
image_flags = image_flags.long() | |
vit_embeds = vit_embeds[image_flags == 1] | |
input_embeds = self.model.language_model.get_input_embeddings()(input_ids.to(device)) | |
B, N, C = input_embeds.shape | |
input_embeds = input_embeds.reshape(B * N, C) | |
if vp_overall_mask is not None and prompt_masks is not None: | |
vp_embeds = [] | |
vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool() | |
prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks] | |
vp_overall_mask = vp_overall_mask[image_flags == 1] | |
overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c) | |
i_vp_img = 0 | |
for i_img in range(len(vit_embeds)): | |
vp_embeds.append(vit_embeds[i_img].reshape(-1, C)) | |
if vp_overall_mask[i_img]: | |
tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C) | |
objects_prompt_masks = prompt_masks[i_vp_img] | |
n_obj = len(objects_prompt_masks) | |
tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1) | |
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1) | |
vp_embeds.append(tile_vit_embeds[objects_prompt_masks]) | |
i_vp_img += 1 | |
vp_embeds = torch.cat(vp_embeds, dim=0) | |
else: | |
vp_embeds = None | |
input_ids = input_ids.reshape(B * N) | |
selected = (input_ids == self.model.img_context_token_id) | |
assert selected.sum() != 0 | |
if vp_embeds is None: | |
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) | |
else: | |
if len(input_embeds[selected]) != len(vp_embeds.reshape(-1, C)): | |
print("Shape mismatch, selected is {}, vp embeds is {} !!!"\ | |
.format(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C)))) | |
min_tokens = min(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C))) | |
input_embeds[selected][:min_tokens] = vp_embeds.reshape(-1, C)[:min_tokens].to(input_embeds.device) | |
else: | |
input_embeds[selected] = vp_embeds.reshape(-1, C).to(input_embeds.device) | |
if fast_vit_embeds is not None: | |
selected = (input_ids == fast_token_idx) | |
# FIXME, add repeat. | |
assert selected.sum() != 0 | |
input_embeds[selected] = fast_vit_embeds.reshape(-1, C).to(input_embeds.device) | |
input_embeds = input_embeds.reshape(B, N, C) | |
else: | |
input_embeds = self.model.language_model.get_input_embeddings()(input_ids) | |
outputs = self.model.language_model.generate( | |
inputs_embeds=input_embeds, | |
attention_mask=attention_mask.to(device), | |
generation_config=generation_config, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
use_cache=True, | |
**generate_kwargs, | |
) | |
return outputs | |
def state_dict(self, *args, **kwargs): | |
if self.transfer_to_hf: | |
state_dict = super(InternVL_V1_5, self).state_dict(*args, **kwargs) | |
return state_dict | |
else: | |
return super().state_dict(*args, **kwargs) | |