|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token |
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
from llava.conversation import conv_templates |
|
|
|
from PIL import Image |
|
import requests |
|
import copy |
|
import torch |
|
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path |
|
import spaces |
|
|
|
|
|
|
|
def load_llava_model(): |
|
model_path = "Lin-Chen/open-llava-next-llama3-8b" |
|
conv_template = "llama_v3" |
|
model_name = get_model_name_from_path(model_path) |
|
device = "cuda" |
|
device_map = "auto" |
|
tokenizer, model, image_processor, max_length = load_pretrained_model( |
|
model_path, None, model_name, device_map=device_map) |
|
|
|
model.eval() |
|
model.tie_weights() |
|
return tokenizer, model, image_processor, conv_template |
|
|
|
tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model() |
|
|
|
@spaces.GPU |
|
def inference(): |
|
image = Image.open("assets/example.jpg").convert("RGB") |
|
device = "cuda" |
|
image_tensor = process_images([image], image_processor_llava, model_llava.config) |
|
image_tensor = image_tensor.to(dtype=torch.float16, device=device) |
|
|
|
prompt = """<image>What is in the figure?""" |
|
conv = conv_templates[conv_template_llava].copy() |
|
conv.append_message(conv.roles[0], prompt) |
|
conv.append_message(conv.roles[1], None) |
|
prompt_question = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
image_sizes = [image.size] |
|
print(input_ids.shape, image_tensor.shape) |
|
with torch.inference_mode(): |
|
cont = model_llava.generate( |
|
input_ids, |
|
images=image_tensor, |
|
image_sizes=image_sizes, |
|
do_sample=False, |
|
temperature=0, |
|
max_new_tokens=256, |
|
use_cache=True |
|
) |
|
text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True) |
|
print(text_outputs) |
|
return text_outputs |
|
|
|
if __name__ == "__main__": |
|
inference() |