Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmcv.runner import BaseModule, ModuleList | |
from mmdet.models.backbones.resnet import Bottleneck | |
from mmdet.models.builder import HEADS | |
from .bbox_head import BBoxHead | |
class BasicResBlock(BaseModule): | |
"""Basic residual block. | |
This block is a little different from the block in the ResNet backbone. | |
The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock. | |
Args: | |
in_channels (int): Channels of the input feature map. | |
out_channels (int): Channels of the output feature map. | |
conv_cfg (dict): The config dict for convolution layers. | |
norm_cfg (dict): The config dict for normalization layers. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
init_cfg=None): | |
super(BasicResBlock, self).__init__(init_cfg) | |
# main path | |
self.conv1 = ConvModule( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg) | |
self.conv2 = ConvModule( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
# identity path | |
self.conv_identity = ConvModule( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
identity = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
identity = self.conv_identity(identity) | |
out = x + identity | |
out = self.relu(out) | |
return out | |
class DoubleConvFCBBoxHead(BBoxHead): | |
r"""Bbox head used in Double-Head R-CNN | |
.. code-block:: none | |
/-> cls | |
/-> shared convs -> | |
\-> reg | |
roi features | |
/-> cls | |
\-> shared fc -> | |
\-> reg | |
""" # noqa: W605 | |
def __init__(self, | |
num_convs=0, | |
num_fcs=0, | |
conv_out_channels=1024, | |
fc_out_channels=1024, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
init_cfg=dict( | |
type='Normal', | |
override=[ | |
dict(type='Normal', name='fc_cls', std=0.01), | |
dict(type='Normal', name='fc_reg', std=0.001), | |
dict( | |
type='Xavier', | |
name='fc_branch', | |
distribution='uniform') | |
]), | |
**kwargs): | |
kwargs.setdefault('with_avg_pool', True) | |
super(DoubleConvFCBBoxHead, self).__init__(init_cfg=init_cfg, **kwargs) | |
assert self.with_avg_pool | |
assert num_convs > 0 | |
assert num_fcs > 0 | |
self.num_convs = num_convs | |
self.num_fcs = num_fcs | |
self.conv_out_channels = conv_out_channels | |
self.fc_out_channels = fc_out_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
# increase the channel of input features | |
self.res_block = BasicResBlock(self.in_channels, | |
self.conv_out_channels) | |
# add conv heads | |
self.conv_branch = self._add_conv_branch() | |
# add fc heads | |
self.fc_branch = self._add_fc_branch() | |
out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes | |
self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg) | |
self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes + 1) | |
self.relu = nn.ReLU(inplace=True) | |
def _add_conv_branch(self): | |
"""Add the fc branch which consists of a sequential of conv layers.""" | |
branch_convs = ModuleList() | |
for i in range(self.num_convs): | |
branch_convs.append( | |
Bottleneck( | |
inplanes=self.conv_out_channels, | |
planes=self.conv_out_channels // 4, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg)) | |
return branch_convs | |
def _add_fc_branch(self): | |
"""Add the fc branch which consists of a sequential of fc layers.""" | |
branch_fcs = ModuleList() | |
for i in range(self.num_fcs): | |
fc_in_channels = ( | |
self.in_channels * | |
self.roi_feat_area if i == 0 else self.fc_out_channels) | |
branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels)) | |
return branch_fcs | |
def forward(self, x_cls, x_reg): | |
# conv head | |
x_conv = self.res_block(x_reg) | |
for conv in self.conv_branch: | |
x_conv = conv(x_conv) | |
if self.with_avg_pool: | |
x_conv = self.avg_pool(x_conv) | |
x_conv = x_conv.view(x_conv.size(0), -1) | |
bbox_pred = self.fc_reg(x_conv) | |
# fc head | |
x_fc = x_cls.view(x_cls.size(0), -1) | |
for fc in self.fc_branch: | |
x_fc = self.relu(fc(x_fc)) | |
cls_score = self.fc_cls(x_fc) | |
return cls_score, bbox_pred | |