Spaces:
Running
on
Zero
Running
on
Zero
''' | |
Cherry picked from Roshan's PR https://github.com/nahidalam/LLaVA/blob/1ecc141d7f20f16518f38a0d99320268305c17c3/llava/eval/maya/eval_utils.py | |
''' | |
import os | |
import sys | |
import torch | |
import requests | |
from io import BytesIO | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoConfig, TextStreamer | |
from transformers.models.cohere.tokenization_cohere_fast import CohereTokenizerFast | |
from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig | |
from constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN | |
from conversation import conv_templates, SeparatorStyle | |
from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path | |
from typing import Optional, Literal | |
def load_maya_model(model_base: str, model_path : str, projector_path : Optional[str] = None, mode = Literal['pretrained','finetuned']): | |
""" Function that helps load a trained Maya model | |
Trained Maya model can be of two flavors : | |
1. Pretrained : The model has only gone through pretraining and the changes are restricted to the projector layer | |
2. Finetuned : Model has gone through instruction finetuning post pretraining stage. This affects the whole model | |
This is a replication of the load_pretrained_model function from llava.model.builder thats specific to Cohere/Maya | |
Args: | |
model_base : Path of the base LLM model in HF. Eg: 'CohereForAI/aya-23-8B', 'meta-llama/Meta-Llama-3-8B-Instruct'. | |
This is used to instantiate the tokenizer and the model (in case of loading the pretrained model) | |
model_path : Path of the trained model repo in HF. Eg : 'nahidalam/Maya' | |
This is used to load the config file. So this path/directory should have the config.json file | |
For the finetuned model, this is used to load the final model weights as well | |
projector_path : For the pretrained model, this represents the path to the local directory which holds the mm_projector.bin file | |
model : Helps specify if this is loading a pretrained only model or a finetuned model | |
Returns: | |
model: LlavaCohereForCausalLM object | |
tokenizer: CohereTokenizerFast object | |
image_processor: | |
content_len: | |
""" | |
device_map = 'auto' | |
kwargs = {"device_map": device_map} | |
kwargs['torch_dtype'] = torch.float32 | |
# kwargs['attn_implementation'] = 'flash_attention_2' | |
## Instantiating tokenizer and model base | |
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) | |
cfg_pretrained = LlavaCohereConfig.from_pretrained(model_path) | |
if mode == 'pretrained': | |
model = LlavaCohereForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) | |
## Loading Projector layer weights | |
mm_projector_weights = torch.load(projector_path, map_location='cpu') | |
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} | |
model.load_state_dict(mm_projector_weights, strict=False) | |
else: | |
# Load model with ignore_mismatched_sizes to handle vision tower weights | |
model = LlavaCohereForCausalLM.from_pretrained( | |
model_path, | |
config=cfg_pretrained, | |
ignore_mismatched_sizes=True, # Add this to handle vision tower weights | |
**kwargs | |
) | |
## Loading image processor | |
image_processor = None | |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) | |
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) | |
if mm_use_im_patch_token: | |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) | |
if mm_use_im_start_end: | |
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Get and load vision tower | |
vision_tower = model.get_vision_tower() | |
if vision_tower is None: | |
raise ValueError("Vision tower not found in model config") | |
print(f"Loading vision tower... Is loaded: {vision_tower.is_loaded}") | |
if not vision_tower.is_loaded: | |
try: | |
vision_tower.load_model() | |
print("Vision tower loaded successfully") | |
except Exception as e: | |
print(f"Error loading vision tower: {str(e)}") | |
raise | |
if device_map != 'auto': | |
vision_tower.to(device=device_map, dtype=torch.float16) | |
image_processor = vision_tower.image_processor | |
if hasattr(model.config, "max_sequence_length"): | |
context_len = model.config.max_sequence_length | |
else: | |
context_len = 2048 | |
#maya = MayaModel(model, tokenizer, image_processor, context_len) | |
return model, tokenizer, image_processor, context_len | |
class MayaModel(object): | |
def __init__(self, model : LlavaCohereForCausalLM, tokenizer : CohereTokenizerFast, image_processor, context_length): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.image_processor = image_processor | |
self.context_length = context_length | |
def validate_inputs(self): | |
""" | |
Method to validate the inputs | |
""" | |
pass | |
def load_image(image_input): | |
""" | |
Convert various image inputs to a PIL Image object. | |
:param image_input: Can be a URL string, a file path string, or image bytes | |
:return: PIL Image object | |
""" | |
try: | |
if isinstance(image_input, str): | |
if image_input.startswith(('http://', 'https://')): | |
# Input is a URL | |
response = requests.get(image_input) | |
response.raise_for_status() # Raise an exception for bad responses | |
return Image.open(BytesIO(response.content)) | |
elif os.path.isfile(image_input): | |
# Input is a file path | |
return Image.open(image_input) | |
else: | |
raise ValueError("Invalid input: string is neither a valid URL nor a file path") | |
elif isinstance(image_input, bytes): | |
# Input is bytes | |
return Image.open(BytesIO(image_input)) | |
else: | |
raise ValueError("Invalid input type. Expected URL string, file path string, or bytes.") | |
except requests.RequestException as e: | |
raise ValueError(f"Error fetching image from URL: {e}") | |
except IOError as e: | |
raise ValueError(f"Error opening image file: {e}") | |
except Exception as e: | |
raise ValueError(f"An unexpected error occurred: {e}") | |
def get_single_sample_prediction(maya_model, image_file, user_question, temperature = 0.0, max_new_tokens = 100, conv_mode = 'aya'): | |
"""Generates the prediction for a single image-user question pair. | |
Args: | |
model (MayaModel): Trained Maya model | |
image_file : One of the following: Online image url, local image path, or image bytes | |
user_question (str): Question to be shared with LLM | |
temperature (float, optional): Temperature param for LLMs. Defaults to 0.0. | |
max_new_tokens (int, optional): Max new number of tokens generated. Defaults to 100 | |
conv_model (str, optional): Conversation model to be used. Defaults to 'aya'. | |
Returns: | |
output (str): Model's response to user question | |
""" | |
conv = conv_templates[conv_mode].copy() | |
roles = conv.roles | |
model = maya_model.model | |
tokenizer = maya_model.tokenizer | |
image_processor = maya_model.image_processor | |
image = load_image(image_file) | |
image_size = image.size | |
image_tensor = process_images([image], image_processor, model.config) | |
if type(image_tensor) is list: | |
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] | |
else: | |
image_tensor = image_tensor.to(model.device, dtype=torch.float16) | |
inp = user_question | |
if image is not None: | |
# first message | |
if model.config.mm_use_im_start_end: | |
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp | |
else: | |
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp | |
# image = None | |
conv.append_message(conv.roles[0], inp) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor, | |
image_sizes=[image_size], | |
do_sample=True if temperature > 0 else False, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
use_cache=True) | |
outputs = tokenizer.decode(output_ids[0]).strip() | |
return outputs | |