from io import BytesIO import base64 from PIL import Image import torch from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14") self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14") def __call__(self, data): text_input = data.pop("text", None) image_input = data.pop("image", None) if text_input: processor = self.processor(text=text_input, return_tensors="pt", padding=True) with torch.no_grad(): return self.text_model(**processor).pooler_output.tolist() elif image_input: image = Image.open(BytesIO(base64.b64decode(image_input))) processor = self.processor(images=image, return_tensors="pt") with torch.no_grad(): return self.image_model(**processor).image_embeds.tolist() else: return None