luisresende13's picture
Update handler.py
080a129 verified
raw
history blame contribute delete
No virus
1.07 kB
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
class EndpointHandler():
def __init__(self, path=""):
self.pipe = pipeline("image-to-text", model=path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop('inputs', data)
url = inputs.get('url')
prompt = inputs.get('prompt')
max_new_tokens = inputs.get('max_new_tokens', 1000)
image = Image.open(requests.get(url, stream=True).raw)
prompt = f'user<image>\n{prompt}\nassistant:'
results = self.pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": max_new_tokens})
result = results[0]
result['generated_text'] = result['generated_text'].replace(prompt.replace('<image>', '') + ' ', '')
return result