idefics2-8b-ocr / handler.py
huz-relay's picture
Add user input
1b88ea1
from typing import Any, Dict, List
from transformers import Idefics2Processor, Idefics2ForConditionalGeneration
import torch
import logging
from PIL import Image
import requests
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
self.logger = logging.getLogger()
self.logger.addHandler(logging.StreamHandler())
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = Idefics2Processor.from_pretrained(path)
self.model = Idefics2ForConditionalGeneration.from_pretrained(path)
self.model.to(self.device)
self.logger.info("Initialisation finished!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
"""image = data.pop("inputs", data)
self.logger.info("image")
# process image
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
self.logger.info("inputs")
self.logger.info(f"{inputs.input_ids}")
generated_ids = self.model.generate(**inputs)
self.logger.info("generated")
# run prediction
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
self.logger.info("decoded")"""
url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
image_1 = data.pop("inputs", data)
image_2 = Image.open(requests.get(url_2, stream=True).raw)
images = [image_1, image_2]
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What’s the difference between these two images?",
},
{"type": "image"},
{"type": "image"},
],
}
]
self.model.to(self.device)
# at inference time, one needs to pass `add_generation_prompt=True` in order to make sure the model completes the prompt
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
self.logger.info(text)
# 'User: What’s the difference between these two images?<image><image><end_of_utterance>\nAssistant:'
inputs = self.processor(images=images, text=text, return_tensors="pt").to(
self.device
)
self.logger.info("inputs")
generated_text = self.model.generate(**inputs, max_new_tokens=500)
self.logger.info("generated")
generated_text = self.processor.batch_decode(
generated_text, skip_special_tokens=True
)[0]
self.logger.info(f"Generated text: {generated_text}")
# decode output
return generated_text