luisresende13's picture
Upload 2 files
3a5d5bb verified
raw
history blame
1.13 kB
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
class EndpointHandler():
def __init__(self, path=""):
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
self.pipe = pipeline("image-to-text", model=model_id)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop('inputs', data)
url = inputs.get('url')
prompt = inputs.get('prompt')
max_new_tokens = inputs.get('max_new_tokens', 1000)
image = Image.open(requests.get(url, stream=True).raw)
prompt = f'user<image>\n{prompt}\nassistant:'
results = self.pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": max_new_tokens})
result = results[0]
# result['generated_text'] = result['generated_text'].replace(prompt.replace('<image>', ''), '')
return result