import torch from transformers import ( AutoTokenizer, CLIPImageProcessor, WhisperProcessor, WhisperForConditionalGeneration, ) from .model import LlavaPhiForCausalLM from .conversation import conv_templates, SeparatorStyle IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" class AudioLanguageConnector: def __init__(self, projection_dim): model_name = "microsoft/phi-2" self.phi2_tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token self.phi2_tokenizer.max_length = projection_dim def __call__(self, text): text = f" {text} " tokens = self.phi2_tokenizer( text, return_tensors="pt", return_attention_mask=False ) return tokens class WhisperWithProjection: def __init__(self, projection_dim, device): self.device = device self.processor = WhisperProcessor.from_pretrained( "openai/whisper-tiny", device_map=device ) self.model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-tiny", device_map=device ) self.model.config.forced_decoder_ids = None # self.audio_language_connector = AudioLanguageConnector(projection_dim) def __call__(self, audio): input_features = self.processor( audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt" ).input_features # generate token ids predicted_ids = self.model.generate(input_features.to(self.device)) # decode token ids to text transcription = self.processor.batch_decode( predicted_ids, skip_special_tokens=True ) # audio_embeddings = self.audio_language_connector(transcription) return transcription class MultiModalPhi2: def __init__( self, modelname_or_path="RaviNaik/Llava-Phi2", temperature=0.2, max_new_tokens=1024, device="cuda:0", ): self.model_name = modelname_or_path self.temperature = temperature self.max_new_tokens = max_new_tokens self.device = device self.disable_torch_init() self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device) self.load_pretrained_model() def disable_torch_init(self): """ Disable the redundant torch default initialization to accelerate model creation. """ setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def load_pretrained_model(self): self.model = LlavaPhiForCausalLM.from_pretrained( self.model_name, device_map=self.device ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.image_processor = CLIPImageProcessor.from_pretrained(self.model_name) mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr( self.model.config, "mm_use_im_patch_token", True ) if mm_use_im_patch_token: self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: self.tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) def tokenizer_image_token( self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None, ): prompt_chunks = [ tokenizer(chunk).input_ids for chunk in prompt.split("") ] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if ( len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id ): offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == "pt": return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f"Unsupported tensor type: {return_tensors}") return input_ids def __call__(self, text, audio, image): if text is None: text = "" if image is not None: qs = ( DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + text ) conv = conv_templates["phi-2_v0"].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = self.tokenizer_image_token( prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0) image_tensor = self.image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ].to(self.device) else: qs = text conv = conv_templates["phi-2_v0"].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"] image_tensor = None if audio is not None: audio_transcript = self.whisper_w_proj(audio) audio_embed = self.tokenizer(audio_transcript, return_tensors="pt")[ "input_ids" ] input_ids = torch.concat([input_ids, audio_embed], dim=1) input_ids = input_ids.to(self.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 with torch.inference_mode(): if image is not None: output_ids = self.model.generate( input_ids, images=image_tensor, do_sample=True, temperature=self.temperature, max_new_tokens=self.max_new_tokens, eos_token_id=self.tokenizer.eos_token_id, # End of sequence token pad_token_id=self.tokenizer.eos_token_id, # Pad token use_cache=True, ) else: output_ids = self.model.generate( input_ids, do_sample=True, temperature=self.temperature, max_new_tokens=self.max_new_tokens, eos_token_id=self.tokenizer.eos_token_id, # End of sequence token pad_token_id=self.tokenizer.eos_token_id, # Pad token use_cache=True, ) input_token_len = input_ids.shape[1] n_diff_input_output = ( (input_ids != output_ids[:, :input_token_len]).sum().item() ) if n_diff_input_output > 0: print( f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" ) outputs = self.tokenizer.batch_decode( output_ids[:, input_token_len:], skip_special_tokens=True )[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[: -len(stop_str)] outputs = outputs.strip() return outputs