rbanfield commited on
Commit
418414b
1 Parent(s): 8182cb1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -12
handler.py CHANGED
@@ -13,20 +13,24 @@ class EndpointHandler():
13
  self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
14
 
15
  def __call__(self, data):
16
- inputs = data.pop("inputs", None)
17
- text_input = inputs["text"] if "text" in inputs else None
18
- image_input = inputs["image"] if "image" in inputs else None
19
 
 
 
 
 
 
 
 
 
 
20
  if text_input:
21
- processor = self.processor(text=text_input, return_tensors="pt", padding=True)
22
- processor.to("cpu")
23
  with torch.no_grad():
24
- return self.model.get_text_features(**processor).tolist()
25
- elif image_input:
26
- image = Image.open(BytesIO(base64.b64decode(image_input)))
27
- processor = self.processor(images=image, return_tensors="pt")
28
- processor.to("cpu")
29
  with torch.no_grad():
30
- return self.model.get_image_features(**processor).tolist()
31
  else:
32
- return None
 
13
  self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
14
 
15
  def __call__(self, data):
 
 
 
16
 
17
+ text_input = None
18
+ if isinstance(data, dict):
19
+ inputs = data.pop("inputs", None)
20
+ text_input = inputs.get('text',None)
21
+ image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None
22
+ else:
23
+ # assuming its an image sent via binary
24
+ image_data = BytesIO(data)
25
+
26
  if text_input:
27
+ processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
 
28
  with torch.no_grad():
29
+ return {"embeddings": self.model.get_text_features(**processor).tolist()}
30
+ elif image_data:
31
+ image = Image.open(image_data)
32
+ processor = self.processor(images=image, return_tensors="pt").to(device)
 
33
  with torch.no_grad():
34
+ return {"embeddings": self.model.get_image_features(**processor).tolist()}
35
  else:
36
+ return {"embeddings": None}