Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
from detectron2.layers import batched_nms | |
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads | |
from detectron2.modeling.roi_heads.roi_heads import Res5ROIHeads | |
from detectron2.structures import Instances | |
def merge_branch_instances(instances, num_branch, nms_thresh, topk_per_image): | |
""" | |
Merge detection results from different branches of TridentNet. | |
Return detection results by applying non-maximum suppression (NMS) on bounding boxes | |
and keep the unsuppressed boxes and other instances (e.g mask) if any. | |
Args: | |
instances (list[Instances]): A list of N * num_branch instances that store detection | |
results. Contain N images and each image has num_branch instances. | |
num_branch (int): Number of branches used for merging detection results for each image. | |
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. | |
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return | |
all detections. | |
Returns: | |
results: (list[Instances]): A list of N instances, one for each image in the batch, | |
that stores the topk most confidence detections after merging results from multiple | |
branches. | |
""" | |
if num_branch == 1: | |
return instances | |
batch_size = len(instances) // num_branch | |
results = [] | |
for i in range(batch_size): | |
instance = Instances.cat([instances[i + batch_size * j] for j in range(num_branch)]) | |
# Apply per-class NMS | |
keep = batched_nms( | |
instance.pred_boxes.tensor, instance.scores, instance.pred_classes, nms_thresh | |
) | |
keep = keep[:topk_per_image] | |
result = instance[keep] | |
results.append(result) | |
return results | |
class TridentRes5ROIHeads(Res5ROIHeads): | |
""" | |
The TridentNet ROIHeads in a typical "C4" R-CNN model. | |
See :class:`Res5ROIHeads`. | |
""" | |
def __init__(self, cfg, input_shape): | |
super().__init__(cfg, input_shape) | |
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH | |
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 | |
def forward(self, images, features, proposals, targets=None): | |
""" | |
See :class:`Res5ROIHeads.forward`. | |
""" | |
num_branch = self.num_branch if self.training or not self.trident_fast else 1 | |
all_targets = targets * num_branch if targets is not None else None | |
pred_instances, losses = super().forward(images, features, proposals, all_targets) | |
del images, all_targets, targets | |
if self.training: | |
return pred_instances, losses | |
else: | |
pred_instances = merge_branch_instances( | |
pred_instances, | |
num_branch, | |
self.box_predictor.test_nms_thresh, | |
self.box_predictor.test_topk_per_image, | |
) | |
return pred_instances, {} | |
class TridentStandardROIHeads(StandardROIHeads): | |
""" | |
The `StandardROIHeads` for TridentNet. | |
See :class:`StandardROIHeads`. | |
""" | |
def __init__(self, cfg, input_shape): | |
super(TridentStandardROIHeads, self).__init__(cfg, input_shape) | |
self.num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH | |
self.trident_fast = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX != -1 | |
def forward(self, images, features, proposals, targets=None): | |
""" | |
See :class:`Res5ROIHeads.forward`. | |
""" | |
# Use 1 branch if using trident_fast during inference. | |
num_branch = self.num_branch if self.training or not self.trident_fast else 1 | |
# Duplicate targets for all branches in TridentNet. | |
all_targets = targets * num_branch if targets is not None else None | |
pred_instances, losses = super().forward(images, features, proposals, all_targets) | |
del images, all_targets, targets | |
if self.training: | |
return pred_instances, losses | |
else: | |
pred_instances = merge_branch_instances( | |
pred_instances, | |
num_branch, | |
self.box_predictor.test_nms_thresh, | |
self.box_predictor.test_topk_per_image, | |
) | |
return pred_instances, {} | |