File size: 1,026 Bytes
7c2538e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
class EndpointHandler():
def __init__(self, path=""):
disable_torch_init()
self.processor = LlavaNextProcessor.from_pretrained(path, use_fast=False)
self.model = LlavaNextForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True
)
self.model.to("cuda:0")
def __call__(self, data):
image_encoded = data.pop("inputs", data)
prompt = data["text"]
image = self.decode_base64_image(image_encoded)
if image.mode != "RGB":
image = image.convert("RGB")
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda:0")
# autoregressively complete prompt
output = self.model.generate(**inputs, max_new_tokens=500)
return processor.decode(output[0], skip_special_tokens=True)
|