Spaces:
ginipick
/
Running on Zero

File size: 5,288 Bytes
b213d84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.

import numpy as np
from typing import List

from detectron2.config import CfgNode as CfgNode_
from detectron2.config import configurable
from detectron2.structures import Instances
from detectron2.structures.boxes import pairwise_iou
from detectron2.tracking.utils import LARGE_COST_VALUE, create_prediction_pairs

from .base_tracker import TRACKER_HEADS_REGISTRY
from .hungarian_tracker import BaseHungarianTracker


@TRACKER_HEADS_REGISTRY.register()
class VanillaHungarianBBoxIOUTracker(BaseHungarianTracker):
    """
    Hungarian algo based tracker using bbox iou as metric
    """

    @configurable
    def __init__(
        self,
        *,
        video_height: int,
        video_width: int,
        max_num_instances: int = 200,
        max_lost_frame_count: int = 0,
        min_box_rel_dim: float = 0.02,
        min_instance_period: int = 1,
        track_iou_threshold: float = 0.5,
        **kwargs,
    ):
        """
        Args:
        video_height: height the video frame
        video_width: width of the video frame
        max_num_instances: maximum number of id allowed to be tracked
        max_lost_frame_count: maximum number of frame an id can lost tracking
                              exceed this number, an id is considered as lost
                              forever
        min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
                         removed from tracking
        min_instance_period: an instance will be shown after this number of period
                             since its first showing up in the video
        track_iou_threshold: iou threshold, below this number a bbox pair is removed
                             from tracking
        """
        super().__init__(
            video_height=video_height,
            video_width=video_width,
            max_num_instances=max_num_instances,
            max_lost_frame_count=max_lost_frame_count,
            min_box_rel_dim=min_box_rel_dim,
            min_instance_period=min_instance_period,
        )
        self._track_iou_threshold = track_iou_threshold

    @classmethod
    def from_config(cls, cfg: CfgNode_):
        """
        Old style initialization using CfgNode

        Args:
            cfg: D2 CfgNode, config file
        Return:
            dictionary storing arguments for __init__ method
        """
        assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS
        assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS
        video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT")
        video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH")
        max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200)
        max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0)
        min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02)
        min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1)
        track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5)
        return {
            "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker",  # noqa
            "video_height": video_height,
            "video_width": video_width,
            "max_num_instances": max_num_instances,
            "max_lost_frame_count": max_lost_frame_count,
            "min_box_rel_dim": min_box_rel_dim,
            "min_instance_period": min_instance_period,
            "track_iou_threshold": track_iou_threshold,
        }

    def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray:
        """
        Build the cost matrix for assignment problem
        (https://en.wikipedia.org/wiki/Assignment_problem)

        Args:
            instances: D2 Instances, for current frame predictions
            prev_instances: D2 Instances, for previous frame predictions

        Return:
            the cost matrix in numpy array
        """
        assert instances is not None and prev_instances is not None
        # calculate IoU of all bbox pairs
        iou_all = pairwise_iou(
            boxes1=instances.pred_boxes,
            boxes2=self._prev_instances.pred_boxes,
        )
        bbox_pairs = create_prediction_pairs(
            instances, self._prev_instances, iou_all, threshold=self._track_iou_threshold
        )
        # assign large cost value to make sure pair below IoU threshold won't be matched
        cost_matrix = np.full((len(instances), len(prev_instances)), LARGE_COST_VALUE)
        return self.assign_cost_matrix_values(cost_matrix, bbox_pairs)

    def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray:
        """
        Based on IoU for each pair of bbox, assign the associated value in cost matrix

        Args:
            cost_matrix: np.ndarray, initialized 2D array with target dimensions
            bbox_pairs: list of bbox pair, in each pair, iou value is stored
        Return:
            np.ndarray, cost_matrix with assigned values
        """
        for pair in bbox_pairs:
            # assign -1 for IoU above threshold pairs, algorithms will minimize cost
            cost_matrix[pair["idx"]][pair["prev_idx"]] = -1
        return cost_matrix