Spaces:
Runtime error
Runtime error
File size: 3,551 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule, Linear
from mmcv.runner import ModuleList, auto_fp16
from mmdet.models.builder import HEADS
from .fcn_mask_head import FCNMaskHead
@HEADS.register_module()
class CoarseMaskHead(FCNMaskHead):
"""Coarse mask head used in PointRend.
Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
the input feature map instead of upsample it.
Args:
num_convs (int): Number of conv layers in the head. Default: 0.
num_fcs (int): Number of fc layers in the head. Default: 2.
fc_out_channels (int): Number of output channels of fc layer.
Default: 1024.
downsample_factor (int): The factor that feature map is downsampled by.
Default: 2.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
num_convs=0,
num_fcs=2,
fc_out_channels=1024,
downsample_factor=2,
init_cfg=dict(
type='Xavier',
override=[
dict(name='fcs'),
dict(type='Constant', val=0.001, name='fc_logits')
]),
*arg,
**kwarg):
super(CoarseMaskHead, self).__init__(
*arg,
num_convs=num_convs,
upsample_cfg=dict(type=None),
init_cfg=None,
**kwarg)
self.init_cfg = init_cfg
self.num_fcs = num_fcs
assert self.num_fcs > 0
self.fc_out_channels = fc_out_channels
self.downsample_factor = downsample_factor
assert self.downsample_factor >= 1
# remove conv_logit
delattr(self, 'conv_logits')
if downsample_factor > 1:
downsample_in_channels = (
self.conv_out_channels
if self.num_convs > 0 else self.in_channels)
self.downsample_conv = ConvModule(
downsample_in_channels,
self.conv_out_channels,
kernel_size=downsample_factor,
stride=downsample_factor,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
else:
self.downsample_conv = None
self.output_size = (self.roi_feat_size[0] // downsample_factor,
self.roi_feat_size[1] // downsample_factor)
self.output_area = self.output_size[0] * self.output_size[1]
last_layer_dim = self.conv_out_channels * self.output_area
self.fcs = ModuleList()
for i in range(num_fcs):
fc_in_channels = (
last_layer_dim if i == 0 else self.fc_out_channels)
self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels
output_channels = self.num_classes * self.output_area
self.fc_logits = Linear(last_layer_dim, output_channels)
def init_weights(self):
super(FCNMaskHead, self).init_weights()
@auto_fp16()
def forward(self, x):
for conv in self.convs:
x = conv(x)
if self.downsample_conv is not None:
x = self.downsample_conv(x)
x = x.flatten(1)
for fc in self.fcs:
x = self.relu(fc(x))
mask_pred = self.fc_logits(x).view(
x.size(0), self.num_classes, *self.output_size)
return mask_pred
|