Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from abc import abstractmethod | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmcv.runner import force_fp32 | |
from mmdet.core import build_bbox_coder, multi_apply | |
from mmdet.core.anchor.point_generator import MlvlPointGenerator | |
from ..builder import HEADS, build_loss | |
from .base_dense_head import BaseDenseHead | |
from .dense_test_mixins import BBoxTestMixin | |
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin): | |
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.). | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
feat_channels (int): Number of hidden channels. Used in child classes. | |
stacked_convs (int): Number of stacking convs of the head. | |
strides (tuple): Downsample factor of each feature map. | |
dcn_on_last_conv (bool): If true, use dcn in the last layer of | |
towers. Default: False. | |
conv_bias (bool | str): If specified as `auto`, it will be decided by | |
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is | |
None, otherwise False. Default: "auto". | |
loss_cls (dict): Config of classification loss. | |
loss_bbox (dict): Config of localization loss. | |
bbox_coder (dict): Config of bbox coder. Defaults | |
'DistancePointBBoxCoder'. | |
conv_cfg (dict): Config dict for convolution layer. Default: None. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
train_cfg (dict): Training config of anchor head. | |
test_cfg (dict): Testing config of anchor head. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" # noqa: W605 | |
_version = 1 | |
def __init__(self, | |
num_classes, | |
in_channels, | |
feat_channels=256, | |
stacked_convs=4, | |
strides=(4, 8, 16, 32, 64), | |
dcn_on_last_conv=False, | |
conv_bias='auto', | |
loss_cls=dict( | |
type='FocalLoss', | |
use_sigmoid=True, | |
gamma=2.0, | |
alpha=0.25, | |
loss_weight=1.0), | |
loss_bbox=dict(type='IoULoss', loss_weight=1.0), | |
bbox_coder=dict(type='DistancePointBBoxCoder'), | |
conv_cfg=None, | |
norm_cfg=None, | |
train_cfg=None, | |
test_cfg=None, | |
init_cfg=dict( | |
type='Normal', | |
layer='Conv2d', | |
std=0.01, | |
override=dict( | |
type='Normal', | |
name='conv_cls', | |
std=0.01, | |
bias_prob=0.01))): | |
super(AnchorFreeHead, self).__init__(init_cfg) | |
self.num_classes = num_classes | |
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) | |
if self.use_sigmoid_cls: | |
self.cls_out_channels = num_classes | |
else: | |
self.cls_out_channels = num_classes + 1 | |
self.in_channels = in_channels | |
self.feat_channels = feat_channels | |
self.stacked_convs = stacked_convs | |
self.strides = strides | |
self.dcn_on_last_conv = dcn_on_last_conv | |
assert conv_bias == 'auto' or isinstance(conv_bias, bool) | |
self.conv_bias = conv_bias | |
self.loss_cls = build_loss(loss_cls) | |
self.loss_bbox = build_loss(loss_bbox) | |
self.bbox_coder = build_bbox_coder(bbox_coder) | |
self.prior_generator = MlvlPointGenerator(strides) | |
# In order to keep a more general interface and be consistent with | |
# anchor_head. We can think of point like one anchor | |
self.num_base_priors = self.prior_generator.num_base_priors[0] | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.fp16_enabled = False | |
self._init_layers() | |
def _init_layers(self): | |
"""Initialize layers of the head.""" | |
self._init_cls_convs() | |
self._init_reg_convs() | |
self._init_predictor() | |
def _init_cls_convs(self): | |
"""Initialize classification conv layers of the head.""" | |
self.cls_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
if self.dcn_on_last_conv and i == self.stacked_convs - 1: | |
conv_cfg = dict(type='DCNv2') | |
else: | |
conv_cfg = self.conv_cfg | |
self.cls_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.conv_bias)) | |
def _init_reg_convs(self): | |
"""Initialize bbox regression conv layers of the head.""" | |
self.reg_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
if self.dcn_on_last_conv and i == self.stacked_convs - 1: | |
conv_cfg = dict(type='DCNv2') | |
else: | |
conv_cfg = self.conv_cfg | |
self.reg_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.conv_bias)) | |
def _init_predictor(self): | |
"""Initialize predictor layers of the head.""" | |
self.conv_cls = nn.Conv2d( | |
self.feat_channels, self.cls_out_channels, 3, padding=1) | |
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
"""Hack some keys of the model state dict so that can load checkpoints | |
of previous version.""" | |
version = local_metadata.get('version', None) | |
if version is None: | |
# the key is different in early versions | |
# for example, 'fcos_cls' become 'conv_cls' now | |
bbox_head_keys = [ | |
k for k in state_dict.keys() if k.startswith(prefix) | |
] | |
ori_predictor_keys = [] | |
new_predictor_keys = [] | |
# e.g. 'fcos_cls' or 'fcos_reg' | |
for key in bbox_head_keys: | |
ori_predictor_keys.append(key) | |
key = key.split('.') | |
conv_name = None | |
if key[1].endswith('cls'): | |
conv_name = 'conv_cls' | |
elif key[1].endswith('reg'): | |
conv_name = 'conv_reg' | |
elif key[1].endswith('centerness'): | |
conv_name = 'conv_centerness' | |
else: | |
assert NotImplementedError | |
if conv_name is not None: | |
key[1] = conv_name | |
new_predictor_keys.append('.'.join(key)) | |
else: | |
ori_predictor_keys.pop(-1) | |
for i in range(len(new_predictor_keys)): | |
state_dict[new_predictor_keys[i]] = state_dict.pop( | |
ori_predictor_keys[i]) | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, | |
strict, missing_keys, unexpected_keys, | |
error_msgs) | |
def forward(self, feats): | |
"""Forward features from the upstream network. | |
Args: | |
feats (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
tuple: Usually contain classification scores and bbox predictions. | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_points * num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_points * 4. | |
""" | |
return multi_apply(self.forward_single, feats)[:2] | |
def forward_single(self, x): | |
"""Forward features of a single scale level. | |
Args: | |
x (Tensor): FPN feature maps of the specified stride. | |
Returns: | |
tuple: Scores for each class, bbox predictions, features | |
after classification and regression conv layers, some | |
models needs these features like FCOS. | |
""" | |
cls_feat = x | |
reg_feat = x | |
for cls_layer in self.cls_convs: | |
cls_feat = cls_layer(cls_feat) | |
cls_score = self.conv_cls(cls_feat) | |
for reg_layer in self.reg_convs: | |
reg_feat = reg_layer(reg_feat) | |
bbox_pred = self.conv_reg(reg_feat) | |
return cls_score, bbox_pred, cls_feat, reg_feat | |
def loss(self, | |
cls_scores, | |
bbox_preds, | |
gt_bboxes, | |
gt_labels, | |
img_metas, | |
gt_bboxes_ignore=None): | |
"""Compute loss of the head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_points * num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_points * 4. | |
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |
gt_labels (list[Tensor]): class indices corresponding to each box | |
img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |
boxes can be ignored when computing the loss. | |
""" | |
raise NotImplementedError | |
def get_targets(self, points, gt_bboxes_list, gt_labels_list): | |
"""Compute regression, classification and centerness targets for points | |
in multiple images. | |
Args: | |
points (list[Tensor]): Points of each fpn level, each has shape | |
(num_points, 2). | |
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, | |
each has shape (num_gt, 4). | |
gt_labels_list (list[Tensor]): Ground truth labels of each box, | |
each has shape (num_gt,). | |
""" | |
raise NotImplementedError | |
def _get_points_single(self, | |
featmap_size, | |
stride, | |
dtype, | |
device, | |
flatten=False): | |
"""Get points of a single scale level. | |
This function will be deprecated soon. | |
""" | |
warnings.warn( | |
'`_get_points_single` in `AnchorFreeHead` will be ' | |
'deprecated soon, we support a multi level point generator now' | |
'you can get points of a single level feature map ' | |
'with `self.prior_generator.single_level_grid_priors` ') | |
h, w = featmap_size | |
# First create Range with the default dtype, than convert to | |
# target `dtype` for onnx exporting. | |
x_range = torch.arange(w, device=device).to(dtype) | |
y_range = torch.arange(h, device=device).to(dtype) | |
y, x = torch.meshgrid(y_range, x_range) | |
if flatten: | |
y = y.flatten() | |
x = x.flatten() | |
return y, x | |
def get_points(self, featmap_sizes, dtype, device, flatten=False): | |
"""Get points according to feature map sizes. | |
Args: | |
featmap_sizes (list[tuple]): Multi-level feature map sizes. | |
dtype (torch.dtype): Type of points. | |
device (torch.device): Device of points. | |
Returns: | |
tuple: points of each image. | |
""" | |
warnings.warn( | |
'`get_points` in `AnchorFreeHead` will be ' | |
'deprecated soon, we support a multi level point generator now' | |
'you can get points of all levels ' | |
'with `self.prior_generator.grid_priors` ') | |
mlvl_points = [] | |
for i in range(len(featmap_sizes)): | |
mlvl_points.append( | |
self._get_points_single(featmap_sizes[i], self.strides[i], | |
dtype, device, flatten)) | |
return mlvl_points | |
def aug_test(self, feats, img_metas, rescale=False): | |
"""Test function with test time augmentation. | |
Args: | |
feats (list[Tensor]): the outer list indicates test-time | |
augmentations and inner Tensor should have a shape NxCxHxW, | |
which contains features for all images in the batch. | |
img_metas (list[list[dict]]): the outer list indicates test-time | |
augs (multiscale, flip, etc.) and the inner list indicates | |
images in a batch. each dict has image information. | |
rescale (bool, optional): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[ndarray]: bbox results of each class | |
""" | |
return self.aug_test_bboxes(feats, img_metas, rescale=rescale) | |