# Copyright (c) OpenMMLab. All rights reserved. import copy import warnings import torch import torch.nn as nn import torch.nn.functional as F from mmcv import ConfigDict from mmcv.ops import nms from ..builder import HEADS from .guided_anchor_head import GuidedAnchorHead @HEADS.register_module() class GARPNHead(GuidedAnchorHead): """Guided-Anchor-based RPN head.""" def __init__(self, in_channels, init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='conv_loc', std=0.01, bias_prob=0.01)), **kwargs): super(GARPNHead, self).__init__( 1, in_channels, init_cfg=init_cfg, **kwargs) def _init_layers(self): """Initialize layers of the head.""" self.rpn_conv = nn.Conv2d( self.in_channels, self.feat_channels, 3, padding=1) super(GARPNHead, self)._init_layers() def forward_single(self, x): """Forward feature of a single scale level.""" x = self.rpn_conv(x) x = F.relu(x, inplace=True) (cls_score, bbox_pred, shape_pred, loc_pred) = super(GARPNHead, self).forward_single(x) return cls_score, bbox_pred, shape_pred, loc_pred def loss(self, cls_scores, bbox_preds, shape_preds, loc_preds, gt_bboxes, img_metas, gt_bboxes_ignore=None): losses = super(GARPNHead, self).loss( cls_scores, bbox_preds, shape_preds, loc_preds, gt_bboxes, None, img_metas, gt_bboxes_ignore=gt_bboxes_ignore) return dict( loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'], loss_anchor_shape=losses['loss_shape'], loss_anchor_loc=losses['loss_loc']) def _get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors, mlvl_masks, img_shape, scale_factor, cfg, rescale=False): cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) # deprecate arguments warning if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: warnings.warn( 'In rpn_proposal or test_cfg, ' 'nms_thr has been moved to a dict named nms as ' 'iou_threshold, max_num has been renamed as max_per_img, ' 'name of original arguments and the way to specify ' 'iou_threshold of NMS will be deprecated.') if 'nms' not in cfg: cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) if 'max_num' in cfg: if 'max_per_img' in cfg: assert cfg.max_num == cfg.max_per_img, f'You ' \ f'set max_num and max_per_img at the same time, ' \ f'but get {cfg.max_num} ' \ f'and {cfg.max_per_img} respectively' \ 'Please delete max_num which will be deprecated.' else: cfg.max_per_img = cfg.max_num if 'nms_thr' in cfg: assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \ f'iou_threshold in nms and ' \ f'nms_thr at the same time, but get ' \ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \ f' respectively. Please delete the ' \ f'nms_thr which will be deprecated.' assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \ 'naive nms.' mlvl_proposals = [] for idx in range(len(cls_scores)): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] anchors = mlvl_anchors[idx] mask = mlvl_masks[idx] assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] # if no location is kept, end. if mask.sum() == 0: continue rpn_cls_score = rpn_cls_score.permute(1, 2, 0) if self.use_sigmoid_cls: rpn_cls_score = rpn_cls_score.reshape(-1) scores = rpn_cls_score.sigmoid() else: rpn_cls_score = rpn_cls_score.reshape(-1, 2) # remind that we set FG labels to [0, num_class-1] # since mmdet v2.0 # BG cat_id: num_class scores = rpn_cls_score.softmax(dim=1)[:, :-1] # filter scores, bbox_pred w.r.t. mask. # anchors are filtered in get_anchors() beforehand. scores = scores[mask] rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)[mask, :] if scores.dim() == 0: rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0) anchors = anchors.unsqueeze(0) scores = scores.unsqueeze(0) # filter anchors, bbox_pred, scores w.r.t. scores if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: _, topk_inds = scores.topk(cfg.nms_pre) rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] anchors = anchors[topk_inds, :] scores = scores[topk_inds] # get proposals w.r.t. anchors and rpn_bbox_pred proposals = self.bbox_coder.decode( anchors, rpn_bbox_pred, max_shape=img_shape) # filter out too small bboxes if cfg.min_bbox_size >= 0: w = proposals[:, 2] - proposals[:, 0] h = proposals[:, 3] - proposals[:, 1] valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) if not valid_mask.all(): proposals = proposals[valid_mask] scores = scores[valid_mask] # NMS in current level proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold) proposals = proposals[:cfg.nms_post, :] mlvl_proposals.append(proposals) proposals = torch.cat(mlvl_proposals, 0) if cfg.get('nms_across_levels', False): # NMS across multi levels proposals, _ = nms(proposals[:, :4], proposals[:, -1], cfg.nms.iou_threshold) proposals = proposals[:cfg.max_per_img, :] else: scores = proposals[:, 4] num = min(cfg.max_per_img, proposals.shape[0]) _, topk_inds = scores.topk(num) proposals = proposals[topk_inds, :] return proposals