Ziv Pollak commited on
Commit
2acfef6
1 Parent(s): 924a30f

adding model

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +116 -3
  3. requirements.txt +4 -1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -9,7 +9,9 @@ from mediapipe.tasks import python
9
  from mediapipe.tasks.python import vision
10
  from mediapipe.framework.formats import landmark_pb2
11
  from mediapipe import solutions
 
12
 
 
13
 
14
  import matplotlib
15
  matplotlib.use("Agg")
@@ -17,6 +19,15 @@ import matplotlib.pyplot as plt
17
 
18
  cropped_image = []
19
  analyzed_image = []
 
 
 
 
 
 
 
 
 
20
  # take a phone
21
  # run face landmark on it to crop image
22
  # run our model on it
@@ -30,6 +41,8 @@ options = vision.FaceLandmarkerOptions(base_options=base_options,
30
  num_faces=1)
31
  detector = vision.FaceLandmarker.create_from_options(options)
32
 
 
 
33
 
34
  def video_identity(video):
35
  return video
@@ -40,6 +53,15 @@ def video_identity(video):
40
  # "playable_video")
41
 
42
 
 
 
 
 
 
 
 
 
 
43
  def handle_image(input_image):
44
  global cropped_image, analyzed_image
45
  cv2.imwrite("image.jpg", input_image)
@@ -63,15 +85,104 @@ def handle_image(input_image):
63
  cv2.circle(input_image, (p1[0], p1[1]), 10, (0, 0, 255), -1)
64
  p2 = [int(face_landmarks_proto.landmark[346].x * width), int(face_landmarks_proto.landmark[346].y * height)]
65
  cv2.circle(input_image, (p2[0], p2[1]), 10, (0, 0, 255), -1)
66
- print(p1[0], p1[1], p2[0], p2[1], height, width)
67
  cropped_image = cropped_image[p1[1]:p2[1], p1[0]:p2[0]]
68
- # [row starting from the top]
69
- #return ([input_image, cropped_image])
 
70
  return (cropped_image)
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  with gr.Blocks() as demo:
 
