facenet-embeddings / handler.py
squanchd's picture
should work
7a2a4a4
raw
history blame contribute delete
No virus
1.04 kB
from typing import Dict, List, Any
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import base64
from facenet_pytorch import MTCNN, InceptionResnetV1
class EndpointHandler():
def __init__(self, path=""):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.mtcnn = MTCNN(device=self.device)
self.resnet = InceptionResnetV1(pretrained='vggface2', device=self.device).eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
imageData = data.get("inputs").get("image")
image = Image.open(BytesIO(base64.b64decode(imageData)))
face_batch = self.mtcnn([image])
face_batch = [i for i in face_batch if i is not None]
if face_batch:
aligned = torch.stack(face_batch)
if self.device.type == "cuda":
aligned = aligned.to(self.device)
embeddings = self.resnet(aligned).detach().cpu()
return embeddings.tolist()
else: return None