SunderAli17 commited on
Commit
4ffd2b7
1 Parent(s): 7f1b096

Create utils.py

Browse files
Files changed (1) hide show
  1. utils/utils.py +51 -0
utils/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from einops import rearrange
4
+ from kornia.geometry.transform.crop2d import warp_affine
5
+
6
+ from utils.matlab_cp2tform import get_similarity_transform_for_cv2
7
+ from torchvision.transforms import Pad
8
+
9
+ REFERNCE_FACIAL_POINTS_RELATIVE = np.array([[38.29459953, 51.69630051],
10
+ [72.53179932, 51.50139999],
11
+ [56.02519989, 71.73660278],
12
+ [41.54930115, 92.3655014],
13
+ [70.72990036, 92.20410156]
14
+ ]) / 112 # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
15
+
16
+
17
+ @torch.no_grad()
18
+ def detect_face(images: torch.Tensor, mtcnn: torch.nn.Module) -> torch.Tensor:
19
+ """
20
+ Detect faces in the images using MTCNN. If no face is detected, use the whole image.
21
+ """
22
+ images = rearrange(images, "b c h w -> b h w c")
23
+ if images.dtype != torch.uint8:
24
+ images = ((images * 0.5 + 0.5) * 255).type(torch.uint8) # Unnormalize
25
+
26
+ _, _, landmarks = mtcnn(images, landmarks=True)
27
+
28
+ return landmarks
29
+
30
+
31
+ def extract_faces_and_landmarks(images: torch.Tensor, output_size=112, mtcnn: torch.nn.Module = None, refernce_points=REFERNCE_FACIAL_POINTS_RELATIVE):
32
+ """
33
+ detect faces in the images and crop them (in a differentiable way) to 112x112 using MTCNN.
34
+ """
35
+ images = Pad(200)(images)
36
+ landmarks_batched = detect_face(images, mtcnn=mtcnn)
37
+ affine_transformations = []
38
+ invalid_indices = []
39
+ for i, landmarks in enumerate(landmarks_batched):
40
+ if landmarks is None:
41
+ invalid_indices.append(i)
42
+ affine_transformations.append(np.eye(2, 3).astype(np.float32))
43
+ else:
44
+ affine_transformations.append(get_similarity_transform_for_cv2(landmarks[0].astype(np.float32),
45
+ refernce_points.astype(np.float32) * output_size))
46
+ affine_transformations = torch.from_numpy(np.stack(affine_transformations).astype(np.float32)).to(device=images.device, dtype=torch.float32)
47
+
48
+ invalid_indices = torch.tensor(invalid_indices).to(device=images.device)
49
+
50
+ fp_images = images.to(torch.float32)
51
+ return warp_affine(fp_images, affine_transformations, dsize=(output_size, output_size)).to(dtype=images.dtype), invalid_indices