File size: 5,004 Bytes
c9019cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import constant_init, xavier_init
from mmcv.runner import BaseModule, ModuleList
from ..builder import NECKS, build_backbone
from .fpn import FPN
class ASPP(BaseModule):
"""ASPP (Atrous Spatial Pyramid Pooling)
This is an implementation of the ASPP module used in DetectoRS
(https://arxiv.org/pdf/2006.02334.pdf)
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of channels produced by this module
dilations (tuple[int]): Dilations of the four branches.
Default: (1, 3, 6, 1)
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
dilations=(1, 3, 6, 1),
init_cfg=dict(type='Kaiming', layer='Conv2d')):
super().__init__(init_cfg)
assert dilations[-1] == 1
self.aspp = nn.ModuleList()
for dilation in dilations:
kernel_size = 3 if dilation > 1 else 1
padding = dilation if dilation > 1 else 0
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
padding=padding,
bias=True)
self.aspp.append(conv)
self.gap = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
avg_x = self.gap(x)
out = []
for aspp_idx in range(len(self.aspp)):
inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
out.append(F.relu_(self.aspp[aspp_idx](inp)))
out[-1] = out[-1].expand_as(out[-2])
out = torch.cat(out, dim=1)
return out
@NECKS.register_module()
class RFP(FPN):
"""RFP (Recursive Feature Pyramid)
This is an implementation of RFP in `DetectoRS
<https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the
input of RFP should be multi level features along with origin input image
of backbone.
Args:
rfp_steps (int): Number of unrolled steps of RFP.
rfp_backbone (dict): Configuration of the backbone for RFP.
aspp_out_channels (int): Number of output channels of ASPP module.
aspp_dilations (tuple[int]): Dilation rates of four branches.
Default: (1, 3, 6, 1)
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
rfp_steps,
rfp_backbone,
aspp_out_channels,
aspp_dilations=(1, 3, 6, 1),
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super().__init__(init_cfg=init_cfg, **kwargs)
self.rfp_steps = rfp_steps
# Be careful! Pretrained weights cannot be loaded when use
# nn.ModuleList
self.rfp_modules = ModuleList()
for rfp_idx in range(1, rfp_steps):
rfp_module = build_backbone(rfp_backbone)
self.rfp_modules.append(rfp_module)
self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels,
aspp_dilations)
self.rfp_weight = nn.Conv2d(
self.out_channels,
1,
kernel_size=1,
stride=1,
padding=0,
bias=True)
def init_weights(self):
# Avoid using super().init_weights(), which may alter the default
# initialization of the modules in self.rfp_modules that have missing
# keys in the pretrained checkpoint.
for convs in [self.lateral_convs, self.fpn_convs]:
for m in convs.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
for rfp_idx in range(self.rfp_steps - 1):
self.rfp_modules[rfp_idx].init_weights()
constant_init(self.rfp_weight, 0)
def forward(self, inputs):
inputs = list(inputs)
assert len(inputs) == len(self.in_channels) + 1 # +1 for input image
img = inputs.pop(0)
# FPN forward
x = super().forward(tuple(inputs))
for rfp_idx in range(self.rfp_steps - 1):
rfp_feats = [x[0]] + list(
self.rfp_aspp(x[i]) for i in range(1, len(x)))
x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats)
# FPN forward
x_idx = super().forward(x_idx)
x_new = []
for ft_idx in range(len(x_idx)):
add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx]))
x_new.append(add_weight * x_idx[ft_idx] +
(1 - add_weight) * x[ft_idx])
x = x_new
return x
|