Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmcv.runner import BaseModule | |
from ..builder import NECKS | |
class Transition(BaseModule): | |
"""Base class for transition. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
""" | |
def __init__(self, in_channels, out_channels, init_cfg=None): | |
super().__init__(init_cfg) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
def forward(x): | |
pass | |
class UpInterpolationConv(Transition): | |
"""A transition used for up-sampling. | |
Up-sample the input by interpolation then refines the feature by | |
a convolution layer. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
scale_factor (int): Up-sampling factor. Default: 2. | |
mode (int): Interpolation mode. Default: nearest. | |
align_corners (bool): Whether align corners when interpolation. | |
Default: None. | |
kernel_size (int): Kernel size for the conv. Default: 3. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
scale_factor=2, | |
mode='nearest', | |
align_corners=None, | |
kernel_size=3, | |
init_cfg=None, | |
**kwargs): | |
super().__init__(in_channels, out_channels, init_cfg) | |
self.mode = mode | |
self.scale_factor = scale_factor | |
self.align_corners = align_corners | |
self.conv = ConvModule( | |
in_channels, | |
out_channels, | |
kernel_size, | |
padding=(kernel_size - 1) // 2, | |
**kwargs) | |
def forward(self, x): | |
x = F.interpolate( | |
x, | |
scale_factor=self.scale_factor, | |
mode=self.mode, | |
align_corners=self.align_corners) | |
x = self.conv(x) | |
return x | |
class LastConv(Transition): | |
"""A transition used for refining the output of the last stage. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
num_inputs (int): Number of inputs of the FPN features. | |
kernel_size (int): Kernel size for the conv. Default: 3. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
num_inputs, | |
kernel_size=3, | |
init_cfg=None, | |
**kwargs): | |
super().__init__(in_channels, out_channels, init_cfg) | |
self.num_inputs = num_inputs | |
self.conv_out = ConvModule( | |
in_channels, | |
out_channels, | |
kernel_size, | |
padding=(kernel_size - 1) // 2, | |
**kwargs) | |
def forward(self, inputs): | |
assert len(inputs) == self.num_inputs | |
return self.conv_out(inputs[-1]) | |
class FPG(BaseModule): | |
"""FPG. | |
Implementation of `Feature Pyramid Grids (FPG) | |
<https://arxiv.org/abs/2004.03580>`_. | |
This implementation only gives the basic structure stated in the paper. | |
But users can implement different type of transitions to fully explore the | |
the potential power of the structure of FPG. | |
Args: | |
in_channels (int): Number of input channels (feature maps of all levels | |
should have the same channels). | |
out_channels (int): Number of output channels (used at each scale) | |
num_outs (int): Number of output scales. | |
stack_times (int): The number of times the pyramid architecture will | |
be stacked. | |
paths (list[str]): Specify the path order of each stack level. | |
Each element in the list should be either 'bu' (bottom-up) or | |
'td' (top-down). | |
inter_channels (int): Number of inter channels. | |
same_up_trans (dict): Transition that goes down at the same stage. | |
same_down_trans (dict): Transition that goes up at the same stage. | |
across_lateral_trans (dict): Across-pathway same-stage | |
across_down_trans (dict): Across-pathway bottom-up connection. | |
across_up_trans (dict): Across-pathway top-down connection. | |
across_skip_trans (dict): Across-pathway skip connection. | |
output_trans (dict): Transition that trans the output of the | |
last stage. | |
start_level (int): Index of the start input backbone level used to | |
build the feature pyramid. Default: 0. | |
end_level (int): Index of the end input backbone level (exclusive) to | |
build the feature pyramid. Default: -1, which means the last level. | |
add_extra_convs (bool): It decides whether to add conv | |
layers on top of the original feature maps. Default to False. | |
If True, its actual mode is specified by `extra_convs_on_inputs`. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
transition_types = { | |
'conv': ConvModule, | |
'interpolation_conv': UpInterpolationConv, | |
'last_conv': LastConv, | |
} | |
def __init__(self, | |
in_channels, | |
out_channels, | |
num_outs, | |
stack_times, | |
paths, | |
inter_channels=None, | |
same_down_trans=None, | |
same_up_trans=dict( | |
type='conv', kernel_size=3, stride=2, padding=1), | |
across_lateral_trans=dict(type='conv', kernel_size=1), | |
across_down_trans=dict(type='conv', kernel_size=3), | |
across_up_trans=None, | |
across_skip_trans=dict(type='identity'), | |
output_trans=dict(type='last_conv', kernel_size=3), | |
start_level=0, | |
end_level=-1, | |
add_extra_convs=False, | |
norm_cfg=None, | |
skip_inds=None, | |
init_cfg=[ | |
dict(type='Caffe2Xavier', layer='Conv2d'), | |
dict( | |
type='Constant', | |
layer=[ | |
'_BatchNorm', '_InstanceNorm', 'GroupNorm', | |
'LayerNorm' | |
], | |
val=1.0) | |
]): | |
super(FPG, self).__init__(init_cfg) | |
assert isinstance(in_channels, list) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.num_ins = len(in_channels) | |
self.num_outs = num_outs | |
if inter_channels is None: | |
self.inter_channels = [out_channels for _ in range(num_outs)] | |
elif isinstance(inter_channels, int): | |
self.inter_channels = [inter_channels for _ in range(num_outs)] | |
else: | |
assert isinstance(inter_channels, list) | |
assert len(inter_channels) == num_outs | |
self.inter_channels = inter_channels | |
self.stack_times = stack_times | |
self.paths = paths | |
assert isinstance(paths, list) and len(paths) == stack_times | |
for d in paths: | |
assert d in ('bu', 'td') | |
self.same_down_trans = same_down_trans | |
self.same_up_trans = same_up_trans | |
self.across_lateral_trans = across_lateral_trans | |
self.across_down_trans = across_down_trans | |
self.across_up_trans = across_up_trans | |
self.output_trans = output_trans | |
self.across_skip_trans = across_skip_trans | |
self.with_bias = norm_cfg is None | |
# skip inds must be specified if across skip trans is not None | |
if self.across_skip_trans is not None: | |
skip_inds is not None | |
self.skip_inds = skip_inds | |
assert len(self.skip_inds[0]) <= self.stack_times | |
if end_level == -1 or end_level == self.num_ins - 1: | |
self.backbone_end_level = self.num_ins | |
assert num_outs >= self.num_ins - start_level | |
else: | |
# if end_level is not the last level, no extra level is allowed | |
self.backbone_end_level = end_level + 1 | |
assert end_level < self.num_ins | |
assert num_outs == end_level - start_level + 1 | |
self.start_level = start_level | |
self.end_level = end_level | |
self.add_extra_convs = add_extra_convs | |
# build lateral 1x1 convs to reduce channels | |
self.lateral_convs = nn.ModuleList() | |
for i in range(self.start_level, self.backbone_end_level): | |
l_conv = nn.Conv2d(self.in_channels[i], | |
self.inter_channels[i - self.start_level], 1) | |
self.lateral_convs.append(l_conv) | |
extra_levels = num_outs - self.backbone_end_level + self.start_level | |
self.extra_downsamples = nn.ModuleList() | |
for i in range(extra_levels): | |
if self.add_extra_convs: | |
fpn_idx = self.backbone_end_level - self.start_level + i | |
extra_conv = nn.Conv2d( | |
self.inter_channels[fpn_idx - 1], | |
self.inter_channels[fpn_idx], | |
3, | |
stride=2, | |
padding=1) | |
self.extra_downsamples.append(extra_conv) | |
else: | |
self.extra_downsamples.append(nn.MaxPool2d(1, stride=2)) | |
self.fpn_transitions = nn.ModuleList() # stack times | |
for s in range(self.stack_times): | |
stage_trans = nn.ModuleList() # num of feature levels | |
for i in range(self.num_outs): | |
# same, across_lateral, across_down, across_up | |
trans = nn.ModuleDict() | |
if s in self.skip_inds[i]: | |
stage_trans.append(trans) | |
continue | |
# build same-stage down trans (used in bottom-up paths) | |
if i == 0 or self.same_up_trans is None: | |
same_up_trans = None | |
else: | |
same_up_trans = self.build_trans( | |
self.same_up_trans, self.inter_channels[i - 1], | |
self.inter_channels[i]) | |
trans['same_up'] = same_up_trans | |
# build same-stage up trans (used in top-down paths) | |
if i == self.num_outs - 1 or self.same_down_trans is None: | |
same_down_trans = None | |
else: | |
same_down_trans = self.build_trans( | |
self.same_down_trans, self.inter_channels[i + 1], | |
self.inter_channels[i]) | |
trans['same_down'] = same_down_trans | |
# build across lateral trans | |
across_lateral_trans = self.build_trans( | |
self.across_lateral_trans, self.inter_channels[i], | |
self.inter_channels[i]) | |
trans['across_lateral'] = across_lateral_trans | |
# build across down trans | |
if i == self.num_outs - 1 or self.across_down_trans is None: | |
across_down_trans = None | |
else: | |
across_down_trans = self.build_trans( | |
self.across_down_trans, self.inter_channels[i + 1], | |
self.inter_channels[i]) | |
trans['across_down'] = across_down_trans | |
# build across up trans | |
if i == 0 or self.across_up_trans is None: | |
across_up_trans = None | |
else: | |
across_up_trans = self.build_trans( | |
self.across_up_trans, self.inter_channels[i - 1], | |
self.inter_channels[i]) | |
trans['across_up'] = across_up_trans | |
if self.across_skip_trans is None: | |
across_skip_trans = None | |
else: | |
across_skip_trans = self.build_trans( | |
self.across_skip_trans, self.inter_channels[i - 1], | |
self.inter_channels[i]) | |
trans['across_skip'] = across_skip_trans | |
# build across_skip trans | |
stage_trans.append(trans) | |
self.fpn_transitions.append(stage_trans) | |
self.output_transition = nn.ModuleList() # output levels | |
for i in range(self.num_outs): | |
trans = self.build_trans( | |
self.output_trans, | |
self.inter_channels[i], | |
self.out_channels, | |
num_inputs=self.stack_times + 1) | |
self.output_transition.append(trans) | |
self.relu = nn.ReLU(inplace=True) | |
def build_trans(self, cfg, in_channels, out_channels, **extra_args): | |
cfg_ = cfg.copy() | |
trans_type = cfg_.pop('type') | |
trans_cls = self.transition_types[trans_type] | |
return trans_cls(in_channels, out_channels, **cfg_, **extra_args) | |
def fuse(self, fuse_dict): | |
out = None | |
for item in fuse_dict.values(): | |
if item is not None: | |
if out is None: | |
out = item | |
else: | |
out = out + item | |
return out | |
def forward(self, inputs): | |
assert len(inputs) == len(self.in_channels) | |
# build all levels from original feature maps | |
feats = [ | |
lateral_conv(inputs[i + self.start_level]) | |
for i, lateral_conv in enumerate(self.lateral_convs) | |
] | |
for downsample in self.extra_downsamples: | |
feats.append(downsample(feats[-1])) | |
outs = [feats] | |
for i in range(self.stack_times): | |
current_outs = outs[-1] | |
next_outs = [] | |
direction = self.paths[i] | |
for j in range(self.num_outs): | |
if i in self.skip_inds[j]: | |
next_outs.append(outs[-1][j]) | |
continue | |
# feature level | |
if direction == 'td': | |
lvl = self.num_outs - j - 1 | |
else: | |
lvl = j | |
# get transitions | |
if direction == 'td': | |
same_trans = self.fpn_transitions[i][lvl]['same_down'] | |
else: | |
same_trans = self.fpn_transitions[i][lvl]['same_up'] | |
across_lateral_trans = self.fpn_transitions[i][lvl][ | |
'across_lateral'] | |
across_down_trans = self.fpn_transitions[i][lvl]['across_down'] | |
across_up_trans = self.fpn_transitions[i][lvl]['across_up'] | |
across_skip_trans = self.fpn_transitions[i][lvl]['across_skip'] | |
# init output | |
to_fuse = dict( | |
same=None, lateral=None, across_up=None, across_down=None) | |
# same downsample/upsample | |
if same_trans is not None: | |
to_fuse['same'] = same_trans(next_outs[-1]) | |
# across lateral | |
if across_lateral_trans is not None: | |
to_fuse['lateral'] = across_lateral_trans( | |
current_outs[lvl]) | |
# across downsample | |
if lvl > 0 and across_up_trans is not None: | |
to_fuse['across_up'] = across_up_trans(current_outs[lvl - | |
1]) | |
# across upsample | |
if (lvl < self.num_outs - 1 and across_down_trans is not None): | |
to_fuse['across_down'] = across_down_trans( | |
current_outs[lvl + 1]) | |
if across_skip_trans is not None: | |
to_fuse['across_skip'] = across_skip_trans(outs[0][lvl]) | |
x = self.fuse(to_fuse) | |
next_outs.append(x) | |
if direction == 'td': | |
outs.append(next_outs[::-1]) | |
else: | |
outs.append(next_outs) | |
# output trans | |
final_outs = [] | |
for i in range(self.num_outs): | |
lvl_out_list = [] | |
for s in range(len(outs)): | |
lvl_out_list.append(outs[s][i]) | |
lvl_out = self.output_transition[i](lvl_out_list) | |
final_outs.append(lvl_out) | |
return final_outs | |