{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "c:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\groundingdino\\models\\GroundingDINO\\ms_deform_attn.py:31: UserWarning: Failed to load custom C++ ops. Running on CPU mode Only!\n", " warnings.warn(\"Failed to load custom C++ ops. Running on CPU mode Only!\")\n" ] } ], "source": [ "import os\n", "import cv2\n", "from SegTracker import SegTracker\n", "from model_args import aot_args,sam_args,segtracker_args\n", "from PIL import Image\n", "from aot_tracker import _palette\n", "import numpy as np\n", "import torch\n", "import imageio\n", "import matplotlib.pyplot as plt\n", "from scipy.ndimage import binary_dilation\n", "import gc\n", "def save_prediction(pred_mask,output_dir,file_name):\n", " save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n", " save_mask = save_mask.convert(mode='P')\n", " save_mask.putpalette(_palette)\n", " save_mask.save(os.path.join(output_dir,file_name))\n", "def colorize_mask(pred_mask):\n", " save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n", " save_mask = save_mask.convert(mode='P')\n", " save_mask.putpalette(_palette)\n", " save_mask = save_mask.convert(mode='RGB')\n", " return np.array(save_mask)\n", "def draw_mask(img, mask, alpha=0.7, id_countour=False):\n", " img_mask = np.zeros_like(img)\n", " img_mask = img\n", " if id_countour:\n", " # very slow ~ 1s per image\n", " obj_ids = np.unique(mask)\n", " obj_ids = obj_ids[obj_ids!=0]\n", "\n", " for id in obj_ids:\n", " # Overlay color on binary mask\n", " if id <= 255:\n", " color = _palette[id*3:id*3+3]\n", " else:\n", " color = [0,0,0]\n", " foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)\n", " binary_mask = (mask == id)\n", "\n", " # Compose image\n", " img_mask[binary_mask] = foreground[binary_mask]\n", "\n", " countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n", " img_mask[countours, :] = 0\n", " else:\n", " binary_mask = (mask!=0)\n", " countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n", " foreground = img*(1-alpha)+colorize_mask(mask)*alpha\n", " img_mask[binary_mask] = foreground[binary_mask]\n", " img_mask[countours,:] = 0\n", " \n", " return img_mask.astype(img.dtype)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Set parameters for input and output" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "video_name = 'cars'\n", "io_args = {\n", " 'input_video': f'./assets/{video_name}.mp4',\n", " 'output_mask_dir': f'./assets/{video_name}_masks', # save pred masks\n", " 'output_video': f'./assets/{video_name}_seg.mp4', # mask+frame vizualization, mp4 or avi, else the same as input video\n", " 'output_gif': f'./assets/{video_name}_seg.gif', # mask visualization\n", "}" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Tuning Grounding-DINO and SAM on the First Frame for Good Initialization" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "ename": "AssertionError", "evalue": "Torch not compiled with CUDA enabled", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[3], line 27\u001b[0m\n\u001b[0;32m 25\u001b[0m cap \u001b[39m=\u001b[39m cv2\u001b[39m.\u001b[39mVideoCapture(io_args[\u001b[39m'\u001b[39m\u001b[39minput_video\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[0;32m 26\u001b[0m frame_idx \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m---> 27\u001b[0m segtracker \u001b[39m=\u001b[39m SegTracker(segtracker_args,sam_args,aot_args)\n\u001b[0;32m 28\u001b[0m segtracker\u001b[39m.\u001b[39mrestart_tracker()\n\u001b[0;32m 29\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mamp\u001b[39m.\u001b[39mautocast():\n", "File \u001b[1;32md:\\05 Dr\\Segmentation\\Segment-and-Track-Anything\\SegTracker.py:19\u001b[0m, in \u001b[0;36mSegTracker.__init__\u001b[1;34m(self, segtracker_args, sam_args, aot_args)\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m,segtracker_args, sam_args, aot_args) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 16\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 17\u001b[0m \u001b[39m Initialize SAM and AOT.\u001b[39;00m\n\u001b[0;32m 18\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m---> 19\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msam \u001b[39m=\u001b[39m Segmentor(sam_args)\n\u001b[0;32m 20\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtracker \u001b[39m=\u001b[39m get_aot(aot_args)\n\u001b[0;32m 21\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdetector \u001b[39m=\u001b[39m Detector(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39msam\u001b[39m.\u001b[39mdevice)\n", "File \u001b[1;32md:\\05 Dr\\Segmentation\\Segment-and-Track-Anything\\tool\\segmentor.py:16\u001b[0m, in \u001b[0;36mSegmentor.__init__\u001b[1;34m(self, sam_args)\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice \u001b[39m=\u001b[39m sam_args[\u001b[39m\"\u001b[39m\u001b[39mgpu_id\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[0;32m 15\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msam \u001b[39m=\u001b[39m sam_model_registry[sam_args[\u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m]](checkpoint\u001b[39m=\u001b[39msam_args[\u001b[39m\"\u001b[39m\u001b[39msam_checkpoint\u001b[39m\u001b[39m\"\u001b[39m])\n\u001b[1;32m---> 16\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msam\u001b[39m.\u001b[39;49mto(device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdevice)\n\u001b[0;32m 17\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39meverything_generator \u001b[39m=\u001b[39m SamAutomaticMaskGenerator(model\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39msam, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39msam_args[\u001b[39m'\u001b[39m\u001b[39mgenerator_args\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[0;32m 18\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minteractive_predictor \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39meverything_generator\u001b[39m.\u001b[39mpredictor\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1145\u001b[0m, in \u001b[0;36mModule.to\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1141\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39mto(device, dtype \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mis_floating_point() \u001b[39mor\u001b[39;00m t\u001b[39m.\u001b[39mis_complex() \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m 1142\u001b[0m non_blocking, memory_format\u001b[39m=\u001b[39mconvert_to_format)\n\u001b[0;32m 1143\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39mto(device, dtype \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mis_floating_point() \u001b[39mor\u001b[39;00m t\u001b[39m.\u001b[39mis_complex() \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m, non_blocking)\n\u001b[1;32m-> 1145\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_apply(convert)\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:797\u001b[0m, in \u001b[0;36mModule._apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 795\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_apply\u001b[39m(\u001b[39mself\u001b[39m, fn):\n\u001b[0;32m 796\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mchildren():\n\u001b[1;32m--> 797\u001b[0m module\u001b[39m.\u001b[39;49m_apply(fn)\n\u001b[0;32m 799\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[0;32m 800\u001b[0m \u001b[39mif\u001b[39;00m torch\u001b[39m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[0;32m 801\u001b[0m \u001b[39m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[0;32m 802\u001b[0m \u001b[39m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 807\u001b[0m \u001b[39m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[0;32m 808\u001b[0m \u001b[39m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:797\u001b[0m, in \u001b[0;36mModule._apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 795\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_apply\u001b[39m(\u001b[39mself\u001b[39m, fn):\n\u001b[0;32m 796\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mchildren():\n\u001b[1;32m--> 797\u001b[0m module\u001b[39m.\u001b[39;49m_apply(fn)\n\u001b[0;32m 799\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[0;32m 800\u001b[0m \u001b[39mif\u001b[39;00m torch\u001b[39m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[0;32m 801\u001b[0m \u001b[39m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[0;32m 802\u001b[0m \u001b[39m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 807\u001b[0m \u001b[39m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[0;32m 808\u001b[0m \u001b[39m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:797\u001b[0m, in \u001b[0;36mModule._apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 795\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_apply\u001b[39m(\u001b[39mself\u001b[39m, fn):\n\u001b[0;32m 796\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mchildren():\n\u001b[1;32m--> 797\u001b[0m module\u001b[39m.\u001b[39;49m_apply(fn)\n\u001b[0;32m 799\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[0;32m 800\u001b[0m \u001b[39mif\u001b[39;00m torch\u001b[39m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[0;32m 801\u001b[0m \u001b[39m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[0;32m 802\u001b[0m \u001b[39m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 807\u001b[0m \u001b[39m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[0;32m 808\u001b[0m \u001b[39m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:820\u001b[0m, in \u001b[0;36mModule._apply\u001b[1;34m(self, fn)\u001b[0m\n\u001b[0;32m 816\u001b[0m \u001b[39m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[0;32m 817\u001b[0m \u001b[39m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[0;32m 818\u001b[0m \u001b[39m# `with torch.no_grad():`\u001b[39;00m\n\u001b[0;32m 819\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m--> 820\u001b[0m param_applied \u001b[39m=\u001b[39m fn(param)\n\u001b[0;32m 821\u001b[0m should_use_set_data \u001b[39m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[0;32m 822\u001b[0m \u001b[39mif\u001b[39;00m should_use_set_data:\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\nn\\modules\\module.py:1143\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[1;34m(t)\u001b[0m\n\u001b[0;32m 1140\u001b[0m \u001b[39mif\u001b[39;00m convert_to_format \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m t\u001b[39m.\u001b[39mdim() \u001b[39min\u001b[39;00m (\u001b[39m4\u001b[39m, \u001b[39m5\u001b[39m):\n\u001b[0;32m 1141\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39mto(device, dtype \u001b[39mif\u001b[39;00m t\u001b[39m.\u001b[39mis_floating_point() \u001b[39mor\u001b[39;00m t\u001b[39m.\u001b[39mis_complex() \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m 1142\u001b[0m non_blocking, memory_format\u001b[39m=\u001b[39mconvert_to_format)\n\u001b[1;32m-> 1143\u001b[0m \u001b[39mreturn\u001b[39;00m t\u001b[39m.\u001b[39;49mto(device, dtype \u001b[39mif\u001b[39;49;00m t\u001b[39m.\u001b[39;49mis_floating_point() \u001b[39mor\u001b[39;49;00m t\u001b[39m.\u001b[39;49mis_complex() \u001b[39melse\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m, non_blocking)\n", "File \u001b[1;32mc:\\Users\\Dubai Computers\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\cuda\\__init__.py:239\u001b[0m, in \u001b[0;36m_lazy_init\u001b[1;34m()\u001b[0m\n\u001b[0;32m 235\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[0;32m 236\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mCannot re-initialize CUDA in forked subprocess. To use CUDA with \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 237\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mmultiprocessing, you must use the \u001b[39m\u001b[39m'\u001b[39m\u001b[39mspawn\u001b[39m\u001b[39m'\u001b[39m\u001b[39m start method\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 238\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mhasattr\u001b[39m(torch\u001b[39m.\u001b[39m_C, \u001b[39m'\u001b[39m\u001b[39m_cuda_getDeviceCount\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[1;32m--> 239\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAssertionError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mTorch not compiled with CUDA enabled\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m 240\u001b[0m \u001b[39mif\u001b[39;00m _cudart \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m 241\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAssertionError\u001b[39;00m(\n\u001b[0;32m 242\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlibcudart functions unavailable. It looks like you have a broken build?\u001b[39m\u001b[39m\"\u001b[39m)\n", "\u001b[1;31mAssertionError\u001b[0m: Torch not compiled with CUDA enabled" ] } ], "source": [ "# choose good parameters in sam_args based on the first frame segmentation result\n", "# other arguments can be modified in model_args.py\n", "# note the object number limit is 255 by default, which requires < 10GB GPU memory with amp\n", "sam_args['generator_args'] = {\n", " 'points_per_side': 30,\n", " 'pred_iou_thresh': 0.8,\n", " 'stability_score_thresh': 0.9,\n", " 'crop_n_layers': 1,\n", " 'crop_n_points_downscale_factor': 2,\n", " 'min_mask_region_area': 200,\n", " }\n", "\n", "# Set Text args\n", "'''\n", "parameter:\n", " grounding_caption: Text prompt to detect objects in key-frames\n", " box_threshold: threshold for box \n", " text_threshold: threshold for label(text)\n", " box_size_threshold: If the size ratio between the box and the frame is larger than the box_size_threshold, the box will be ignored. This is used to filter out large boxes.\n", " reset_image: reset the image embeddings for SAM\n", "'''\n", "grounding_caption = \"car.suv\"\n", "box_threshold, text_threshold, box_size_threshold, reset_image = 0.35, 0.5, 0.5, True\n", "\n", "cap = cv2.VideoCapture(io_args['input_video'])\n", "frame_idx = 0\n", "segtracker = SegTracker(segtracker_args,sam_args,aot_args)\n", "segtracker.restart_tracker()\n", "with torch.cuda.amp.autocast():\n", " while cap.isOpened():\n", " ret, frame = cap.read()\n", " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n", " pred_mask, annotated_frame = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold)\n", " torch.cuda.empty_cache()\n", " obj_ids = np.unique(pred_mask)\n", " obj_ids = obj_ids[obj_ids!=0]\n", " print(\"processed frame {}, obj_num {}\".format(frame_idx,len(obj_ids)),end='\\n')\n", " break\n", " cap.release()\n", " init_res = draw_mask(annotated_frame, pred_mask,id_countour=False)\n", " plt.figure(figsize=(10,10))\n", " plt.axis('off')\n", " plt.imshow(init_res)\n", " plt.show()\n", " plt.figure(figsize=(10,10))\n", " plt.axis('off')\n", " plt.imshow(colorize_mask(pred_mask))\n", " plt.show()\n", "\n", " del segtracker\n", " torch.cuda.empty_cache()\n", " gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Generate Results for the Whole Video" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For every sam_gap frames, we use SAM to find new objects and add them for tracking\n", "# larger sam_gap is faster but may not spot new objects in time\n", "segtracker_args = {\n", " 'sam_gap': 49, # the interval to run sam to segment new objects\n", " 'min_area': 200, # minimal mask area to add a new mask as a new object\n", " 'max_obj_num': 255, # maximal object number to track in a video\n", " 'min_new_obj_iou': 0.8, # the area of a new object in the background should > 80% \n", "}\n", "\n", "# source video to segment\n", "cap = cv2.VideoCapture(io_args['input_video'])\n", "fps = cap.get(cv2.CAP_PROP_FPS)\n", "# output masks\n", "output_dir = io_args['output_mask_dir']\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)\n", "pred_list = []\n", "masked_pred_list = []\n", "\n", "torch.cuda.empty_cache()\n", "gc.collect()\n", "sam_gap = segtracker_args['sam_gap']\n", "frame_idx = 0\n", "segtracker = SegTracker(segtracker_args, sam_args, aot_args)\n", "segtracker.restart_tracker()\n", "\n", "with torch.cuda.amp.autocast():\n", " while cap.isOpened():\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n", " if frame_idx == 0:\n", " pred_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold, reset_image)\n", " # pred_mask = cv2.imread('./debug/first_frame_mask.png', 0)\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", " segtracker.add_reference(frame, pred_mask)\n", " elif (frame_idx % sam_gap) == 0:\n", " seg_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold, reset_image)\n", " save_prediction(seg_mask, './debug/seg_result', str(frame_idx)+'.png')\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", " track_mask = segtracker.track(frame)\n", " save_prediction(track_mask, './debug/aot_result', str(frame_idx)+'.png')\n", " # find new objects, and update tracker with new objects\n", " new_obj_mask = segtracker.find_new_objs(track_mask, seg_mask)\n", " if np.sum(new_obj_mask > 0) > frame.shape[0] * frame.shape[1] * 0.4:\n", " new_obj_mask = np.zeros_like(new_obj_mask)\n", " save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')\n", " pred_mask = track_mask + new_obj_mask\n", " # segtracker.restart_tracker()\n", " segtracker.add_reference(frame, pred_mask)\n", " else:\n", " pred_mask = segtracker.track(frame,update_memory=True)\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", " \n", " save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')\n", " # masked_frame = draw_mask(frame,pred_mask)\n", " # masked_pred_list.append(masked_frame)\n", " # plt.imshow(masked_frame)\n", " # plt.show() \n", " \n", " pred_list.append(pred_mask)\n", " \n", " \n", " print(\"processed frame {}, obj_num {}\".format(frame_idx,segtracker.get_obj_num()),end='\\r')\n", " frame_idx += 1\n", " cap.release()\n", " print('\\nfinished')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Save results for visualization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# draw pred mask on frame and save as a video\n", "cap = cv2.VideoCapture(io_args['input_video'])\n", "fps = cap.get(cv2.CAP_PROP_FPS)\n", "width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", "height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", "num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", "\n", "if io_args['input_video'][-3:]=='mp4':\n", " fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n", "elif io_args['input_video'][-3:] == 'avi':\n", " fourcc = cv2.VideoWriter_fourcc(*\"MJPG\")\n", " # fourcc = cv2.VideoWriter_fourcc(*\"XVID\")\n", "else:\n", " fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))\n", "out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))\n", "\n", "frame_idx = 0\n", "while cap.isOpened():\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n", " pred_mask = pred_list[frame_idx]\n", " masked_frame = draw_mask(frame,pred_mask)\n", " # masked_frame = masked_pred_list[frame_idx]\n", " masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)\n", " out.write(masked_frame)\n", " print('frame {} writed'.format(frame_idx),end='\\r')\n", " frame_idx += 1\n", "out.release()\n", "cap.release()\n", "print(\"\\n{} saved\".format(io_args['output_video']))\n", "print('\\nfinished')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save colorized masks as a gif\n", "imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)\n", "print(\"{} saved\".format(io_args['output_gif']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "21" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# manually release memory (after cuda out of memory)\n", "del segtracker\n", "torch.cuda.empty_cache()\n", "gc.collect()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.5 64-bit ('ldm': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "536611da043600e50719c9460971b5220bad26cd4a87e5994bfd4c9e9e5e7fb0" } } }, "nbformat": 4, "nbformat_minor": 2 }