# Copyright (c) OpenMMLab. All rights reserved. import torch from mmdet.core import bbox2result from mmdet.models.builder import DETECTORS from ...core.utils import flip_tensor from .single_stage import SingleStageDetector @DETECTORS.register_module() class CenterNet(SingleStageDetector): """Implementation of CenterNet(Objects as Points) . """ def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained, init_cfg) def merge_aug_results(self, aug_results, with_nms): """Merge augmented detection bboxes and score. Args: aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each image. with_nms (bool): If True, do nms before return boxes. Returns: tuple: (out_bboxes, out_labels) """ recovered_bboxes, aug_labels = [], [] for single_result in aug_results: recovered_bboxes.append(single_result[0][0]) aug_labels.append(single_result[0][1]) bboxes = torch.cat(recovered_bboxes, dim=0).contiguous() labels = torch.cat(aug_labels).contiguous() if with_nms: out_bboxes, out_labels = self.bbox_head._bboxes_nms( bboxes, labels, self.bbox_head.test_cfg) else: out_bboxes, out_labels = bboxes, labels return out_bboxes, out_labels def aug_test(self, imgs, img_metas, rescale=True): """Augment testing of CenterNet. Aug test must have flipped image pair, and unlike CornerNet, it will perform an averaging operation on the feature map instead of detecting bbox. Args: imgs (list[Tensor]): Augmented images. img_metas (list[list[dict]]): Meta information of each image, e.g., image size, scaling factor, etc. rescale (bool): If True, return boxes in original image space. Default: True. Note: ``imgs`` must including flipped image pairs. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ img_inds = list(range(len(imgs))) assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( 'aug test must have flipped image pair') aug_results = [] for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): flip_direction = img_metas[flip_ind][0]['flip_direction'] img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) x = self.extract_feat(img_pair) center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x) assert len(center_heatmap_preds) == len(wh_preds) == len( offset_preds) == 1 # Feature map averaging center_heatmap_preds[0] = ( center_heatmap_preds[0][0:1] + flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2 wh_preds[0] = (wh_preds[0][0:1] + flip_tensor(wh_preds[0][1:2], flip_direction)) / 2 bbox_list = self.bbox_head.get_bboxes( center_heatmap_preds, wh_preds, [offset_preds[0][0:1]], img_metas[ind], rescale=rescale, with_nms=False) aug_results.append(bbox_list) nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None) if nms_cfg is None: with_nms = False else: with_nms = True bbox_list = [self.merge_aug_results(aug_results, with_nms)] bbox_results = [ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results