Spaces:
Running
on
L40S
Running
on
L40S
File size: 2,208 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
from detectron2.config import configurable
from detectron2.utils.registry import Registry
from ..config.config import CfgNode as CfgNode_
from ..structures import Instances
TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS")
TRACKER_HEADS_REGISTRY.__doc__ = """
Registry for tracking classes.
"""
class BaseTracker:
"""
A parent class for all trackers
"""
@configurable
def __init__(self, **kwargs):
self._prev_instances = None # (D2)instances for previous frame
self._matched_idx = set() # indices in prev_instances found matching
self._matched_ID = set() # idendities in prev_instances found matching
self._untracked_prev_idx = set() # indices in prev_instances not found matching
self._id_count = 0 # used to assign new id
@classmethod
def from_config(cls, cfg: CfgNode_):
raise NotImplementedError("Calling BaseTracker::from_config")
def update(self, predictions: Instances) -> Instances:
"""
Args:
predictions: D2 Instances for predictions of the current frame
Return:
D2 Instances for predictions of the current frame with ID assigned
_prev_instances and instances will have the following fields:
.pred_boxes (shape=[N, 4])
.scores (shape=[N,])
.pred_classes (shape=[N,])
.pred_keypoints (shape=[N, M, 3], Optional)
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W]
.ID (shape=[N,])
N: # of detected bboxes
H and W: height and width of 2D mask
"""
raise NotImplementedError("Calling BaseTracker::update")
def build_tracker_head(cfg: CfgNode_) -> BaseTracker:
"""
Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`.
Args:
cfg: D2 CfgNode, config file with tracker information
Return:
tracker object
"""
name = cfg.TRACKER_HEADS.TRACKER_NAME
tracker_class = TRACKER_HEADS_REGISTRY.get(name)
return tracker_class(cfg)
|