75
  gr.Markdown(
76
  """
77
  # Iris detection
@@ -94,6 +205,8 @@ with gr.Blocks() as demo:
94
  out = [cropped_image]
95
  b.click(fn=handle_image, inputs=image1, outputs=out)
96
 
 
 
97
  demo.launch()
98
 
99
 
 
9
  from mediapipe.tasks.python import vision
10
  from mediapipe.framework.formats import landmark_pb2
11
  from mediapipe import solutions
12
+ from PIL import Image
13
 
14
+ import torch, torchvision
15
 
16
  import matplotlib
17
  matplotlib.use("Agg")
 
19
 
20
  cropped_image = []
21
  analyzed_image = []
22
+
23
+ # colors for visualization
24
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
25
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
26
+
27
+ finetuned_classes = [
28
+ 'iris',
29
+ ]
30
+
31
  # take a phone
32
  # run face landmark on it to crop image
33
  # run our model on it
 
41
  num_faces=1)
42
  detector = vision.FaceLandmarker.create_from_options(options)
43
 
44
+ model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
45
+ model.eval();
46
 
47
  def video_identity(video):
48
  return video
 
53
  # "playable_video")
54
 
55
 
56
+ import torchvision.transforms as T
57
+
58
+ # standard PyTorch mean-std input image normalization
59
+ transform = T.Compose([
60
+ T.Resize(800),
61
+ T.ToTensor(),
62
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
63
+ ])
64
+
65
  def handle_image(input_image):
66
  global cropped_image, analyzed_image
67
  cv2.imwrite("image.jpg", input_image)
 
85
  cv2.circle(input_image, (p1[0], p1[1]), 10, (0, 0, 255), -1)
86
  p2 = [int(face_landmarks_proto.landmark[346].x * width), int(face_landmarks_proto.landmark[346].y * height)]
87
  cv2.circle(input_image, (p2[0], p2[1]), 10, (0, 0, 255), -1)
 
88
  cropped_image = cropped_image[p1[1]:p2[1], p1[0]:p2[0]]
89
+
90
+ run_worflow(cropped_image, model)
91
+
92
  return (cropped_image)
93
 
94
+ def load_model():
95
+ print('load model')
96
+ '''
97
+ model = torch.hub.load('facebookresearch/detr',
98
+ 'detr_resnet50',
99
+ pretrained=False,
100
+ num_classes=1)
101
+
102
+ checkpoint = torch.load('outputs/checkpoint.pth',
103
+ map_location='cpu')
104
+
105
+ model.load_state_dict(checkpoint['model'],
106
+ strict=False)
107
+
108
+ model.eval();
109
+ '''
110
+ def filter_bboxes_from_outputs(img,
111
+ outputs,
112
+ threshold=0.7
113
+ ):
114
+
115
+ # keep only predictions with confidence above threshold
116
+ probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
117
+ keep = probas.max(-1).values > threshold
118
+
119
+ probas_to_keep = probas[keep]
120
+
121
+ # convert boxes from [0; 1] to image scales
122
+ bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], img.size)
123
+
124
+ return probas_to_keep, bboxes_scaled
125
+
126
+
127
+ def plot_finetuned_results(pil_img, prob=None, boxes=None):
128
+ plt.figure(figsize=(16,10))
129
+ plt.imshow(pil_img)
130
+ ax = plt.gca()
131
+ colors = COLORS * 100
132
+ if prob is not None and boxes is not None:
133
+ for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
134
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
135
+ fill=False, color=c, linewidth=3))
136
+ cl = p.argmax()
137
+ #text = f'{finetuned_classes[cl]}: {p[cl]:0.2f}'
138
+ text = 'results'
139
+ ax.text(xmin, ymin, text, fontsize=15,
140
+ bbox=dict(facecolor='yellow', alpha=0.5))
141
+ plt.axis('off')
142
+ plt.show()
143
+
144
+ def rescale_bboxes(out_bbox, size):
145
+ print (size)
146
+ img_w, img_h = size
147
+ b = box_cxcywh_to_xyxy(out_bbox)
148
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
149
+ return b
150
+
151
+ def box_cxcywh_to_xyxy(x):
152
+ x_c, y_c, w, h = x.unbind(1)
153
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
154
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
155
+ return torch.stack(b, dim=1)
156
+
157
+
158
+ def run_worflow(my_image, my_model):
159
+
160
+ # Write image to disk and read it as PIL !!!!
161
+ cv2.imwrite("img1.jpg", my_image)
162
+ my_image = Image.open("img1.jpg")
163
+
164
+ # mean-std normalize the input image (batch-size: 1)
165
+ img = transform(my_image).unsqueeze(0)
166
+
167
+ # propagate through the model
168
+ outputs = my_model(img)
169
+
170
+ for threshold in [0.2, 0.2]:
171
+
172
+ probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(my_image,
173
+ outputs,
174
+ threshold=threshold)
175
+
176
+ plot_finetuned_results(my_image,
177
+ probas_to_keep,
178
+ bboxes_scaled)
179
+
180
+
181
+
182
 
183
 
184
  with gr.Blocks() as demo:
185
+
186
  gr.Markdown(
187
  """
188
  # Iris detection
 
205
  out = [cropped_image]
206
  b.click(fn=handle_image, inputs=image1, outputs=out)
207
 
208
+
209
+
210
  demo.launch()
211
 
212
 
requirements.txt CHANGED
@@ -3,4 +3,7 @@ numpy
3
  pandas
4
  Pillow
5
  opencv-python
6
- mediapipe
 
 
 
 
3
  pandas
4
  Pillow
5
  opencv-python
6
+ mediapipe
7
+ torch
8
+ torchvision
9
+ scipy