Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,119 Bytes
e968589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
'''
Cherry picked from Roshan's PR https://github.com/nahidalam/LLaVA/blob/1ecc141d7f20f16518f38a0d99320268305c17c3/llava/eval/maya/eval_utils.py
'''
import os
import sys
import torch
import requests
from io import BytesIO
from PIL import Image
from transformers import AutoTokenizer, AutoConfig, TextStreamer
from transformers.models.cohere.tokenization_cohere_fast import CohereTokenizerFast
from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig
from constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
from conversation import conv_templates, SeparatorStyle
from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from typing import Optional, Literal
def load_maya_model(model_base: str, model_path : str, projector_path : Optional[str] = None, mode = Literal['pretrained','finetuned']):
""" Function that helps load a trained Maya model
Trained Maya model can be of two flavors :
1. Pretrained : The model has only gone through pretraining and the changes are restricted to the projector layer
2. Finetuned : Model has gone through instruction finetuning post pretraining stage. This affects the whole model
This is a replication of the load_pretrained_model function from llava.model.builder thats specific to Cohere/Maya
Args:
model_base : Path of the base LLM model in HF. Eg: 'CohereForAI/aya-23-8B', 'meta-llama/Meta-Llama-3-8B-Instruct'.
This is used to instantiate the tokenizer and the model (in case of loading the pretrained model)
model_path : Path of the trained model repo in HF. Eg : 'nahidalam/Maya'
This is used to load the config file. So this path/directory should have the config.json file
For the finetuned model, this is used to load the final model weights as well
projector_path : For the pretrained model, this represents the path to the local directory which holds the mm_projector.bin file
model : Helps specify if this is loading a pretrained only model or a finetuned model
Returns:
model: LlavaCohereForCausalLM object
tokenizer: CohereTokenizerFast object
image_processor:
content_len:
"""
device_map = 'auto'
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float32
# kwargs['attn_implementation'] = 'flash_attention_2'
## Instantiating tokenizer and model base
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = LlavaCohereConfig.from_pretrained(model_path)
if mode == 'pretrained':
model = LlavaCohereForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
## Loading Projector layer weights
mm_projector_weights = torch.load(projector_path, map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)
else:
# Load model with ignore_mismatched_sizes to handle vision tower weights
model = LlavaCohereForCausalLM.from_pretrained(
model_path,
config=cfg_pretrained,
ignore_mismatched_sizes=True, # Add this to handle vision tower weights
**kwargs
)
## Loading image processor
image_processor = None
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
# Get and load vision tower
vision_tower = model.get_vision_tower()
if vision_tower is None:
raise ValueError("Vision tower not found in model config")
print(f"Loading vision tower... Is loaded: {vision_tower.is_loaded}")
if not vision_tower.is_loaded:
try:
vision_tower.load_model()
print("Vision tower loaded successfully")
except Exception as e:
print(f"Error loading vision tower: {str(e)}")
raise
if device_map != 'auto':
vision_tower.to(device=device_map, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
#maya = MayaModel(model, tokenizer, image_processor, context_len)
return model, tokenizer, image_processor, context_len
class MayaModel(object):
def __init__(self, model : LlavaCohereForCausalLM, tokenizer : CohereTokenizerFast, image_processor, context_length):
self.model = model
self.tokenizer = tokenizer
self.image_processor = image_processor
self.context_length = context_length
def validate_inputs(self):
"""
Method to validate the inputs
"""
pass
def load_image(image_input):
"""
Convert various image inputs to a PIL Image object.
:param image_input: Can be a URL string, a file path string, or image bytes
:return: PIL Image object
"""
try:
if isinstance(image_input, str):
if image_input.startswith(('http://', 'https://')):
# Input is a URL
response = requests.get(image_input)
response.raise_for_status() # Raise an exception for bad responses
return Image.open(BytesIO(response.content))
elif os.path.isfile(image_input):
# Input is a file path
return Image.open(image_input)
else:
raise ValueError("Invalid input: string is neither a valid URL nor a file path")
elif isinstance(image_input, bytes):
# Input is bytes
return Image.open(BytesIO(image_input))
else:
raise ValueError("Invalid input type. Expected URL string, file path string, or bytes.")
except requests.RequestException as e:
raise ValueError(f"Error fetching image from URL: {e}")
except IOError as e:
raise ValueError(f"Error opening image file: {e}")
except Exception as e:
raise ValueError(f"An unexpected error occurred: {e}")
def get_single_sample_prediction(maya_model, image_file, user_question, temperature = 0.0, max_new_tokens = 100, conv_mode = 'aya'):
"""Generates the prediction for a single image-user question pair.
Args:
model (MayaModel): Trained Maya model
image_file : One of the following: Online image url, local image path, or image bytes
user_question (str): Question to be shared with LLM
temperature (float, optional): Temperature param for LLMs. Defaults to 0.0.
max_new_tokens (int, optional): Max new number of tokens generated. Defaults to 100
conv_model (str, optional): Conversation model to be used. Defaults to 'aya'.
Returns:
output (str): Model's response to user question
"""
conv = conv_templates[conv_mode].copy()
roles = conv.roles
model = maya_model.model
tokenizer = maya_model.tokenizer
image_processor = maya_model.image_processor
image = load_image(image_file)
image_size = image.size
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
inp = user_question
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
# image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True)
outputs = tokenizer.decode(output_ids[0]).strip()
return outputs
|