# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch from mmdet.core import bbox2result, bbox2roi from ..builder import HEADS, build_head, build_roi_extractor from .standard_roi_head import StandardRoIHead @HEADS.register_module() class GridRoIHead(StandardRoIHead): """Grid roi head for Grid R-CNN. https://arxiv.org/abs/1811.12030 """ def __init__(self, grid_roi_extractor, grid_head, **kwargs): assert grid_head is not None super(GridRoIHead, self).__init__(**kwargs) if grid_roi_extractor is not None: self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor) self.share_roi_extractor = False else: self.share_roi_extractor = True self.grid_roi_extractor = self.bbox_roi_extractor self.grid_head = build_head(grid_head) def _random_jitter(self, sampling_results, img_metas, amplitude=0.15): """Ramdom jitter positive proposals for training.""" for sampling_result, img_meta in zip(sampling_results, img_metas): bboxes = sampling_result.pos_bboxes random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_( -amplitude, amplitude) # before jittering cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2 wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs() # after jittering new_cxcy = cxcy + wh * random_offsets[:, :2] new_wh = wh * (1 + random_offsets[:, 2:]) # xywh to xyxy new_x1y1 = (new_cxcy - new_wh / 2) new_x2y2 = (new_cxcy + new_wh / 2) new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1) # clip bboxes max_shape = img_meta['img_shape'] if max_shape is not None: new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1) new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1) sampling_result.pos_bboxes = new_bboxes return sampling_results def forward_dummy(self, x, proposals): """Dummy forward function.""" # bbox head outs = () rois = bbox2roi([proposals]) if self.with_bbox: bbox_results = self._bbox_forward(x, rois) outs = outs + (bbox_results['cls_score'], bbox_results['bbox_pred']) # grid head grid_rois = rois[:100] grid_feats = self.grid_roi_extractor( x[:self.grid_roi_extractor.num_inputs], grid_rois) if self.with_shared_head: grid_feats = self.shared_head(grid_feats) grid_pred = self.grid_head(grid_feats) outs = outs + (grid_pred, ) # mask head if self.with_mask: mask_rois = rois[:100] mask_results = self._mask_forward(x, mask_rois) outs = outs + (mask_results['mask_pred'], ) return outs def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, img_metas): """Run forward function and calculate loss for box head in training.""" bbox_results = super(GridRoIHead, self)._bbox_forward_train(x, sampling_results, gt_bboxes, gt_labels, img_metas) # Grid head forward and loss sampling_results = self._random_jitter(sampling_results, img_metas) pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) # GN in head does not support zero shape input if pos_rois.shape[0] == 0: return bbox_results grid_feats = self.grid_roi_extractor( x[:self.grid_roi_extractor.num_inputs], pos_rois) if self.with_shared_head: grid_feats = self.shared_head(grid_feats) # Accelerate training max_sample_num_grid = self.train_cfg.get('max_num_grid', 192) sample_idx = torch.randperm( grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid )] grid_feats = grid_feats[sample_idx] grid_pred = self.grid_head(grid_feats) grid_targets = self.grid_head.get_targets(sampling_results, self.train_cfg) grid_targets = grid_targets[sample_idx] loss_grid = self.grid_head.loss(grid_pred, grid_targets) bbox_results['loss_bbox'].update(loss_grid) return bbox_results def simple_test(self, x, proposal_list, img_metas, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.simple_test_bboxes( x, img_metas, proposal_list, self.test_cfg, rescale=False) # pack rois into bboxes grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes]) if grid_rois.shape[0] != 0: grid_feats = self.grid_roi_extractor( x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois) self.grid_head.test_mode = True grid_pred = self.grid_head(grid_feats) # split batch grid head prediction back to each image num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes) grid_pred = { k: v.split(num_roi_per_img, 0) for k, v in grid_pred.items() } # apply bbox post-processing to each image individually bbox_results = [] num_imgs = len(det_bboxes) for i in range(num_imgs): if det_bboxes[i].shape[0] == 0: bbox_results.append([ np.zeros((0, 5), dtype=np.float32) for _ in range(self.bbox_head.num_classes) ]) else: det_bbox = self.grid_head.get_bboxes( det_bboxes[i], grid_pred['fused'][i], [img_metas[i]]) if rescale: det_bbox[:, :4] /= img_metas[i]['scale_factor'] bbox_results.append( bbox2result(det_bbox, det_labels[i], self.bbox_head.num_classes)) else: bbox_results = [[ np.zeros((0, 5), dtype=np.float32) for _ in range(self.bbox_head.num_classes) ] for _ in range(len(det_bboxes))] if not self.with_mask: return bbox_results else: segm_results = self.simple_test_mask( x, img_metas, det_bboxes, det_labels, rescale=rescale) return list(zip(bbox_results, segm_results))