Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch | |
from mmdet.core import bbox2result, bbox2roi | |
from ..builder import HEADS, build_head, build_roi_extractor | |
from .standard_roi_head import StandardRoIHead | |
class GridRoIHead(StandardRoIHead): | |
"""Grid roi head for Grid R-CNN. | |
https://arxiv.org/abs/1811.12030 | |
""" | |
def __init__(self, grid_roi_extractor, grid_head, **kwargs): | |
assert grid_head is not None | |
super(GridRoIHead, self).__init__(**kwargs) | |
if grid_roi_extractor is not None: | |
self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor) | |
self.share_roi_extractor = False | |
else: | |
self.share_roi_extractor = True | |
self.grid_roi_extractor = self.bbox_roi_extractor | |
self.grid_head = build_head(grid_head) | |
def _random_jitter(self, sampling_results, img_metas, amplitude=0.15): | |
"""Ramdom jitter positive proposals for training.""" | |
for sampling_result, img_meta in zip(sampling_results, img_metas): | |
bboxes = sampling_result.pos_bboxes | |
random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_( | |
-amplitude, amplitude) | |
# before jittering | |
cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2 | |
wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs() | |
# after jittering | |
new_cxcy = cxcy + wh * random_offsets[:, :2] | |
new_wh = wh * (1 + random_offsets[:, 2:]) | |
# xywh to xyxy | |
new_x1y1 = (new_cxcy - new_wh / 2) | |
new_x2y2 = (new_cxcy + new_wh / 2) | |
new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1) | |
# clip bboxes | |
max_shape = img_meta['img_shape'] | |
if max_shape is not None: | |
new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1) | |
new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1) | |
sampling_result.pos_bboxes = new_bboxes | |
return sampling_results | |
def forward_dummy(self, x, proposals): | |
"""Dummy forward function.""" | |
# bbox head | |
outs = () | |
rois = bbox2roi([proposals]) | |
if self.with_bbox: | |
bbox_results = self._bbox_forward(x, rois) | |
outs = outs + (bbox_results['cls_score'], | |
bbox_results['bbox_pred']) | |
# grid head | |
grid_rois = rois[:100] | |
grid_feats = self.grid_roi_extractor( | |
x[:self.grid_roi_extractor.num_inputs], grid_rois) | |
if self.with_shared_head: | |
grid_feats = self.shared_head(grid_feats) | |
grid_pred = self.grid_head(grid_feats) | |
outs = outs + (grid_pred, ) | |
# mask head | |
if self.with_mask: | |
mask_rois = rois[:100] | |
mask_results = self._mask_forward(x, mask_rois) | |
outs = outs + (mask_results['mask_pred'], ) | |
return outs | |
def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, | |
img_metas): | |
"""Run forward function and calculate loss for box head in training.""" | |
bbox_results = super(GridRoIHead, | |
self)._bbox_forward_train(x, sampling_results, | |
gt_bboxes, gt_labels, | |
img_metas) | |
# Grid head forward and loss | |
sampling_results = self._random_jitter(sampling_results, img_metas) | |
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) | |
# GN in head does not support zero shape input | |
if pos_rois.shape[0] == 0: | |
return bbox_results | |
grid_feats = self.grid_roi_extractor( | |
x[:self.grid_roi_extractor.num_inputs], pos_rois) | |
if self.with_shared_head: | |
grid_feats = self.shared_head(grid_feats) | |
# Accelerate training | |
max_sample_num_grid = self.train_cfg.get('max_num_grid', 192) | |
sample_idx = torch.randperm( | |
grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid | |
)] | |
grid_feats = grid_feats[sample_idx] | |
grid_pred = self.grid_head(grid_feats) | |
grid_targets = self.grid_head.get_targets(sampling_results, | |
self.train_cfg) | |
grid_targets = grid_targets[sample_idx] | |
loss_grid = self.grid_head.loss(grid_pred, grid_targets) | |
bbox_results['loss_bbox'].update(loss_grid) | |
return bbox_results | |
def simple_test(self, | |
x, | |
proposal_list, | |
img_metas, | |
proposals=None, | |
rescale=False): | |
"""Test without augmentation.""" | |
assert self.with_bbox, 'Bbox head must be implemented.' | |
det_bboxes, det_labels = self.simple_test_bboxes( | |
x, img_metas, proposal_list, self.test_cfg, rescale=False) | |
# pack rois into bboxes | |
grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes]) | |
if grid_rois.shape[0] != 0: | |
grid_feats = self.grid_roi_extractor( | |
x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois) | |
self.grid_head.test_mode = True | |
grid_pred = self.grid_head(grid_feats) | |
# split batch grid head prediction back to each image | |
num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes) | |
grid_pred = { | |
k: v.split(num_roi_per_img, 0) | |
for k, v in grid_pred.items() | |
} | |
# apply bbox post-processing to each image individually | |
bbox_results = [] | |
num_imgs = len(det_bboxes) | |
for i in range(num_imgs): | |
if det_bboxes[i].shape[0] == 0: | |
bbox_results.append([ | |
np.zeros((0, 5), dtype=np.float32) | |
for _ in range(self.bbox_head.num_classes) | |
]) | |
else: | |
det_bbox = self.grid_head.get_bboxes( | |
det_bboxes[i], grid_pred['fused'][i], [img_metas[i]]) | |
if rescale: | |
det_bbox[:, :4] /= img_metas[i]['scale_factor'] | |
bbox_results.append( | |
bbox2result(det_bbox, det_labels[i], | |
self.bbox_head.num_classes)) | |
else: | |
bbox_results = [[ | |
np.zeros((0, 5), dtype=np.float32) | |
for _ in range(self.bbox_head.num_classes) | |
] for _ in range(len(det_bboxes))] | |
if not self.with_mask: | |
return bbox_results | |
else: | |
segm_results = self.simple_test_mask( | |
x, img_metas, det_bboxes, det_labels, rescale=rescale) | |
return list(zip(bbox_results, segm_results)) | |