facenet-embeddings / handler.py
squanchd's picture
should work
e7ed489
raw
history blame
1.03 kB
from typing import Dict, List, Any
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import base64
from facenet_pytorch import MTCNN, InceptionResnetV1
class EndpointHandler():
def __init__(self, path=""):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.mtcnn = MTCNN(device=self.device)
self.resnet = InceptionResnetV1(pretrained='vggface2', device=self.device).eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
imageData = data.pop("image", data)
image = Image.open(BytesIO(base64.b64decode(imageData)))
face_batch = self.mtcnn([image])
face_batch = [i for i in face_batch if i is not None]
if face_batch:
aligned = torch.stack(face_batch)
if self.device.type == "cuda":
aligned = aligned.to(self.device)
embeddings = self.resnet(aligned).detach().cpu()
return embeddings.tolist()
else: return None