|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell |
|
from mmcv.runner import BaseModule, ModuleList |
|
|
|
from ..builder import NECKS |
|
|
|
|
|
@NECKS.register_module() |
|
class NASFPN(BaseModule): |
|
"""NAS-FPN. |
|
|
|
Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture |
|
for Object Detection <https://arxiv.org/abs/1904.07392>`_ |
|
|
|
Args: |
|
in_channels (List[int]): Number of input channels per scale. |
|
out_channels (int): Number of output channels (used at each scale) |
|
num_outs (int): Number of output scales. |
|
stack_times (int): The number of times the pyramid architecture will |
|
be stacked. |
|
start_level (int): Index of the start input backbone level used to |
|
build the feature pyramid. Default: 0. |
|
end_level (int): Index of the end input backbone level (exclusive) to |
|
build the feature pyramid. Default: -1, which means the last level. |
|
add_extra_convs (bool): It decides whether to add conv |
|
layers on top of the original feature maps. Default to False. |
|
If True, its actual mode is specified by `extra_convs_on_inputs`. |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
num_outs, |
|
stack_times, |
|
start_level=0, |
|
end_level=-1, |
|
add_extra_convs=False, |
|
norm_cfg=None, |
|
init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')): |
|
super(NASFPN, self).__init__(init_cfg) |
|
assert isinstance(in_channels, list) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_ins = len(in_channels) |
|
self.num_outs = num_outs |
|
self.stack_times = stack_times |
|
self.norm_cfg = norm_cfg |
|
|
|
if end_level == -1: |
|
self.backbone_end_level = self.num_ins |
|
assert num_outs >= self.num_ins - start_level |
|
else: |
|
|
|
self.backbone_end_level = end_level |
|
assert end_level <= len(in_channels) |
|
assert num_outs == end_level - start_level |
|
self.start_level = start_level |
|
self.end_level = end_level |
|
self.add_extra_convs = add_extra_convs |
|
|
|
|
|
self.lateral_convs = nn.ModuleList() |
|
for i in range(self.start_level, self.backbone_end_level): |
|
l_conv = ConvModule( |
|
in_channels[i], |
|
out_channels, |
|
1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) |
|
self.lateral_convs.append(l_conv) |
|
|
|
|
|
extra_levels = num_outs - self.backbone_end_level + self.start_level |
|
self.extra_downsamples = nn.ModuleList() |
|
for i in range(extra_levels): |
|
extra_conv = ConvModule( |
|
out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None) |
|
self.extra_downsamples.append( |
|
nn.Sequential(extra_conv, nn.MaxPool2d(2, 2))) |
|
|
|
|
|
self.fpn_stages = ModuleList() |
|
for _ in range(self.stack_times): |
|
stage = nn.ModuleDict() |
|
|
|
stage['gp_64_4'] = GlobalPoolingCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
|
|
stage['sum_44_4'] = SumCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
|
|
stage['sum_43_3'] = SumCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
|
|
stage['sum_34_4'] = SumCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
|
|
stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False) |
|
stage['sum_55_5'] = SumCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
|
|
stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False) |
|
stage['sum_77_7'] = SumCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
|
|
stage['gp_75_6'] = GlobalPoolingCell( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
out_norm_cfg=norm_cfg) |
|
self.fpn_stages.append(stage) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
|
|
feats = [ |
|
lateral_conv(inputs[i + self.start_level]) |
|
for i, lateral_conv in enumerate(self.lateral_convs) |
|
] |
|
|
|
for downsample in self.extra_downsamples: |
|
feats.append(downsample(feats[-1])) |
|
|
|
p3, p4, p5, p6, p7 = feats |
|
|
|
for stage in self.fpn_stages: |
|
|
|
p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:]) |
|
|
|
p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:]) |
|
|
|
p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:]) |
|
|
|
p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:]) |
|
|
|
p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:]) |
|
p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:]) |
|
|
|
p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:]) |
|
p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:]) |
|
|
|
p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:]) |
|
|
|
return p3, p4, p5, p6, p7 |
|
|