''' Cherry picked from Roshan's PR https://github.com/nahidalam/LLaVA/blob/1ecc141d7f20f16518f38a0d99320268305c17c3/llava/eval/maya/eval_utils.py ''' import os import sys import torch import requests from io import BytesIO from PIL import Image from transformers import AutoTokenizer, AutoConfig, TextStreamer from transformers.models.cohere.tokenization_cohere_fast import CohereTokenizerFast from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig from constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN from conversation import conv_templates, SeparatorStyle from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from typing import Optional, Literal def load_maya_model(model_base: str, model_path : str, projector_path : Optional[str] = None, mode = Literal['pretrained','finetuned']): """ Function that helps load a trained Maya model Trained Maya model can be of two flavors : 1. Pretrained : The model has only gone through pretraining and the changes are restricted to the projector layer 2. Finetuned : Model has gone through instruction finetuning post pretraining stage. This affects the whole model This is a replication of the load_pretrained_model function from llava.model.builder thats specific to Cohere/Maya Args: model_base : Path of the base LLM model in HF. Eg: 'CohereForAI/aya-23-8B', 'meta-llama/Meta-Llama-3-8B-Instruct'. This is used to instantiate the tokenizer and the model (in case of loading the pretrained model) model_path : Path of the trained model repo in HF. Eg : 'nahidalam/Maya' This is used to load the config file. So this path/directory should have the config.json file For the finetuned model, this is used to load the final model weights as well projector_path : For the pretrained model, this represents the path to the local directory which holds the mm_projector.bin file model : Helps specify if this is loading a pretrained only model or a finetuned model Returns: model: LlavaCohereForCausalLM object tokenizer: CohereTokenizerFast object image_processor: content_len: """ device_map = 'auto' kwargs = {"device_map": device_map} kwargs['torch_dtype'] = torch.float32 # kwargs['attn_implementation'] = 'flash_attention_2' ## Instantiating tokenizer and model base tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = LlavaCohereConfig.from_pretrained(model_path) if mode == 'pretrained': model = LlavaCohereForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) ## Loading Projector layer weights mm_projector_weights = torch.load(projector_path, map_location='cpu') mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} model.load_state_dict(mm_projector_weights, strict=False) else: # Load model with ignore_mismatched_sizes to handle vision tower weights model = LlavaCohereForCausalLM.from_pretrained( model_path, config=cfg_pretrained, ignore_mismatched_sizes=True, # Add this to handle vision tower weights **kwargs ) ## Loading image processor image_processor = None mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) # Get and load vision tower vision_tower = model.get_vision_tower() if vision_tower is None: raise ValueError("Vision tower not found in model config") print(f"Loading vision tower... Is loaded: {vision_tower.is_loaded}") if not vision_tower.is_loaded: try: vision_tower.load_model() print("Vision tower loaded successfully") except Exception as e: print(f"Error loading vision tower: {str(e)}") raise if device_map != 'auto': vision_tower.to(device=device_map, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 #maya = MayaModel(model, tokenizer, image_processor, context_len) return model, tokenizer, image_processor, context_len class MayaModel(object): def __init__(self, model : LlavaCohereForCausalLM, tokenizer : CohereTokenizerFast, image_processor, context_length): self.model = model self.tokenizer = tokenizer self.image_processor = image_processor self.context_length = context_length def validate_inputs(self): """ Method to validate the inputs """ pass def load_image(image_input): """ Convert various image inputs to a PIL Image object. :param image_input: Can be a URL string, a file path string, or image bytes :return: PIL Image object """ try: if isinstance(image_input, str): if image_input.startswith(('http://', 'https://')): # Input is a URL response = requests.get(image_input) response.raise_for_status() # Raise an exception for bad responses return Image.open(BytesIO(response.content)) elif os.path.isfile(image_input): # Input is a file path return Image.open(image_input) else: raise ValueError("Invalid input: string is neither a valid URL nor a file path") elif isinstance(image_input, bytes): # Input is bytes return Image.open(BytesIO(image_input)) else: raise ValueError("Invalid input type. Expected URL string, file path string, or bytes.") except requests.RequestException as e: raise ValueError(f"Error fetching image from URL: {e}") except IOError as e: raise ValueError(f"Error opening image file: {e}") except Exception as e: raise ValueError(f"An unexpected error occurred: {e}") def get_single_sample_prediction(maya_model, image_file, user_question, temperature = 0.0, max_new_tokens = 100, conv_mode = 'aya'): """Generates the prediction for a single image-user question pair. Args: model (MayaModel): Trained Maya model image_file : One of the following: Online image url, local image path, or image bytes user_question (str): Question to be shared with LLM temperature (float, optional): Temperature param for LLMs. Defaults to 0.0. max_new_tokens (int, optional): Max new number of tokens generated. Defaults to 100 conv_model (str, optional): Conversation model to be used. Defaults to 'aya'. Returns: output (str): Model's response to user question """ conv = conv_templates[conv_mode].copy() roles = conv.roles model = maya_model.model tokenizer = maya_model.tokenizer image_processor = maya_model.image_processor image = load_image(image_file) image_size = image.size image_tensor = process_images([image], image_processor, model.config) if type(image_tensor) is list: image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] else: image_tensor = image_tensor.to(model.device, dtype=torch.float16) inp = user_question if image is not None: # first message if model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp # image = None conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, image_sizes=[image_size], do_sample=True if temperature > 0 else False, temperature=temperature, max_new_tokens=max_new_tokens, streamer=streamer, use_cache=True) outputs = tokenizer.decode(output_ids[0]).strip() return outputs