Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmcv.cnn import ConvModule | |
from mmdet.models.builder import HEADS | |
from .fcn_mask_head import FCNMaskHead | |
class HTCMaskHead(FCNMaskHead): | |
def __init__(self, with_conv_res=True, *args, **kwargs): | |
super(HTCMaskHead, self).__init__(*args, **kwargs) | |
self.with_conv_res = with_conv_res | |
if self.with_conv_res: | |
self.conv_res = ConvModule( | |
self.conv_out_channels, | |
self.conv_out_channels, | |
1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
def forward(self, x, res_feat=None, return_logits=True, return_feat=True): | |
if res_feat is not None: | |
assert self.with_conv_res | |
res_feat = self.conv_res(res_feat) | |
x = x + res_feat | |
for conv in self.convs: | |
x = conv(x) | |
res_feat = x | |
outs = [] | |
if return_logits: | |
x = self.upsample(x) | |
if self.upsample_method == 'deconv': | |
x = self.relu(x) | |
mask_pred = self.conv_logits(x) | |
outs.append(mask_pred) | |
if return_feat: | |
outs.append(res_feat) | |
return outs if len(outs) > 1 else outs[0] | |