File size: 1,070 Bytes
3a5d5bb
 
 
 
 
 
 
c828564
3a5d5bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
080a129
3a5d5bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image    
import requests

class EndpointHandler():
    def __init__(self, path=""):
        self.pipe = pipeline("image-to-text", model=path)
        
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        inputs = data.pop('inputs', data)
        url = inputs.get('url')
        prompt = inputs.get('prompt')
        max_new_tokens = inputs.get('max_new_tokens', 1000)
        
        image = Image.open(requests.get(url, stream=True).raw)
        prompt = f'user<image>\n{prompt}\nassistant:'

        results = self.pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": max_new_tokens})
        result = results[0]
        result['generated_text'] = result['generated_text'].replace(prompt.replace('<image>', '') + ' ', '')
        return result