3v324v23's picture
Add files
c9019cd
raw
history blame
3.93 kB
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from ..builder import NECKS
@NECKS.register_module()
class ChannelMapper(BaseModule):
r"""Channel Mapper to reduce/increase channels of backbone features.
This is used to reduce/increase channels of backbone features.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
kernel_size (int, optional): kernel_size for reducing channels (used
at each scale). Default: 3.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU').
num_outs (int, optional): Number of output feature maps. There
would be extra_convs when num_outs larger than the lenghth
of in_channels.
init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = ChannelMapper(in_channels, 11, 3).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 11, 340, 340])
outputs[1].shape = torch.Size([1, 11, 170, 170])
outputs[2].shape = torch.Size([1, 11, 84, 84])
outputs[3].shape = torch.Size([1, 11, 43, 43])
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
num_outs=None,
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(ChannelMapper, self).__init__(init_cfg)
assert isinstance(in_channels, list)
self.extra_convs = None
if num_outs is None:
num_outs = len(in_channels)
self.convs = nn.ModuleList()
for in_channel in in_channels:
self.convs.append(
ConvModule(
in_channel,
out_channels,
kernel_size,
padding=(kernel_size - 1) // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
if num_outs > len(in_channels):
self.extra_convs = nn.ModuleList()
for i in range(len(in_channels), num_outs):
if i == len(in_channels):
in_channel = in_channels[-1]
else:
in_channel = out_channels
self.extra_convs.append(
ConvModule(
in_channel,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.convs)
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
if self.extra_convs:
for i in range(len(self.extra_convs)):
if i == 0:
outs.append(self.extra_convs[0](inputs[-1]))
else:
outs.append(self.extra_convs[i](outs[-1]))
return tuple(outs)