|
from typing import Dict, List, Any |
|
from ultralytics import YOLO |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
|
|
class LinearClassifier(torch.nn.Module): |
|
def __init__(self, input_dim=384, output_dim=7): |
|
super(LinearClassifier, self).__init__() |
|
|
|
self.linear = torch.nn.Linear(input_dim, output_dim) |
|
self.linear.weight.data.normal_(mean=0.0, std=0.01) |
|
self.linear.bias.data.zero_() |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
|
self.dinov2_vits14.to(self.device) |
|
print('Successfully load dinov2_vits14 model') |
|
|
|
self.yolov8_model = YOLO(os.path.join(path, 'yolov8_2023-07-19_yolov8m.pt')) |
|
|
|
self.linear_model = LinearClassifier() |
|
self.linear_model.load_state_dict(torch.load(os.path.join(path, 'linear_2023-07-18_v0.2.pt'))) |
|
self.linear_model.eval() |
|
|
|
self.transform_image = T.Compose([ |
|
T.ToTensor(), |
|
T.Resize(244), |
|
T.CenterCrop(224), |
|
T.Normalize([0.5], [0.5]) |
|
]) |
|
|
|
with open(os.path.join(path, 'labels.txt'), 'r') as f: |
|
self.labels = f.read().split(',') |
|
|
|
self.name_en2vi = { |
|
"loggerhead": "Quản đồng", |
|
"green": "Vích", |
|
"leatherback": "Rùa da", |
|
"hawksbill": "Đồi mồi", |
|
"kemp_ridley": "Vích Kemp", |
|
"olive_ridley": "Đồi mồi dứa", |
|
"flatback": "Rùa lưng phẳng" |
|
} |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
result = self.yolov8_model(data['inputs']) |
|
|
|
img = result[0].orig_img[:,:,::-1] |
|
H, W, _ = img.shape |
|
annotated = img.copy() |
|
|
|
try: |
|
x1, y1, x2, y2 = result[0].boxes.xyxy.numpy().astype('int')[0] |
|
if result[0].boxes.conf[0].item() < 0.75: |
|
return img.tolist(), "🤔 Hmm... Vích AI không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé." |
|
else: |
|
annotated = result[0].plot(labels=False, conf=False)[:,:,::-1] |
|
except: |
|
|
|
return img.tolist(), "🤔 Hmm... Vích AI không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé." |
|
|
|
h, w = y2-y1, x2-x1 |
|
offset = abs(h-w) // 2 |
|
if h > w: |
|
x1 = max(x1 - offset, 0) |
|
x2 = min(x2 + offset, W) |
|
else: |
|
y1 = max(y1 - offset, 0) |
|
y2 = min(y2 + offset, H) |
|
cropped = img[y1:y2, x1:x2] |
|
|
|
new_image = self.transform_image(Image.fromarray(cropped))[:3].unsqueeze(0) |
|
embedding = self.dinov2_vits14(new_image.to(self.device)) |
|
prediction = self.linear_model(embedding) |
|
percentage = nn.Softmax(dim=1)(prediction).detach().numpy().round(2)[0].tolist() |
|
result = {} |
|
|
|
for i in range(len(self.labels)): |
|
result[self.name_en2vi[self.labels[i]]] = percentage[i] |
|
|
|
|
|
return annotated.tolist(), result |