ElenaRyumina commited on
Commit
9e58e71
1 Parent(s): ef7a94a
Files changed (1) hide show
  1. run_webcam.ipynb +282 -0
run_webcam.ipynb ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "3e1b0206-2912-4385-97b9-5948ed70dfc8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import cv2\n",
11
+ "import mediapipe as mp #face detector\n",
12
+ "import math\n",
13
+ "import numpy as np\n",
14
+ "import time\n",
15
+ "\n",
16
+ "import warnings\n",
17
+ "warnings.simplefilter(\"ignore\", UserWarning)\n",
18
+ "\n",
19
+ "import torch\n",
20
+ "from PIL import Image\n",
21
+ "from torchvision import transforms"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "id": "fcbcf9fa-a7cc-4d4c-b723-6d7efd49b94b",
27
+ "metadata": {},
28
+ "source": [
29
+ "#### Sub functions"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 2,
35
+ "id": "6d0fc324-98a8-4efc-bb11-4bec8a015790",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "def pth_processing(fp):\n",
40
+ " class PreprocessInput(torch.nn.Module):\n",
41
+ " def init(self):\n",
42
+ " super(PreprocessInput, self).init()\n",
43
+ "\n",
44
+ " def forward(self, x):\n",
45
+ " x = x.to(torch.float32)\n",
46
+ " x = torch.flip(x, dims=(0,))\n",
47
+ " x[0, :, :] -= 91.4953\n",
48
+ " x[1, :, :] -= 103.8827\n",
49
+ " x[2, :, :] -= 131.0912\n",
50
+ " return x\n",
51
+ "\n",
52
+ " def get_img_torch(img):\n",
53
+ " \n",
54
+ " ttransform = transforms.Compose([\n",
55
+ " transforms.PILToTensor(),\n",
56
+ " PreprocessInput()\n",
57
+ " ])\n",
58
+ " img = img.resize((224, 224), Image.Resampling.NEAREST)\n",
59
+ " img = ttransform(img)\n",
60
+ " img = torch.unsqueeze(img, 0)\n",
61
+ " return img\n",
62
+ " return get_img_torch(fp)\n",
63
+ "\n",
64
+ "def tf_processing(fp):\n",
65
+ " def preprocess_input(x):\n",
66
+ " x_temp = np.copy(x)\n",
67
+ " x_temp = x_temp[..., ::-1]\n",
68
+ " x_temp[..., 0] -= 91.4953\n",
69
+ " x_temp[..., 1] -= 103.8827\n",
70
+ " x_temp[..., 2] -= 131.0912\n",
71
+ " return x_temp\n",
72
+ "\n",
73
+ " def get_img_tf(img):\n",
74
+ " img = cv2.resize(img, (224,224), interpolation=cv2.INTER_NEAREST)\n",
75
+ " img = tf.keras.utils.img_to_array(img)\n",
76
+ " img = preprocess_input(img)\n",
77
+ " img = np.array([img])\n",
78
+ " return img\n",
79
+ "\n",
80
+ " return get_img_tf(fp)\n",
81
+ "\n",
82
+ "def norm_coordinates(normalized_x, normalized_y, image_width, image_height):\n",
83
+ " \n",
84
+ " x_px = min(math.floor(normalized_x * image_width), image_width - 1)\n",
85
+ " y_px = min(math.floor(normalized_y * image_height), image_height - 1)\n",
86
+ " \n",
87
+ " return x_px, y_px\n",
88
+ "\n",
89
+ "def get_box(fl, w, h):\n",
90
+ " idx_to_coors = {}\n",
91
+ " for idx, landmark in enumerate(fl.landmark):\n",
92
+ " landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)\n",
93
+ "\n",
94
+ " if landmark_px:\n",
95
+ " idx_to_coors[idx] = landmark_px\n",
96
+ "\n",
97
+ " x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])\n",
98
+ " y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])\n",
99
+ " endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])\n",
100
+ " endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])\n",
101
+ "\n",
102
+ " (startX, startY) = (max(0, x_min), max(0, y_min))\n",
103
+ " (endX, endY) = (min(w - 1, endX), min(h - 1, endY))\n",
104
+ " \n",
105
+ " return startX, startY, endX, endY\n",
106
+ "\n",
107
+ "def display_EMO_PRED(img, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), line_width=2, ):\n",
108
+ " lw = line_width or max(round(sum(img.shape) / 2 * 0.003), 2)\n",
109
+ " text2_color = (255, 0, 255)\n",
110
+ " p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))\n",
111
+ " cv2.rectangle(img, p1, p2, text2_color, thickness=lw, lineType=cv2.LINE_AA)\n",
112
+ " font = cv2.FONT_HERSHEY_SIMPLEX\n",
113
+ "\n",
114
+ " tf = max(lw - 1, 1)\n",
115
+ " text_fond = (0, 0, 0)\n",
116
+ " text_width_2, text_height_2 = cv2.getTextSize(label, font, lw / 3, tf)\n",
117
+ " text_width_2 = text_width_2[0] + round(((p2[0] - p1[0]) * 10) / 360)\n",
118
+ " center_face = p1[0] + round((p2[0] - p1[0]) / 2)\n",
119
+ "\n",
120
+ " cv2.putText(img, label,\n",
121
+ " (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n",
122
+ " lw / 3, text_fond, thickness=tf, lineType=cv2.LINE_AA)\n",
123
+ " cv2.putText(img, label,\n",
124
+ " (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n",
125
+ " lw / 3, text2_color, thickness=tf, lineType=cv2.LINE_AA)\n",
126
+ " return img\n",
127
+ "\n",
128
+ "def display_FPS(img, text, margin=1.0, box_scale=1.0):\n",
129
+ " img_h, img_w, _ = img.shape\n",
130
+ " line_width = int(min(img_h, img_w) * 0.001) # line width\n",
131
+ " thickness = max(int(line_width / 3), 1) # font thickness\n",
132
+ "\n",
133
+ " font_face = cv2.FONT_HERSHEY_SIMPLEX\n",
134
+ " font_color = (0, 0, 0)\n",
135
+ " font_scale = thickness / 1.5\n",
136
+ "\n",
137
+ " t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]\n",
138
+ "\n",
139
+ " margin_n = int(t_h * margin)\n",
140
+ " sub_img = img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n",
141
+ " img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]\n",
142
+ "\n",
143
+ " white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255\n",
144
+ "\n",
145
+ " img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n",
146
+ " img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5,\n",
147
+ " 1.0)\n",
148
+ "\n",
149
+ " cv2.putText(img=img,\n",
150
+ " text=text,\n",
151
+ " org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,\n",
152
+ " 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),\n",
153
+ " fontFace=font_face,\n",
154
+ " fontScale=font_scale,\n",
155
+ " color=font_color,\n",
156
+ " thickness=thickness,\n",
157
+ " lineType=cv2.LINE_AA,\n",
158
+ " bottomLeftOrigin=False)\n",
159
+ "\n",
160
+ " return img"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "bae915fd-cc3d-4dc1-83fc-c9c32e1b12a8",
166
+ "metadata": {},
167
+ "source": [
168
+ "#### Testing models by webcam"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 8,
174
+ "id": "c05ed967-a30e-47f5-96ed-b32bab0c6879",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "mp_face_mesh = mp.solutions.face_mesh\n",
179
+ "\n",
180
+ "name_backbone_model = 'FER_static_ResNet50_AffectNet.pth'\n",
181
+ "# name_LSTM_model = 'IEMOCAP'\n",
182
+ "# name_LSTM_model = 'CREMA-D'\n",
183
+ "# name_LSTM_model = 'RAMAS'\n",
184
+ "# name_LSTM_model = 'RAVDESS'\n",
185
+ "# name_LSTM_model = 'SAVEE'\n",
186
+ "name_LSTM_model = 'Aff-Wild2'\n",
187
+ "\n",
188
+ "# torch\n",
189
+ "pth_backbone_model = torch.jit.load(name_backbone_model)\n",
190
+ "pth_backbone_model.eval()\n",
191
+ "\n",
192
+ "pth_LSTM_model = torch.jit.load('FER_dinamic_LSTM_{0}.pth'.format(name_LSTM_model))\n",
193
+ "pth_LSTM_model.eval()\n",
194
+ "\n",
195
+ "DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}\n",
196
+ "\n",
197
+ "cap = cv2.VideoCapture(0)\n",
198
+ "w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
199
+ "h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
200
+ "fps = np.round(cap.get(cv2.CAP_PROP_FPS))\n",
201
+ "\n",
202
+ "path_save_video = 'result.mp4'\n",
203
+ "vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))\n",
204
+ "\n",
205
+ "lstm_features = []\n",
206
+ " \n",
207
+ "with mp_face_mesh.FaceMesh(\n",
208
+ "max_num_faces=1,\n",
209
+ "refine_landmarks=False,\n",
210
+ "min_detection_confidence=0.5,\n",
211
+ "min_tracking_confidence=0.5) as face_mesh:\n",
212
+ "\n",
213
+ " while cap.isOpened():\n",
214
+ " t1 = time.time()\n",
215
+ " success, frame = cap.read()\n",
216
+ " if frame is None: break\n",
217
+ "\n",
218
+ " frame_copy = frame.copy()\n",
219
+ " frame_copy.flags.writeable = False\n",
220
+ " frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)\n",
221
+ " results = face_mesh.process(frame_copy)\n",
222
+ " frame_copy.flags.writeable = True\n",
223
+ " \n",
224
+ " if results.multi_face_landmarks:\n",
225
+ " for fl in results.multi_face_landmarks:\n",
226
+ " startX, startY, endX, endY = get_box(fl, w, h)\n",
227
+ " cur_face = frame_copy[startY:endY, startX: endX]\n",
228
+ " \n",
229
+ " cur_face = pth_processing(Image.fromarray(cur_face))\n",
230
+ " features = torch.nn.functional.relu(pth_backbone_model.extract_features(cur_face)).detach().numpy()\n",
231
+ "\n",
232
+ " if len(lstm_features) == 0:\n",
233
+ " lstm_features = [features]*10\n",
234
+ " else:\n",
235
+ " lstm_features = lstm_features[1:] + [features]\n",
236
+ "\n",
237
+ " lstm_f = torch.from_numpy(np.vstack(lstm_features))\n",
238
+ " lstm_f = torch.unsqueeze(lstm_f, 0)\n",
239
+ " output = pth_LSTM_model(lstm_f).detach().numpy()\n",
240
+ " \n",
241
+ " cl = np.argmax(output)\n",
242
+ " label = DICT_EMO[cl]\n",
243
+ " frame = display_EMO_PRED(frame, (startX, startY, endX, endY), label+' {0:.1%}'.format(output[0][cl]), line_width=3)\n",
244
+ "\n",
245
+ " t2 = time.time()\n",
246
+ "\n",
247
+ " frame = display_FPS(frame, 'FPS: {0:.1f}'.format(1 / (t2 - t1)), box_scale=.5)\n",
248
+ "\n",
249
+ " vid_writer.write(frame)\n",
250
+ " \n",
251
+ " cv2.imshow('Webcam', frame)\n",
252
+ " if cv2.waitKey(1) & 0xFF == ord('q'):\n",
253
+ " break\n",
254
+ "\n",
255
+ " vid_writer.release()\n",
256
+ " cap.release()\n",
257
+ " cv2.destroyAllWindows()"
258
+ ]
259
+ }
260
+ ],
261
+ "metadata": {
262
+ "kernelspec": {
263
+ "display_name": "Python 3 (ipykernel)",
264
+ "language": "python",
265
+ "name": "python3"
266
+ },
267
+ "language_info": {
268
+ "codemirror_mode": {
269
+ "name": "ipython",
270
+ "version": 3
271
+ },
272
+ "file_extension": ".py",
273
+ "mimetype": "text/x-python",
274
+ "name": "python",
275
+ "nbconvert_exporter": "python",
276
+ "pygments_lexer": "ipython3",
277
+ "version": "3.9.13"
278
+ }
279
+ },
280
+ "nbformat": 4,
281
+ "nbformat_minor": 5
282
+ }