|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule |
|
from mmcv.runner import BaseModule |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from ..builder import NECKS |
|
|
|
|
|
@NECKS.register_module() |
|
class HRFPN(BaseModule): |
|
"""HRFPN (High Resolution Feature Pyrmamids) |
|
|
|
paper: `High-Resolution Representations for Labeling Pixels and Regions |
|
<https://arxiv.org/abs/1904.04514>`_. |
|
|
|
Args: |
|
in_channels (list): number of channels for each branch. |
|
out_channels (int): output channels of feature pyramids. |
|
num_outs (int): number of output stages. |
|
pooling_type (str): pooling for generating feature pyramids |
|
from {MAX, AVG}. |
|
conv_cfg (dict): dictionary to construct and config conv layer. |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. |
|
stride (int): stride of 3x3 convolutional layers |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
num_outs=5, |
|
pooling_type='AVG', |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
with_cp=False, |
|
stride=1, |
|
init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')): |
|
super(HRFPN, self).__init__(init_cfg) |
|
assert isinstance(in_channels, list) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_ins = len(in_channels) |
|
self.num_outs = num_outs |
|
self.with_cp = with_cp |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
|
|
self.reduction_conv = ConvModule( |
|
sum(in_channels), |
|
out_channels, |
|
kernel_size=1, |
|
conv_cfg=self.conv_cfg, |
|
act_cfg=None) |
|
|
|
self.fpn_convs = nn.ModuleList() |
|
for i in range(self.num_outs): |
|
self.fpn_convs.append( |
|
ConvModule( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
stride=stride, |
|
conv_cfg=self.conv_cfg, |
|
act_cfg=None)) |
|
|
|
if pooling_type == 'MAX': |
|
self.pooling = F.max_pool2d |
|
else: |
|
self.pooling = F.avg_pool2d |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
assert len(inputs) == self.num_ins |
|
outs = [inputs[0]] |
|
for i in range(1, self.num_ins): |
|
outs.append( |
|
F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear')) |
|
out = torch.cat(outs, dim=1) |
|
if out.requires_grad and self.with_cp: |
|
out = checkpoint(self.reduction_conv, out) |
|
else: |
|
out = self.reduction_conv(out) |
|
outs = [out] |
|
for i in range(1, self.num_outs): |
|
outs.append(self.pooling(out, kernel_size=2**i, stride=2**i)) |
|
outputs = [] |
|
|
|
for i in range(self.num_outs): |
|
if outs[i].requires_grad and self.with_cp: |
|
tmp_out = checkpoint(self.fpn_convs[i], outs[i]) |
|
else: |
|
tmp_out = self.fpn_convs[i](outs[i]) |
|
outputs.append(tmp_out) |
|
return tuple(outputs) |
|
|