ElenaRyumina
commited on
Commit
•
9e58e71
1
Parent(s):
ef7a94a
Summary
Browse files- run_webcam.ipynb +282 -0
run_webcam.ipynb
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "3e1b0206-2912-4385-97b9-5948ed70dfc8",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import cv2\n",
|
11 |
+
"import mediapipe as mp #face detector\n",
|
12 |
+
"import math\n",
|
13 |
+
"import numpy as np\n",
|
14 |
+
"import time\n",
|
15 |
+
"\n",
|
16 |
+
"import warnings\n",
|
17 |
+
"warnings.simplefilter(\"ignore\", UserWarning)\n",
|
18 |
+
"\n",
|
19 |
+
"import torch\n",
|
20 |
+
"from PIL import Image\n",
|
21 |
+
"from torchvision import transforms"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "markdown",
|
26 |
+
"id": "fcbcf9fa-a7cc-4d4c-b723-6d7efd49b94b",
|
27 |
+
"metadata": {},
|
28 |
+
"source": [
|
29 |
+
"#### Sub functions"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 2,
|
35 |
+
"id": "6d0fc324-98a8-4efc-bb11-4bec8a015790",
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"def pth_processing(fp):\n",
|
40 |
+
" class PreprocessInput(torch.nn.Module):\n",
|
41 |
+
" def init(self):\n",
|
42 |
+
" super(PreprocessInput, self).init()\n",
|
43 |
+
"\n",
|
44 |
+
" def forward(self, x):\n",
|
45 |
+
" x = x.to(torch.float32)\n",
|
46 |
+
" x = torch.flip(x, dims=(0,))\n",
|
47 |
+
" x[0, :, :] -= 91.4953\n",
|
48 |
+
" x[1, :, :] -= 103.8827\n",
|
49 |
+
" x[2, :, :] -= 131.0912\n",
|
50 |
+
" return x\n",
|
51 |
+
"\n",
|
52 |
+
" def get_img_torch(img):\n",
|
53 |
+
" \n",
|
54 |
+
" ttransform = transforms.Compose([\n",
|
55 |
+
" transforms.PILToTensor(),\n",
|
56 |
+
" PreprocessInput()\n",
|
57 |
+
" ])\n",
|
58 |
+
" img = img.resize((224, 224), Image.Resampling.NEAREST)\n",
|
59 |
+
" img = ttransform(img)\n",
|
60 |
+
" img = torch.unsqueeze(img, 0)\n",
|
61 |
+
" return img\n",
|
62 |
+
" return get_img_torch(fp)\n",
|
63 |
+
"\n",
|
64 |
+
"def tf_processing(fp):\n",
|
65 |
+
" def preprocess_input(x):\n",
|
66 |
+
" x_temp = np.copy(x)\n",
|
67 |
+
" x_temp = x_temp[..., ::-1]\n",
|
68 |
+
" x_temp[..., 0] -= 91.4953\n",
|
69 |
+
" x_temp[..., 1] -= 103.8827\n",
|
70 |
+
" x_temp[..., 2] -= 131.0912\n",
|
71 |
+
" return x_temp\n",
|
72 |
+
"\n",
|
73 |
+
" def get_img_tf(img):\n",
|
74 |
+
" img = cv2.resize(img, (224,224), interpolation=cv2.INTER_NEAREST)\n",
|
75 |
+
" img = tf.keras.utils.img_to_array(img)\n",
|
76 |
+
" img = preprocess_input(img)\n",
|
77 |
+
" img = np.array([img])\n",
|
78 |
+
" return img\n",
|
79 |
+
"\n",
|
80 |
+
" return get_img_tf(fp)\n",
|
81 |
+
"\n",
|
82 |
+
"def norm_coordinates(normalized_x, normalized_y, image_width, image_height):\n",
|
83 |
+
" \n",
|
84 |
+
" x_px = min(math.floor(normalized_x * image_width), image_width - 1)\n",
|
85 |
+
" y_px = min(math.floor(normalized_y * image_height), image_height - 1)\n",
|
86 |
+
" \n",
|
87 |
+
" return x_px, y_px\n",
|
88 |
+
"\n",
|
89 |
+
"def get_box(fl, w, h):\n",
|
90 |
+
" idx_to_coors = {}\n",
|
91 |
+
" for idx, landmark in enumerate(fl.landmark):\n",
|
92 |
+
" landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)\n",
|
93 |
+
"\n",
|
94 |
+
" if landmark_px:\n",
|
95 |
+
" idx_to_coors[idx] = landmark_px\n",
|
96 |
+
"\n",
|
97 |
+
" x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])\n",
|
98 |
+
" y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])\n",
|
99 |
+
" endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])\n",
|
100 |
+
" endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])\n",
|
101 |
+
"\n",
|
102 |
+
" (startX, startY) = (max(0, x_min), max(0, y_min))\n",
|
103 |
+
" (endX, endY) = (min(w - 1, endX), min(h - 1, endY))\n",
|
104 |
+
" \n",
|
105 |
+
" return startX, startY, endX, endY\n",
|
106 |
+
"\n",
|
107 |
+
"def display_EMO_PRED(img, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), line_width=2, ):\n",
|
108 |
+
" lw = line_width or max(round(sum(img.shape) / 2 * 0.003), 2)\n",
|
109 |
+
" text2_color = (255, 0, 255)\n",
|
110 |
+
" p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))\n",
|
111 |
+
" cv2.rectangle(img, p1, p2, text2_color, thickness=lw, lineType=cv2.LINE_AA)\n",
|
112 |
+
" font = cv2.FONT_HERSHEY_SIMPLEX\n",
|
113 |
+
"\n",
|
114 |
+
" tf = max(lw - 1, 1)\n",
|
115 |
+
" text_fond = (0, 0, 0)\n",
|
116 |
+
" text_width_2, text_height_2 = cv2.getTextSize(label, font, lw / 3, tf)\n",
|
117 |
+
" text_width_2 = text_width_2[0] + round(((p2[0] - p1[0]) * 10) / 360)\n",
|
118 |
+
" center_face = p1[0] + round((p2[0] - p1[0]) / 2)\n",
|
119 |
+
"\n",
|
120 |
+
" cv2.putText(img, label,\n",
|
121 |
+
" (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n",
|
122 |
+
" lw / 3, text_fond, thickness=tf, lineType=cv2.LINE_AA)\n",
|
123 |
+
" cv2.putText(img, label,\n",
|
124 |
+
" (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n",
|
125 |
+
" lw / 3, text2_color, thickness=tf, lineType=cv2.LINE_AA)\n",
|
126 |
+
" return img\n",
|
127 |
+
"\n",
|
128 |
+
"def display_FPS(img, text, margin=1.0, box_scale=1.0):\n",
|
129 |
+
" img_h, img_w, _ = img.shape\n",
|
130 |
+
" line_width = int(min(img_h, img_w) * 0.001) # line width\n",
|
131 |
+
" thickness = max(int(line_width / 3), 1) # font thickness\n",
|
132 |
+
"\n",
|
133 |
+
" font_face = cv2.FONT_HERSHEY_SIMPLEX\n",
|
134 |
+
" font_color = (0, 0, 0)\n",
|
135 |
+
" font_scale = thickness / 1.5\n",
|
136 |
+
"\n",
|
137 |
+
" t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]\n",
|
138 |
+
"\n",
|
139 |
+
" margin_n = int(t_h * margin)\n",
|
140 |
+
" sub_img = img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n",
|
141 |
+
" img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]\n",
|
142 |
+
"\n",
|
143 |
+
" white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255\n",
|
144 |
+
"\n",
|
145 |
+
" img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n",
|
146 |
+
" img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5,\n",
|
147 |
+
" 1.0)\n",
|
148 |
+
"\n",
|
149 |
+
" cv2.putText(img=img,\n",
|
150 |
+
" text=text,\n",
|
151 |
+
" org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,\n",
|
152 |
+
" 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),\n",
|
153 |
+
" fontFace=font_face,\n",
|
154 |
+
" fontScale=font_scale,\n",
|
155 |
+
" color=font_color,\n",
|
156 |
+
" thickness=thickness,\n",
|
157 |
+
" lineType=cv2.LINE_AA,\n",
|
158 |
+
" bottomLeftOrigin=False)\n",
|
159 |
+
"\n",
|
160 |
+
" return img"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"id": "bae915fd-cc3d-4dc1-83fc-c9c32e1b12a8",
|
166 |
+
"metadata": {},
|
167 |
+
"source": [
|
168 |
+
"#### Testing models by webcam"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": 8,
|
174 |
+
"id": "c05ed967-a30e-47f5-96ed-b32bab0c6879",
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"mp_face_mesh = mp.solutions.face_mesh\n",
|
179 |
+
"\n",
|
180 |
+
"name_backbone_model = 'FER_static_ResNet50_AffectNet.pth'\n",
|
181 |
+
"# name_LSTM_model = 'IEMOCAP'\n",
|
182 |
+
"# name_LSTM_model = 'CREMA-D'\n",
|
183 |
+
"# name_LSTM_model = 'RAMAS'\n",
|
184 |
+
"# name_LSTM_model = 'RAVDESS'\n",
|
185 |
+
"# name_LSTM_model = 'SAVEE'\n",
|
186 |
+
"name_LSTM_model = 'Aff-Wild2'\n",
|
187 |
+
"\n",
|
188 |
+
"# torch\n",
|
189 |
+
"pth_backbone_model = torch.jit.load(name_backbone_model)\n",
|
190 |
+
"pth_backbone_model.eval()\n",
|
191 |
+
"\n",
|
192 |
+
"pth_LSTM_model = torch.jit.load('FER_dinamic_LSTM_{0}.pth'.format(name_LSTM_model))\n",
|
193 |
+
"pth_LSTM_model.eval()\n",
|
194 |
+
"\n",
|
195 |
+
"DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}\n",
|
196 |
+
"\n",
|
197 |
+
"cap = cv2.VideoCapture(0)\n",
|
198 |
+
"w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
|
199 |
+
"h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
|
200 |
+
"fps = np.round(cap.get(cv2.CAP_PROP_FPS))\n",
|
201 |
+
"\n",
|
202 |
+
"path_save_video = 'result.mp4'\n",
|
203 |
+
"vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))\n",
|
204 |
+
"\n",
|
205 |
+
"lstm_features = []\n",
|
206 |
+
" \n",
|
207 |
+
"with mp_face_mesh.FaceMesh(\n",
|
208 |
+
"max_num_faces=1,\n",
|
209 |
+
"refine_landmarks=False,\n",
|
210 |
+
"min_detection_confidence=0.5,\n",
|
211 |
+
"min_tracking_confidence=0.5) as face_mesh:\n",
|
212 |
+
"\n",
|
213 |
+
" while cap.isOpened():\n",
|
214 |
+
" t1 = time.time()\n",
|
215 |
+
" success, frame = cap.read()\n",
|
216 |
+
" if frame is None: break\n",
|
217 |
+
"\n",
|
218 |
+
" frame_copy = frame.copy()\n",
|
219 |
+
" frame_copy.flags.writeable = False\n",
|
220 |
+
" frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)\n",
|
221 |
+
" results = face_mesh.process(frame_copy)\n",
|
222 |
+
" frame_copy.flags.writeable = True\n",
|
223 |
+
" \n",
|
224 |
+
" if results.multi_face_landmarks:\n",
|
225 |
+
" for fl in results.multi_face_landmarks:\n",
|
226 |
+
" startX, startY, endX, endY = get_box(fl, w, h)\n",
|
227 |
+
" cur_face = frame_copy[startY:endY, startX: endX]\n",
|
228 |
+
" \n",
|
229 |
+
" cur_face = pth_processing(Image.fromarray(cur_face))\n",
|
230 |
+
" features = torch.nn.functional.relu(pth_backbone_model.extract_features(cur_face)).detach().numpy()\n",
|
231 |
+
"\n",
|
232 |
+
" if len(lstm_features) == 0:\n",
|
233 |
+
" lstm_features = [features]*10\n",
|
234 |
+
" else:\n",
|
235 |
+
" lstm_features = lstm_features[1:] + [features]\n",
|
236 |
+
"\n",
|
237 |
+
" lstm_f = torch.from_numpy(np.vstack(lstm_features))\n",
|
238 |
+
" lstm_f = torch.unsqueeze(lstm_f, 0)\n",
|
239 |
+
" output = pth_LSTM_model(lstm_f).detach().numpy()\n",
|
240 |
+
" \n",
|
241 |
+
" cl = np.argmax(output)\n",
|
242 |
+
" label = DICT_EMO[cl]\n",
|
243 |
+
" frame = display_EMO_PRED(frame, (startX, startY, endX, endY), label+' {0:.1%}'.format(output[0][cl]), line_width=3)\n",
|
244 |
+
"\n",
|
245 |
+
" t2 = time.time()\n",
|
246 |
+
"\n",
|
247 |
+
" frame = display_FPS(frame, 'FPS: {0:.1f}'.format(1 / (t2 - t1)), box_scale=.5)\n",
|
248 |
+
"\n",
|
249 |
+
" vid_writer.write(frame)\n",
|
250 |
+
" \n",
|
251 |
+
" cv2.imshow('Webcam', frame)\n",
|
252 |
+
" if cv2.waitKey(1) & 0xFF == ord('q'):\n",
|
253 |
+
" break\n",
|
254 |
+
"\n",
|
255 |
+
" vid_writer.release()\n",
|
256 |
+
" cap.release()\n",
|
257 |
+
" cv2.destroyAllWindows()"
|
258 |
+
]
|
259 |
+
}
|
260 |
+
],
|
261 |
+
"metadata": {
|
262 |
+
"kernelspec": {
|
263 |
+
"display_name": "Python 3 (ipykernel)",
|
264 |
+
"language": "python",
|
265 |
+
"name": "python3"
|
266 |
+
},
|
267 |
+
"language_info": {
|
268 |
+
"codemirror_mode": {
|
269 |
+
"name": "ipython",
|
270 |
+
"version": 3
|
271 |
+
},
|
272 |
+
"file_extension": ".py",
|
273 |
+
"mimetype": "text/x-python",
|
274 |
+
"name": "python",
|
275 |
+
"nbconvert_exporter": "python",
|
276 |
+
"pygments_lexer": "ipython3",
|
277 |
+
"version": "3.9.13"
|
278 |
+
}
|
279 |
+
},
|
280 |
+
"nbformat": 4,
|
281 |
+
"nbformat_minor": 5
|
282 |
+
}
|