Spaces:
Runtime error
Runtime error
File size: 3,777 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import NonLocal2d
from mmcv.runner import BaseModule
from ..builder import NECKS
@NECKS.register_module()
class BFP(BaseModule):
"""BFP (Balanced Feature Pyramids)
BFP takes multi-level features as inputs and gather them into a single one,
then refine the gathered feature and scatter the refined results to
multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
<https://arxiv.org/abs/1904.02701>`_ for details.
Args:
in_channels (int): Number of input channels (feature maps of all levels
should have the same channels).
num_levels (int): Number of input feature levels.
conv_cfg (dict): The config dict for convolution layers.
norm_cfg (dict): The config dict for normalization layers.
refine_level (int): Index of integration and refine level of BSF in
multi-level features from bottom to top.
refine_type (str): Type of the refine op, currently support
[None, 'conv', 'non_local'].
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
num_levels,
refine_level=2,
refine_type=None,
conv_cfg=None,
norm_cfg=None,
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(BFP, self).__init__(init_cfg)
assert refine_type in [None, 'conv', 'non_local']
self.in_channels = in_channels
self.num_levels = num_levels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.refine_level = refine_level
self.refine_type = refine_type
assert 0 <= self.refine_level < self.num_levels
if self.refine_type == 'conv':
self.refine = ConvModule(
self.in_channels,
self.in_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
elif self.refine_type == 'non_local':
self.refine = NonLocal2d(
self.in_channels,
reduction=1,
use_scale=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == self.num_levels
# step 1: gather multi-level features by resize and average
feats = []
gather_size = inputs[self.refine_level].size()[2:]
for i in range(self.num_levels):
if i < self.refine_level:
gathered = F.adaptive_max_pool2d(
inputs[i], output_size=gather_size)
else:
gathered = F.interpolate(
inputs[i], size=gather_size, mode='nearest')
feats.append(gathered)
bsf = sum(feats) / len(feats)
# step 2: refine gathered features
if self.refine_type is not None:
bsf = self.refine(bsf)
# step 3: scatter refined features to multi-levels by a residual path
outs = []
for i in range(self.num_levels):
out_size = inputs[i].size()[2:]
if i < self.refine_level:
residual = F.interpolate(bsf, size=out_size, mode='nearest')
else:
residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
outs.append(residual + inputs[i])
return tuple(outs)
|