Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmdet.core import bbox2result | |
from mmdet.models.builder import DETECTORS | |
from ...core.utils import flip_tensor | |
from .single_stage import SingleStageDetector | |
class CenterNet(SingleStageDetector): | |
"""Implementation of CenterNet(Objects as Points) | |
<https://arxiv.org/abs/1904.07850>. | |
""" | |
def __init__(self, | |
backbone, | |
neck, | |
bbox_head, | |
train_cfg=None, | |
test_cfg=None, | |
pretrained=None, | |
init_cfg=None): | |
super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg, | |
test_cfg, pretrained, init_cfg) | |
def merge_aug_results(self, aug_results, with_nms): | |
"""Merge augmented detection bboxes and score. | |
Args: | |
aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each | |
image. | |
with_nms (bool): If True, do nms before return boxes. | |
Returns: | |
tuple: (out_bboxes, out_labels) | |
""" | |
recovered_bboxes, aug_labels = [], [] | |
for single_result in aug_results: | |
recovered_bboxes.append(single_result[0][0]) | |
aug_labels.append(single_result[0][1]) | |
bboxes = torch.cat(recovered_bboxes, dim=0).contiguous() | |
labels = torch.cat(aug_labels).contiguous() | |
if with_nms: | |
out_bboxes, out_labels = self.bbox_head._bboxes_nms( | |
bboxes, labels, self.bbox_head.test_cfg) | |
else: | |
out_bboxes, out_labels = bboxes, labels | |
return out_bboxes, out_labels | |
def aug_test(self, imgs, img_metas, rescale=True): | |
"""Augment testing of CenterNet. Aug test must have flipped image pair, | |
and unlike CornerNet, it will perform an averaging operation on the | |
feature map instead of detecting bbox. | |
Args: | |
imgs (list[Tensor]): Augmented images. | |
img_metas (list[list[dict]]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
rescale (bool): If True, return boxes in original image space. | |
Default: True. | |
Note: | |
``imgs`` must including flipped image pairs. | |
Returns: | |
list[list[np.ndarray]]: BBox results of each image and classes. | |
The outer list corresponds to each image. The inner list | |
corresponds to each class. | |
""" | |
img_inds = list(range(len(imgs))) | |
assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( | |
'aug test must have flipped image pair') | |
aug_results = [] | |
for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): | |
flip_direction = img_metas[flip_ind][0]['flip_direction'] | |
img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) | |
x = self.extract_feat(img_pair) | |
center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x) | |
assert len(center_heatmap_preds) == len(wh_preds) == len( | |
offset_preds) == 1 | |
# Feature map averaging | |
center_heatmap_preds[0] = ( | |
center_heatmap_preds[0][0:1] + | |
flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2 | |
wh_preds[0] = (wh_preds[0][0:1] + | |
flip_tensor(wh_preds[0][1:2], flip_direction)) / 2 | |
bbox_list = self.bbox_head.get_bboxes( | |
center_heatmap_preds, | |
wh_preds, [offset_preds[0][0:1]], | |
img_metas[ind], | |
rescale=rescale, | |
with_nms=False) | |
aug_results.append(bbox_list) | |
nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None) | |
if nms_cfg is None: | |
with_nms = False | |
else: | |
with_nms = True | |
bbox_list = [self.merge_aug_results(aug_results, with_nms)] | |
bbox_results = [ | |
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) | |
for det_bboxes, det_labels in bbox_list | |
] | |
return bbox_results | |