Spaces:
Runtime error
Runtime error
File size: 6,223 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from inspect import signature
import mmcv
import torch
from mmcv.image import tensor2imgs
from mmdet.core import bbox_mapping
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
@DETECTORS.register_module()
class RPN(BaseDetector):
"""Implementation of Region Proposal Network."""
def __init__(self,
backbone,
neck,
rpn_head,
train_cfg,
test_cfg,
pretrained=None,
init_cfg=None):
super(RPN, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck) if neck is not None else None
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head.update(train_cfg=rpn_train_cfg)
rpn_head.update(test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feat(self, img):
"""Extract features.
Args:
img (torch.Tensor): Image tensor with shape (n, c, h ,w).
Returns:
list[torch.Tensor]: Multi-level features that may have
different resolutions.
"""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Dummy forward function."""
x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)
return rpn_outs
def forward_train(self,
img,
img_metas,
gt_bboxes=None,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
if (isinstance(self.train_cfg.rpn, dict)
and self.train_cfg.rpn.get('debug', False)):
self.rpn_head.debug_imgs = tensor2imgs(img)
x = self.extract_feat(img)
losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None,
gt_bboxes_ignore)
return losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[np.ndarray]: proposals
"""
x = self.extract_feat(img)
# get origin input shape to onnx dynamic input shape
if torch.onnx.is_in_onnx_export():
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
if rescale:
for proposals, meta in zip(proposal_list, img_metas):
proposals[:, :4] /= proposals.new_tensor(meta['scale_factor'])
if torch.onnx.is_in_onnx_export():
return proposal_list
return [proposal.cpu().numpy() for proposal in proposal_list]
def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[np.ndarray]: proposals
"""
proposal_list = self.rpn_head.aug_test_rpn(
self.extract_feats(imgs), img_metas)
if not rescale:
for proposals, img_meta in zip(proposal_list, img_metas[0]):
img_shape = img_meta['img_shape']
scale_factor = img_meta['scale_factor']
flip = img_meta['flip']
flip_direction = img_meta['flip_direction']
proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
scale_factor, flip,
flip_direction)
return [proposal.cpu().numpy() for proposal in proposal_list]
def show_result(self, data, result, top_k=20, **kwargs):
"""Show RPN proposals on the image.
Args:
data (str or np.ndarray): Image filename or loaded image.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
top_k (int): Plot the first k bboxes only
if set positive. Default: 20
Returns:
np.ndarray: The image with bboxes drawn on it.
"""
if kwargs is not None:
kwargs['colors'] = 'green'
sig = signature(mmcv.imshow_bboxes)
for k in list(kwargs.keys()):
if k not in sig.parameters:
kwargs.pop(k)
mmcv.imshow_bboxes(data, result, top_k=top_k, **kwargs)
|