import torch from xtuner.model import InternVL_V1_5 from typing import List, Optional, Tuple, Union from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, LlamaTokenizer) import torch.nn as nn from mmengine import print_log from torch.nn import CrossEntropyLoss from transformers import (AutoConfig, AutoModel, AutoTokenizer, BitsAndBytesConfig) from xtuner.model.utils import (find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad) import os def get_rank_and_world_size(): rank = int(os.environ.get('RANK', 0)) world_size = int(os.environ.get('WORLD_SIZE', 1)) return rank, world_size # This function is used to split large model def split_model(model_name): import math device_map = {} num_gpus = torch.cuda.device_count() rank, world_size = get_rank_and_world_size() num_gpus = num_gpus // world_size num_layers = {'InternVL2-8B': 32, 'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name] # Since the first GPU will be used for ViT, treat it as 0.8 GPU. num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.2)) num_layers_per_gpu = [num_layers_per_gpu] * num_gpus num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.8) layer_cnt = 0 for i, num_layer in enumerate(num_layers_per_gpu): for j in range(num_layer): device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i layer_cnt += 1 device_map['vision_model'] = rank device_map['mlp1'] = rank device_map['language_model.model.tok_embeddings'] = rank device_map['language_model.model.embed_tokens'] = rank device_map['language_model.output'] = rank device_map['language_model.model.norm'] = rank device_map['language_model.lm_head'] = rank device_map[f'language_model.model.layers.{num_layers - 1}'] = rank return device_map class InternVL_Slowfast(InternVL_V1_5): def __init__(self, model_path, freeze_llm=False, freeze_visual_encoder=False, llm_lora=None, visual_encoder_lora=None, quantization_vit=False, quantization_llm=False, pretrained_pth=None, special_tokens=None, model_split=False, ): print_log('Start to load InternVL_V1_5 model.', logger='current') super(InternVL_V1_5, self).__init__() self.freeze_llm = freeze_llm self.freeze_visual_encoder = freeze_visual_encoder self.use_llm_lora = llm_lora is not None self.use_visual_encoder_lora = visual_encoder_lora is not None self.quantization_vit = quantization_vit self.quantization_llm = quantization_llm if quantization_vit: assert visual_encoder_lora is not None if quantization_llm: assert quantization_llm and llm_lora is not None config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if config.llm_config.model_type == 'internlm2': config.llm_config.attn_implementation = 'flash_attention_2' else: config.llm_config._attn_implementation = 'flash_attention_2' if quantization_vit is False and quantization_llm is False: quantization = None else: llm_int8_skip_modules = ['mlp1'] if quantization_llm and not quantization_vit: llm_int8_skip_modules.append('vision_model') if quantization_vit and not quantization_llm: llm_int8_skip_modules.append('language_model') quantization_config = dict( type=BitsAndBytesConfig, llm_int8_skip_modules=llm_int8_skip_modules, load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4') quantization_clazz = quantization_config.pop('type') quantization = quantization_clazz(**quantization_config) if model_split: # print("\n\nDone Model Split !!!!!!!!!!!\n\n") device_map = split_model("InternVL2-26B") # print(device_map) self.device = 'cuda' self.model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map=device_map).eval() else: self.model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, quantization_config=quantization, config=config, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True) self.tokenizer = tokenizer if special_tokens is not None: self._add_special_tokens(special_tokens) img_context_token_id = tokenizer.convert_tokens_to_ids('') self.model.img_context_token_id = img_context_token_id if self.freeze_llm: self.model.language_model.requires_grad_(False) if self.freeze_visual_encoder: self.model.vision_model.requires_grad_(False) if hasattr(self.model.language_model, 'enable_input_require_grads'): self.model.language_model.enable_input_require_grads() else: self.model.language_model.get_input_embeddings( ).register_forward_hook(make_inputs_require_grad) self.gradient_checkpointing_enable() if self.use_llm_lora: self._prepare_llm_for_lora(llm_lora) if self.use_visual_encoder_lora: self._prepare_visual_encoder_for_lora(visual_encoder_lora) if pretrained_pth is not None: pretrained_state_dict = guess_load_checkpoint(pretrained_pth) self.load_state_dict(pretrained_state_dict, strict=False) print(f'Load pretrained weight from {pretrained_pth}') self._count = 0 print_log(self, logger='current') print_log('InternVL_V1_5 construction is complete', logger='current') self.transfer_to_hf = False def _add_special_tokens(self, special_tokens): num_new_tokens = self.tokenizer.add_tokens( special_tokens, special_tokens=True) if num_new_tokens > 0: self.model.language_model.resize_token_embeddings(len(self.tokenizer)) def _post_init(self, fast_pool_size=4, fast_pool=True): if fast_pool: self.fast_pool = nn.AdaptiveAvgPool2d((fast_pool_size, fast_pool_size)) return def forward(self, data, data_samples=None, mode='loss', fast_token_idx=None): if 'fast_pixel_values' in data.keys(): assert fast_token_idx is not None fast_pixel_values = data['fast_pixel_values'] if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5: if type(fast_pixel_values) is list: fast_pixel_values = [ x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values ] # b*n, c, h, w fast_concat_images = torch.cat( [image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0) else: raise NotImplementedError() else: fast_pixel_values = None fast_concat_images = None pixel_values = data['pixel_values'] if type(pixel_values) is list or pixel_values.ndim == 5: if type(pixel_values) is list: pixel_values = [ x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values ] # b*n, c, h, w concat_images = torch.cat( [image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) else: raise NotImplementedError() input_ids = data['input_ids'] position_ids = data['position_ids'] attention_mask = data['attention_mask'] # sum is 0 are text image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0 image_flags = image_flags.long() labels = data['labels'] use_cache = False if 'vp_overall_mask' not in data.keys(): vp_overall_mask = None else: vp_overall_mask = data['vp_overall_mask'] if 'prompt_masks' in data.keys(): prompt_masks = data['prompt_masks'] else: prompt_masks = None outputs = self._llm_forward( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, image_flags=image_flags, pixel_values=concat_images, labels=labels, use_cache=use_cache, output_hidden_states=True, fast_pixel_values=fast_concat_images, fast_token_idx=fast_token_idx, vp_overall_mask=vp_overall_mask, prompt_masks=prompt_masks, ) return outputs def _llm_forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, fast_pixel_values=None, fast_token_idx=None, vp_overall_mask=None, prompt_masks=None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None \ else self.model.config.use_return_dict image_flags = image_flags.squeeze(-1) # We only added the clone code here to avoid the error. input_embeds = self.model.language_model.get_input_embeddings()( input_ids).clone() if fast_pixel_values is not None: n_fast_images = fast_pixel_values.shape[0] whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0) vit_embeds = self.model.extract_feature(whole_pixel_values) vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? fast_vit_embeds = vit_embeds[:n_fast_images] # (n_fast_images, hw, c) _size = int(fast_vit_embeds.shape[1] ** 0.5) fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, fast_vit_embeds.shape[-1]) # pooling fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) # (n_fast_images, c, h, w) fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) # (n_fast_images, c, hw) fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1) vit_embeds = vit_embeds[n_fast_images:] else: vit_embeds = self.model.extract_feature(pixel_values) vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? fast_vit_embeds = None vit_embeds = vit_embeds[image_flags == 1] vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) self._count += 1 if vp_overall_mask is not None and prompt_masks is not None: vp_embeds = [] vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool() prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks] vp_overall_mask = vp_overall_mask[image_flags == 1] overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c) i_vp_img = 0 for i_img in range(len(vit_embeds)): vp_embeds.append(vit_embeds[i_img].reshape(-1, C)) if vp_overall_mask[i_img]: tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C) objects_prompt_masks = prompt_masks[i_vp_img] n_obj = len(objects_prompt_masks) tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1) objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1) vp_embeds.append(tile_vit_embeds[objects_prompt_masks]) i_vp_img += 1 vp_embeds = torch.cat(vp_embeds, dim=0) else: vp_embeds = None input_ids = input_ids.reshape(B * N) selected = (input_ids == self.model.img_context_token_id) if vp_embeds is None: try: input_embeds[selected] = vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape=' f'{input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() if n_token > len(vit_embeds): print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!") expand_ratio = n_token // len(vit_embeds) + 1 vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0) input_embeds[selected] = vit_embeds[:n_token] else: try: input_embeds[selected] = vp_embeds.reshape(-1, C) except Exception as e: vp_embeds = vp_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape=' f'{input_embeds[selected].shape}, ' f'vp_embeds.shape={vp_embeds.shape}') n_token = selected.sum() if n_token > len(vp_embeds): print(f"Wrong !!! {n_token} image tokens in text but only {len(vp_embeds)} vit embeds !!!") expand_ratio = n_token // len(vp_embeds) + 1 vp_embeds = torch.cat([vp_embeds] * expand_ratio, dim=0) input_embeds[selected] = vp_embeds[:n_token] if fast_vit_embeds is not None: selected = (input_ids == fast_token_idx) selected_tot = selected.sum().item() if selected_tot > fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]: assert selected_tot % (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) == 0 repeat_times = selected_tot / (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) fast_vit_embeds = fast_vit_embeds.repeat(int(repeat_times), 1, 1) try: input_embeds[selected] = fast_vit_embeds.reshape(-1, C) except Exception as e: fast_vit_embeds = fast_vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[fast_selected].shape=' f'{input_embeds[selected].shape}, ' f'fast_vit_embeds.shape={fast_vit_embeds.shape}') n_token = selected.sum() input_embeds[selected] = fast_vit_embeds[:n_token] input_embeds = input_embeds.reshape(B, N, C) outputs = self.model.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view( -1, self.model.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits, ) + outputs[1:] return (loss, ) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, fast_token_idx=None, fast_pixel_values=None, prompt_masks=None, vp_overall_mask=None, **generate_kwargs, ) -> torch.LongTensor: device = self.model.device assert self.model.img_context_token_id is not None if fast_pixel_values is not None: assert fast_token_idx is not None if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5: if type(fast_pixel_values) is list: fast_pixel_values = [ x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values ] # b*n, c, h, w fast_pixel_values = torch.cat( [image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0) if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features else: if type(pixel_values) is list or pixel_values.ndim == 5: if type(pixel_values) is list: pixel_values = [ x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values ] # b*n, c, h, w pixel_values = torch.cat( [image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) if fast_pixel_values is not None: n_fast_images = fast_pixel_values.shape[0] whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0) vit_embeds = self.model.extract_feature(whole_pixel_values.to(device)) # vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? fast_vit_embeds = vit_embeds[:n_fast_images] # (n_fast_images, hw, c) _size = int(fast_vit_embeds.shape[1] ** 0.5) fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, fast_vit_embeds.shape[-1]) # pooling fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) # (n_fast_images, c, h, w) fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) # (n_fast_images, c, hw) fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1) vit_embeds = vit_embeds[n_fast_images:] else: fast_vit_embeds = None vit_embeds = self.model.extract_feature(pixel_values.to(device)) image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0 image_flags = image_flags.long() vit_embeds = vit_embeds[image_flags == 1] input_embeds = self.model.language_model.get_input_embeddings()(input_ids.to(device)) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) if vp_overall_mask is not None and prompt_masks is not None: vp_embeds = [] vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool() prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks] vp_overall_mask = vp_overall_mask[image_flags == 1] overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c) i_vp_img = 0 for i_img in range(len(vit_embeds)): vp_embeds.append(vit_embeds[i_img].reshape(-1, C)) if vp_overall_mask[i_img]: tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C) objects_prompt_masks = prompt_masks[i_vp_img] n_obj = len(objects_prompt_masks) tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1) objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1) vp_embeds.append(tile_vit_embeds[objects_prompt_masks]) i_vp_img += 1 vp_embeds = torch.cat(vp_embeds, dim=0) else: vp_embeds = None input_ids = input_ids.reshape(B * N) selected = (input_ids == self.model.img_context_token_id) assert selected.sum() != 0 if vp_embeds is None: input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) else: if len(input_embeds[selected]) != len(vp_embeds.reshape(-1, C)): print("Shape mismatch, selected is {}, vp embeds is {} !!!"\ .format(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C)))) min_tokens = min(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C))) input_embeds[selected][:min_tokens] = vp_embeds.reshape(-1, C)[:min_tokens].to(input_embeds.device) else: input_embeds[selected] = vp_embeds.reshape(-1, C).to(input_embeds.device) if fast_vit_embeds is not None: selected = (input_ids == fast_token_idx) # FIXME, add repeat. assert selected.sum() != 0 input_embeds[selected] = fast_vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.model.language_model.get_input_embeddings()(input_ids) outputs = self.model.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask.to(device), generation_config=generation_config, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=True, **generate_kwargs, ) return outputs def state_dict(self, *args, **kwargs): if self.transfer_to_hf: state_dict = super(InternVL_V1_5, self).state_dict(*args, **kwargs) return state_dict else: return super().state_dict(*args, **kwargs)