Shelly / handler.py
panda1835's picture
Update handler.py
8d956c1
from typing import Dict, List, Any
from ultralytics import YOLO
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
class LinearClassifier(torch.nn.Module):
def __init__(self, input_dim=384, output_dim=7):
super(LinearClassifier, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()
def forward(self, x):
return self.linear(x)
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
self.dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
self.dinov2_vits14.to(self.device)
print('Successfully load dinov2_vits14 model')
self.yolov8_model = YOLO(os.path.join(path, 'yolov8_2023-07-19_yolov8m.pt'))
self.linear_model = LinearClassifier()
self.linear_model.load_state_dict(torch.load(os.path.join(path, 'linear_2023-07-18_v0.2.pt')))
self.linear_model.eval()
self.transform_image = T.Compose([
T.ToTensor(),
T.Resize(244),
T.CenterCrop(224),
T.Normalize([0.5], [0.5])
])
with open(os.path.join(path, 'labels.txt'), 'r') as f:
self.labels = f.read().split(',') # loggerhead,green,leatherback...
self.name_en2vi = {
"loggerhead": "Quản đồng",
"green": "Vích",
"leatherback": "Rùa da",
"hawksbill": "Đồi mồi",
"kemp_ridley": "Vích Kemp",
"olive_ridley": "Đồi mồi dứa",
"flatback": "Rùa lưng phẳng"
}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# Get the prediction
result = self.yolov8_model(data['inputs'])
# Get the original image with channel shifted
img = result[0].orig_img[:,:,::-1]
H, W, _ = img.shape
annotated = img.copy()
# Modify crop so that it is square
try:
x1, y1, x2, y2 = result[0].boxes.xyxy.numpy().astype('int')[0]
if result[0].boxes.conf[0].item() < 0.75: # if low in confidence
return img.tolist(), "🤔 Hmm... Vích AI không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé."
else:
annotated = result[0].plot(labels=False, conf=False)[:,:,::-1]
except: # in case there is no detection
# x1, y1, x2, y2 = 0, 0, W, H
return img.tolist(), "🤔 Hmm... Vích AI không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé."
h, w = y2-y1, x2-x1
offset = abs(h-w) // 2
if h > w:
x1 = max(x1 - offset, 0)
x2 = min(x2 + offset, W)
else:
y1 = max(y1 - offset, 0)
y2 = min(y2 + offset, H)
cropped = img[y1:y2, x1:x2]
new_image = self.transform_image(Image.fromarray(cropped))[:3].unsqueeze(0)
embedding = self.dinov2_vits14(new_image.to(self.device))
prediction = self.linear_model(embedding)
percentage = nn.Softmax(dim=1)(prediction).detach().numpy().round(2)[0].tolist()
result = {}
for i in range(len(self.labels)):
result[self.name_en2vi[self.labels[i]]] = percentage[i]
# Return the annotated original image with the square cropped and result dict
return annotated.tolist(), result