# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.runner import BaseModule from ..builder import NECKS @NECKS.register_module() class ChannelMapper(BaseModule): r"""Channel Mapper to reduce/increase channels of backbone features. This is used to reduce/increase channels of backbone features. Args: in_channels (List[int]): Number of input channels per scale. out_channels (int): Number of output channels (used at each scale). kernel_size (int, optional): kernel_size for reducing channels (used at each scale). Default: 3. conv_cfg (dict, optional): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. act_cfg (dict, optional): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). num_outs (int, optional): Number of output feature maps. There would be extra_convs when num_outs larger than the length of in_channels. init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [2, 3, 5, 7] >>> scales = [340, 170, 84, 43] >>> inputs = [torch.rand(1, c, s, s) ... for c, s in zip(in_channels, scales)] >>> self = ChannelMapper(in_channels, 11, 3).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 11, 340, 340]) outputs[1].shape = torch.Size([1, 11, 170, 170]) outputs[2].shape = torch.Size([1, 11, 84, 84]) outputs[3].shape = torch.Size([1, 11, 43, 43]) """ def __init__(self, in_channels, out_channels, kernel_size=3, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), num_outs=None, init_cfg=dict( type='Xavier', layer='Conv2d', distribution='uniform')): super(ChannelMapper, self).__init__(init_cfg) assert isinstance(in_channels, list) self.extra_convs = None if num_outs is None: num_outs = len(in_channels) self.convs = nn.ModuleList() for in_channel in in_channels: self.convs.append( ConvModule( in_channel, out_channels, kernel_size, padding=(kernel_size - 1) // 2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) if num_outs > len(in_channels): self.extra_convs = nn.ModuleList() for i in range(len(in_channels), num_outs): if i == len(in_channels): in_channel = in_channels[-1] else: in_channel = out_channels self.extra_convs.append( ConvModule( in_channel, out_channels, 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) def forward(self, inputs): """Forward function.""" assert len(inputs) == len(self.convs) outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] if self.extra_convs: for i in range(len(self.extra_convs)): if i == 0: outs.append(self.extra_convs[0](inputs[-1])) else: outs.append(self.extra_convs[i](outs[-1])) return tuple(outs)