Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmdet.models.builder import HEADS | |
from mmdet.models.utils import build_linear_layer | |
from .bbox_head import BBoxHead | |
class ConvFCBBoxHead(BBoxHead): | |
r"""More general bbox head, with shared conv and fc layers and two optional | |
separated branches. | |
.. code-block:: none | |
/-> cls convs -> cls fcs -> cls | |
shared convs -> shared fcs | |
\-> reg convs -> reg fcs -> reg | |
""" # noqa: W605 | |
def __init__(self, | |
num_shared_convs=0, | |
num_shared_fcs=0, | |
num_cls_convs=0, | |
num_cls_fcs=0, | |
num_reg_convs=0, | |
num_reg_fcs=0, | |
conv_out_channels=256, | |
fc_out_channels=1024, | |
conv_cfg=None, | |
norm_cfg=None, | |
init_cfg=None, | |
*args, | |
**kwargs): | |
super(ConvFCBBoxHead, self).__init__( | |
*args, init_cfg=init_cfg, **kwargs) | |
assert (num_shared_convs + num_shared_fcs + num_cls_convs + | |
num_cls_fcs + num_reg_convs + num_reg_fcs > 0) | |
if num_cls_convs > 0 or num_reg_convs > 0: | |
assert num_shared_fcs == 0 | |
if not self.with_cls: | |
assert num_cls_convs == 0 and num_cls_fcs == 0 | |
if not self.with_reg: | |
assert num_reg_convs == 0 and num_reg_fcs == 0 | |
self.num_shared_convs = num_shared_convs | |
self.num_shared_fcs = num_shared_fcs | |
self.num_cls_convs = num_cls_convs | |
self.num_cls_fcs = num_cls_fcs | |
self.num_reg_convs = num_reg_convs | |
self.num_reg_fcs = num_reg_fcs | |
self.conv_out_channels = conv_out_channels | |
self.fc_out_channels = fc_out_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
# add shared convs and fcs | |
self.shared_convs, self.shared_fcs, last_layer_dim = \ | |
self._add_conv_fc_branch( | |
self.num_shared_convs, self.num_shared_fcs, self.in_channels, | |
True) | |
self.shared_out_channels = last_layer_dim | |
# add cls specific branch | |
self.cls_convs, self.cls_fcs, self.cls_last_dim = \ | |
self._add_conv_fc_branch( | |
self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels) | |
# add reg specific branch | |
self.reg_convs, self.reg_fcs, self.reg_last_dim = \ | |
self._add_conv_fc_branch( | |
self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels) | |
if self.num_shared_fcs == 0 and not self.with_avg_pool: | |
if self.num_cls_fcs == 0: | |
self.cls_last_dim *= self.roi_feat_area | |
if self.num_reg_fcs == 0: | |
self.reg_last_dim *= self.roi_feat_area | |
self.relu = nn.ReLU(inplace=True) | |
# reconstruct fc_cls and fc_reg since input channels are changed | |
if self.with_cls: | |
if self.custom_cls_channels: | |
cls_channels = self.loss_cls.get_cls_channels(self.num_classes) | |
else: | |
cls_channels = self.num_classes + 1 | |
self.fc_cls = build_linear_layer( | |
self.cls_predictor_cfg, | |
in_features=self.cls_last_dim, | |
out_features=cls_channels) | |
if self.with_reg: | |
out_dim_reg = (4 if self.reg_class_agnostic else 4 * | |
self.num_classes) | |
self.fc_reg = build_linear_layer( | |
self.reg_predictor_cfg, | |
in_features=self.reg_last_dim, | |
out_features=out_dim_reg) | |
if init_cfg is None: | |
# when init_cfg is None, | |
# It has been set to | |
# [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))], | |
# [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))] | |
# after `super(ConvFCBBoxHead, self).__init__()` | |
# we only need to append additional configuration | |
# for `shared_fcs`, `cls_fcs` and `reg_fcs` | |
self.init_cfg += [ | |
dict( | |
type='Xavier', | |
distribution='uniform', | |
override=[ | |
dict(name='shared_fcs'), | |
dict(name='cls_fcs'), | |
dict(name='reg_fcs') | |
]) | |
] | |
def _add_conv_fc_branch(self, | |
num_branch_convs, | |
num_branch_fcs, | |
in_channels, | |
is_shared=False): | |
"""Add shared or separable branch. | |
convs -> avg pool (optional) -> fcs | |
""" | |
last_layer_dim = in_channels | |
# add branch specific conv layers | |
branch_convs = nn.ModuleList() | |
if num_branch_convs > 0: | |
for i in range(num_branch_convs): | |
conv_in_channels = ( | |
last_layer_dim if i == 0 else self.conv_out_channels) | |
branch_convs.append( | |
ConvModule( | |
conv_in_channels, | |
self.conv_out_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg)) | |
last_layer_dim = self.conv_out_channels | |
# add branch specific fc layers | |
branch_fcs = nn.ModuleList() | |
if num_branch_fcs > 0: | |
# for shared branch, only consider self.with_avg_pool | |
# for separated branches, also consider self.num_shared_fcs | |
if (is_shared | |
or self.num_shared_fcs == 0) and not self.with_avg_pool: | |
last_layer_dim *= self.roi_feat_area | |
for i in range(num_branch_fcs): | |
fc_in_channels = ( | |
last_layer_dim if i == 0 else self.fc_out_channels) | |
branch_fcs.append( | |
nn.Linear(fc_in_channels, self.fc_out_channels)) | |
last_layer_dim = self.fc_out_channels | |
return branch_convs, branch_fcs, last_layer_dim | |
def forward(self, x): | |
# shared part | |
if self.num_shared_convs > 0: | |
for conv in self.shared_convs: | |
x = conv(x) | |
if self.num_shared_fcs > 0: | |
if self.with_avg_pool: | |
x = self.avg_pool(x) | |
x = x.flatten(1) | |
for fc in self.shared_fcs: | |
x = self.relu(fc(x)) | |
# separate branches | |
x_cls = x | |
x_reg = x | |
for conv in self.cls_convs: | |
x_cls = conv(x_cls) | |
if x_cls.dim() > 2: | |
if self.with_avg_pool: | |
x_cls = self.avg_pool(x_cls) | |
x_cls = x_cls.flatten(1) | |
for fc in self.cls_fcs: | |
x_cls = self.relu(fc(x_cls)) | |
for conv in self.reg_convs: | |
x_reg = conv(x_reg) | |
if x_reg.dim() > 2: | |
if self.with_avg_pool: | |
x_reg = self.avg_pool(x_reg) | |
x_reg = x_reg.flatten(1) | |
for fc in self.reg_fcs: | |
x_reg = self.relu(fc(x_reg)) | |
cls_score = self.fc_cls(x_cls) if self.with_cls else None | |
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None | |
return cls_score, bbox_pred | |
class Shared2FCBBoxHead(ConvFCBBoxHead): | |
def __init__(self, fc_out_channels=1024, *args, **kwargs): | |
super(Shared2FCBBoxHead, self).__init__( | |
num_shared_convs=0, | |
num_shared_fcs=2, | |
num_cls_convs=0, | |
num_cls_fcs=0, | |
num_reg_convs=0, | |
num_reg_fcs=0, | |
fc_out_channels=fc_out_channels, | |
*args, | |
**kwargs) | |
class Shared4Conv1FCBBoxHead(ConvFCBBoxHead): | |
def __init__(self, fc_out_channels=1024, *args, **kwargs): | |
super(Shared4Conv1FCBBoxHead, self).__init__( | |
num_shared_convs=4, | |
num_shared_fcs=1, | |
num_cls_convs=0, | |
num_cls_fcs=0, | |
num_reg_convs=0, | |
num_reg_fcs=0, | |
fc_out_channels=fc_out_channels, | |
*args, | |
**kwargs) | |