File size: 5,894 Bytes
8e542dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import cv2
import copy
import re
import torch
import numpy as np

from pathlib import Path
from facelib.detection.yolov5face.models.yolo import Model
from facelib.detection.yolov5face.utils.datasets import letterbox
from facelib.detection.yolov5face.utils.general import (
    check_img_size,
    non_max_suppression_face,
    scale_coords,
    scale_coords_landmarks,
)

# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
    torch.__version__)[0][:3])] >= [1, 9, 0]


def isListempty(inList):
    if isinstance(inList, list): # Is a list
        return all(map(isListempty, inList))
    return False # Not a list

class YoloDetector:
    def __init__(
        self,
        config_name,
        min_face=10,
        target_size=None,
        device='cuda',
    ):
        """
        config_name: name of .yaml config with network configuration from models/ folder.
        min_face : minimal face size in pixels.
        target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
                    None for original resolution.
        """
        self._class_path = Path(__file__).parent.absolute()
        self.target_size = target_size
        self.min_face = min_face
        self.detector = Model(cfg=config_name)
        self.device = device


    def _preprocess(self, imgs):
        """
        Preprocessing image before passing through the network. Resize and conversion to torch tensor.
        """
        pp_imgs = []
        for img in imgs:
            h0, w0 = img.shape[:2]  # orig hw
            if self.target_size:
                r = self.target_size / min(h0, w0)  # resize image to img_size
                if r < 1:
                    img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)

            imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max())  # check img_size
            img = letterbox(img, new_shape=imgsz)[0]
            pp_imgs.append(img)
        pp_imgs = np.array(pp_imgs)
        pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
        pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
        pp_imgs = pp_imgs.float()  # uint8 to fp16/32
        return pp_imgs / 255.0  # 0 - 255 to 0.0 - 1.0

    def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
        """
        Postprocessing of raw pytorch model output.
        Returns:
            bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
            points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
        """
        bboxes = [[] for _ in range(len(origimgs))]
        landmarks = [[] for _ in range(len(origimgs))]

        pred = non_max_suppression_face(pred, conf_thres, iou_thres)

        for image_id, origimg in enumerate(origimgs):
            img_shape = origimg.shape
            image_height, image_width = img_shape[:2]
            gn = torch.tensor(img_shape)[[1, 0, 1, 0]]  # normalization gain whwh
            gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]]  # normalization gain landmarks
            det = pred[image_id].cpu()
            scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
            scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()

            for j in range(det.size()[0]):
                box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
                box = list(
                    map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
                )
                if box[3] - box[1] < self.min_face:
                    continue
                lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
                lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
                lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
                bboxes[image_id].append(box)
                landmarks[image_id].append(lm)
        return bboxes, landmarks

    def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
        """
        Get bbox coordinates and keypoints of faces on original image.
        Params:
            imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
            conf_thres: confidence threshold for each prediction
            iou_thres: threshold for NMS (filter of intersecting bboxes)
        Returns:
            bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
            points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
        """
        # Pass input images through face detector
        images = imgs if isinstance(imgs, list) else [imgs]
        images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
        origimgs = copy.deepcopy(images)

        images = self._preprocess(images)
        
        if IS_HIGH_VERSION:
            with torch.inference_mode():  # for pytorch>=1.9 
                pred = self.detector(images)[0]
        else:
            with torch.no_grad():  # for pytorch<1.9
                pred = self.detector(images)[0]

        bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)

        # return bboxes, points
        if not isListempty(points):
            bboxes = np.array(bboxes).reshape(-1,4)
            points = np.array(points).reshape(-1,10)
            padding = bboxes[:,0].reshape(-1,1)
            return np.concatenate((bboxes, padding, points), axis=1)
        else:
            return None

    def __call__(self, *args):
        return self.predict(*args)