{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3e1b0206-2912-4385-97b9-5948ed70dfc8", "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import mediapipe as mp #face detector\n", "import math\n", "import numpy as np\n", "import time\n", "\n", "import warnings\n", "warnings.simplefilter(\"ignore\", UserWarning)\n", "\n", "import torch\n", "from PIL import Image\n", "from torchvision import transforms" ] }, { "cell_type": "markdown", "id": "fcbcf9fa-a7cc-4d4c-b723-6d7efd49b94b", "metadata": {}, "source": [ "#### Sub functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "6d0fc324-98a8-4efc-bb11-4bec8a015790", "metadata": {}, "outputs": [], "source": [ "def pth_processing(fp):\n", " class PreprocessInput(torch.nn.Module):\n", " def init(self):\n", " super(PreprocessInput, self).init()\n", "\n", " def forward(self, x):\n", " x = x.to(torch.float32)\n", " x = torch.flip(x, dims=(0,))\n", " x[0, :, :] -= 91.4953\n", " x[1, :, :] -= 103.8827\n", " x[2, :, :] -= 131.0912\n", " return x\n", "\n", " def get_img_torch(img):\n", " \n", " ttransform = transforms.Compose([\n", " transforms.PILToTensor(),\n", " PreprocessInput()\n", " ])\n", " img = img.resize((224, 224), Image.Resampling.NEAREST)\n", " img = ttransform(img)\n", " img = torch.unsqueeze(img, 0)\n", " return img\n", " return get_img_torch(fp)\n", "\n", "def tf_processing(fp):\n", " def preprocess_input(x):\n", " x_temp = np.copy(x)\n", " x_temp = x_temp[..., ::-1]\n", " x_temp[..., 0] -= 91.4953\n", " x_temp[..., 1] -= 103.8827\n", " x_temp[..., 2] -= 131.0912\n", " return x_temp\n", "\n", " def get_img_tf(img):\n", " img = cv2.resize(img, (224,224), interpolation=cv2.INTER_NEAREST)\n", " img = tf.keras.utils.img_to_array(img)\n", " img = preprocess_input(img)\n", " img = np.array([img])\n", " return img\n", "\n", " return get_img_tf(fp)\n", "\n", "def norm_coordinates(normalized_x, normalized_y, image_width, image_height):\n", " \n", " x_px = min(math.floor(normalized_x * image_width), image_width - 1)\n", " y_px = min(math.floor(normalized_y * image_height), image_height - 1)\n", " \n", " return x_px, y_px\n", "\n", "def get_box(fl, w, h):\n", " idx_to_coors = {}\n", " for idx, landmark in enumerate(fl.landmark):\n", " landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)\n", "\n", " if landmark_px:\n", " idx_to_coors[idx] = landmark_px\n", "\n", " x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])\n", " y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])\n", " endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])\n", " endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])\n", "\n", " (startX, startY) = (max(0, x_min), max(0, y_min))\n", " (endX, endY) = (min(w - 1, endX), min(h - 1, endY))\n", " \n", " return startX, startY, endX, endY\n", "\n", "def display_EMO_PRED(img, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), line_width=2, ):\n", " lw = line_width or max(round(sum(img.shape) / 2 * 0.003), 2)\n", " text2_color = (255, 0, 255)\n", " p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))\n", " cv2.rectangle(img, p1, p2, text2_color, thickness=lw, lineType=cv2.LINE_AA)\n", " font = cv2.FONT_HERSHEY_SIMPLEX\n", "\n", " tf = max(lw - 1, 1)\n", " text_fond = (0, 0, 0)\n", " text_width_2, text_height_2 = cv2.getTextSize(label, font, lw / 3, tf)\n", " text_width_2 = text_width_2[0] + round(((p2[0] - p1[0]) * 10) / 360)\n", " center_face = p1[0] + round((p2[0] - p1[0]) / 2)\n", "\n", " cv2.putText(img, label,\n", " (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n", " lw / 3, text_fond, thickness=tf, lineType=cv2.LINE_AA)\n", " cv2.putText(img, label,\n", " (center_face - round(text_width_2 / 2), p1[1] - round(((p2[0] - p1[0]) * 20) / 360)), font,\n", " lw / 3, text2_color, thickness=tf, lineType=cv2.LINE_AA)\n", " return img\n", "\n", "def display_FPS(img, text, margin=1.0, box_scale=1.0):\n", " img_h, img_w, _ = img.shape\n", " line_width = int(min(img_h, img_w) * 0.001) # line width\n", " thickness = max(int(line_width / 3), 1) # font thickness\n", "\n", " font_face = cv2.FONT_HERSHEY_SIMPLEX\n", " font_color = (0, 0, 0)\n", " font_scale = thickness / 1.5\n", "\n", " t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]\n", "\n", " margin_n = int(t_h * margin)\n", " sub_img = img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n", " img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]\n", "\n", " white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255\n", "\n", " img[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),\n", " img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5,\n", " 1.0)\n", "\n", " cv2.putText(img=img,\n", " text=text,\n", " org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,\n", " 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),\n", " fontFace=font_face,\n", " fontScale=font_scale,\n", " color=font_color,\n", " thickness=thickness,\n", " lineType=cv2.LINE_AA,\n", " bottomLeftOrigin=False)\n", "\n", " return img" ] }, { "cell_type": "markdown", "id": "bae915fd-cc3d-4dc1-83fc-c9c32e1b12a8", "metadata": {}, "source": [ "#### Testing models by webcam" ] }, { "cell_type": "code", "execution_count": 8, "id": "c05ed967-a30e-47f5-96ed-b32bab0c6879", "metadata": {}, "outputs": [], "source": [ "mp_face_mesh = mp.solutions.face_mesh\n", "\n", "name_backbone_model = 'FER_static_ResNet50_AffectNet.pth'\n", "# name_LSTM_model = 'IEMOCAP'\n", "# name_LSTM_model = 'CREMA-D'\n", "# name_LSTM_model = 'RAMAS'\n", "# name_LSTM_model = 'RAVDESS'\n", "# name_LSTM_model = 'SAVEE'\n", "name_LSTM_model = 'Aff-Wild2'\n", "\n", "# torch\n", "pth_backbone_model = torch.jit.load(name_backbone_model)\n", "pth_backbone_model.eval()\n", "\n", "pth_LSTM_model = torch.jit.load('FER_dinamic_LSTM_{0}.pth'.format(name_LSTM_model))\n", "pth_LSTM_model.eval()\n", "\n", "DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}\n", "\n", "cap = cv2.VideoCapture(0)\n", "w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", "h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", "fps = np.round(cap.get(cv2.CAP_PROP_FPS))\n", "\n", "path_save_video = 'result.mp4'\n", "vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))\n", "\n", "lstm_features = []\n", " \n", "with mp_face_mesh.FaceMesh(\n", "max_num_faces=1,\n", "refine_landmarks=False,\n", "min_detection_confidence=0.5,\n", "min_tracking_confidence=0.5) as face_mesh:\n", "\n", " while cap.isOpened():\n", " t1 = time.time()\n", " success, frame = cap.read()\n", " if frame is None: break\n", "\n", " frame_copy = frame.copy()\n", " frame_copy.flags.writeable = False\n", " frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)\n", " results = face_mesh.process(frame_copy)\n", " frame_copy.flags.writeable = True\n", " \n", " if results.multi_face_landmarks:\n", " for fl in results.multi_face_landmarks:\n", " startX, startY, endX, endY = get_box(fl, w, h)\n", " cur_face = frame_copy[startY:endY, startX: endX]\n", " \n", " cur_face = pth_processing(Image.fromarray(cur_face))\n", " features = torch.nn.functional.relu(pth_backbone_model.extract_features(cur_face)).detach().numpy()\n", "\n", " if len(lstm_features) == 0:\n", " lstm_features = [features]*10\n", " else:\n", " lstm_features = lstm_features[1:] + [features]\n", "\n", " lstm_f = torch.from_numpy(np.vstack(lstm_features))\n", " lstm_f = torch.unsqueeze(lstm_f, 0)\n", " output = pth_LSTM_model(lstm_f).detach().numpy()\n", " \n", " cl = np.argmax(output)\n", " label = DICT_EMO[cl]\n", " frame = display_EMO_PRED(frame, (startX, startY, endX, endY), label+' {0:.1%}'.format(output[0][cl]), line_width=3)\n", "\n", " t2 = time.time()\n", "\n", " frame = display_FPS(frame, 'FPS: {0:.1f}'.format(1 / (t2 - t1)), box_scale=.5)\n", "\n", " vid_writer.write(frame)\n", " \n", " cv2.imshow('Webcam', frame)\n", " if cv2.waitKey(1) & 0xFF == ord('q'):\n", " break\n", "\n", " vid_writer.release()\n", " cap.release()\n", " cv2.destroyAllWindows()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }