File size: 2,362 Bytes
da079a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates

from PIL import Image
import requests
import copy
import torch
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import spaces
#model_path = "/scratch/TecManDep/A_Models/llava-v1.6-vicuna-7b"
#conv_template = "vicuna_v1" # Make sure you use correct chat template for different models

def load_llava_model():
    model_path =  "Lin-Chen/open-llava-next-llama3-8b"
    conv_template = "llama_v3"
    model_name = get_model_name_from_path(model_path)
    device = "cuda"
    device_map = "auto"
    tokenizer, model, image_processor, max_length = load_pretrained_model(
        model_path, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args

    model.eval()
    model.tie_weights()
    return tokenizer, model, image_processor, conv_template

tokenizer_llava, model_llava, image_processor_llava, conv_template_llava = load_llava_model()

@spaces.GPU
def inference():
    image = Image.open("assets/example.jpg").convert("RGB")
    device = "cuda"
    image_tensor = process_images([image], image_processor_llava, model_llava.config)
    image_tensor = image_tensor.to(dtype=torch.float16, device=device)

    prompt = """<image>What is in the figure?"""
    conv = conv_templates[conv_template_llava].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt_question, tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
    image_sizes = [image.size]
    print(input_ids.shape, image_tensor.shape)
    with torch.inference_mode():
        cont = model_llava.generate(
            input_ids,
            images=image_tensor,
            image_sizes=image_sizes,
            do_sample=False,
            temperature=0,
            max_new_tokens=256,
            use_cache=True
        )
    text_outputs = tokenizer_llava.batch_decode(cont, skip_special_tokens=True)
    print(text_outputs)
    return text_outputs

if __name__ == "__main__":
    inference()