Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import mmcv | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmcv.runner import BaseModule, auto_fp16, force_fp32 | |
from mmdet.core import InstanceData, mask_matrix_nms, multi_apply | |
from mmdet.core.utils import center_of_mass, generate_coordinate | |
from mmdet.models.builder import HEADS | |
from mmdet.utils.misc import floordiv | |
from .solo_head import SOLOHead | |
class MaskFeatModule(BaseModule): | |
"""SOLOv2 mask feature map branch used in `SOLOv2: Dynamic and Fast | |
Instance Segmentation. <https://arxiv.org/pdf/2003.10152>`_ | |
Args: | |
in_channels (int): Number of channels in the input feature map. | |
feat_channels (int): Number of hidden channels of the mask feature | |
map branch. | |
start_level (int): The starting feature map level from RPN that | |
will be used to predict the mask feature map. | |
end_level (int): The ending feature map level from rpn that | |
will be used to predict the mask feature map. | |
out_channels (int): Number of output channels of the mask feature | |
map branch. This is the channel count of the mask | |
feature map that to be dynamically convolved with the predicted | |
kernel. | |
mask_stride (int): Downsample factor of the mask feature map output. | |
Default: 4. | |
conv_cfg (dict): Config dict for convolution layer. Default: None. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
in_channels, | |
feat_channels, | |
start_level, | |
end_level, | |
out_channels, | |
mask_stride=4, | |
conv_cfg=None, | |
norm_cfg=None, | |
init_cfg=[dict(type='Normal', layer='Conv2d', std=0.01)]): | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.feat_channels = feat_channels | |
self.start_level = start_level | |
self.end_level = end_level | |
self.mask_stride = mask_stride | |
assert start_level >= 0 and end_level >= start_level | |
self.out_channels = out_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self._init_layers() | |
self.fp16_enabled = False | |
def _init_layers(self): | |
self.convs_all_levels = nn.ModuleList() | |
for i in range(self.start_level, self.end_level + 1): | |
convs_per_level = nn.Sequential() | |
if i == 0: | |
convs_per_level.add_module( | |
f'conv{i}', | |
ConvModule( | |
self.in_channels, | |
self.feat_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
inplace=False)) | |
self.convs_all_levels.append(convs_per_level) | |
continue | |
for j in range(i): | |
if j == 0: | |
if i == self.end_level: | |
chn = self.in_channels + 2 | |
else: | |
chn = self.in_channels | |
convs_per_level.add_module( | |
f'conv{j}', | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
inplace=False)) | |
convs_per_level.add_module( | |
f'upsample{j}', | |
nn.Upsample( | |
scale_factor=2, | |
mode='bilinear', | |
align_corners=False)) | |
continue | |
convs_per_level.add_module( | |
f'conv{j}', | |
ConvModule( | |
self.feat_channels, | |
self.feat_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
inplace=False)) | |
convs_per_level.add_module( | |
f'upsample{j}', | |
nn.Upsample( | |
scale_factor=2, mode='bilinear', align_corners=False)) | |
self.convs_all_levels.append(convs_per_level) | |
self.conv_pred = ConvModule( | |
self.feat_channels, | |
self.out_channels, | |
1, | |
padding=0, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
def forward(self, feats): | |
inputs = feats[self.start_level:self.end_level + 1] | |
assert len(inputs) == (self.end_level - self.start_level + 1) | |
feature_add_all_level = self.convs_all_levels[0](inputs[0]) | |
for i in range(1, len(inputs)): | |
input_p = inputs[i] | |
if i == len(inputs) - 1: | |
coord_feat = generate_coordinate(input_p.size(), | |
input_p.device) | |
input_p = torch.cat([input_p, coord_feat], 1) | |
# fix runtime error of "+=" inplace operation in PyTorch 1.10 | |
feature_add_all_level = feature_add_all_level + \ | |
self.convs_all_levels[i](input_p) | |
feature_pred = self.conv_pred(feature_add_all_level) | |
return feature_pred | |
class SOLOV2Head(SOLOHead): | |
"""SOLOv2 mask head used in `SOLOv2: Dynamic and Fast Instance | |
Segmentation. <https://arxiv.org/pdf/2003.10152>`_ | |
Args: | |
mask_feature_head (dict): Config of SOLOv2MaskFeatHead. | |
dynamic_conv_size (int): Dynamic Conv kernel size. Default: 1. | |
dcn_cfg (dict): Dcn conv configurations in kernel_convs and cls_conv. | |
default: None. | |
dcn_apply_to_all_conv (bool): Whether to use dcn in every layer of | |
kernel_convs and cls_convs, or only the last layer. It shall be set | |
`True` for the normal version of SOLOv2 and `False` for the | |
light-weight version. default: True. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
*args, | |
mask_feature_head, | |
dynamic_conv_size=1, | |
dcn_cfg=None, | |
dcn_apply_to_all_conv=True, | |
init_cfg=[ | |
dict(type='Normal', layer='Conv2d', std=0.01), | |
dict( | |
type='Normal', | |
std=0.01, | |
bias_prob=0.01, | |
override=dict(name='conv_cls')) | |
], | |
**kwargs): | |
assert dcn_cfg is None or isinstance(dcn_cfg, dict) | |
self.dcn_cfg = dcn_cfg | |
self.with_dcn = dcn_cfg is not None | |
self.dcn_apply_to_all_conv = dcn_apply_to_all_conv | |
self.dynamic_conv_size = dynamic_conv_size | |
mask_out_channels = mask_feature_head.get('out_channels') | |
self.kernel_out_channels = \ | |
mask_out_channels * self.dynamic_conv_size * self.dynamic_conv_size | |
super().__init__(*args, init_cfg=init_cfg, **kwargs) | |
# update the in_channels of mask_feature_head | |
if mask_feature_head.get('in_channels', None) is not None: | |
if mask_feature_head.in_channels != self.in_channels: | |
warnings.warn('The `in_channels` of SOLOv2MaskFeatHead and ' | |
'SOLOv2Head should be same, changing ' | |
'mask_feature_head.in_channels to ' | |
f'{self.in_channels}') | |
mask_feature_head.update(in_channels=self.in_channels) | |
else: | |
mask_feature_head.update(in_channels=self.in_channels) | |
self.mask_feature_head = MaskFeatModule(**mask_feature_head) | |
self.mask_stride = self.mask_feature_head.mask_stride | |
self.fp16_enabled = False | |
def _init_layers(self): | |
self.cls_convs = nn.ModuleList() | |
self.kernel_convs = nn.ModuleList() | |
conv_cfg = None | |
for i in range(self.stacked_convs): | |
if self.with_dcn: | |
if self.dcn_apply_to_all_conv: | |
conv_cfg = self.dcn_cfg | |
elif i == self.stacked_convs - 1: | |
# light head | |
conv_cfg = self.dcn_cfg | |
chn = self.in_channels + 2 if i == 0 else self.feat_channels | |
self.kernel_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.norm_cfg is None)) | |
chn = self.in_channels if i == 0 else self.feat_channels | |
self.cls_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.norm_cfg is None)) | |
self.conv_cls = nn.Conv2d( | |
self.feat_channels, self.cls_out_channels, 3, padding=1) | |
self.conv_kernel = nn.Conv2d( | |
self.feat_channels, self.kernel_out_channels, 3, padding=1) | |
def forward(self, feats): | |
assert len(feats) == self.num_levels | |
mask_feats = self.mask_feature_head(feats) | |
feats = self.resize_feats(feats) | |
mlvl_kernel_preds = [] | |
mlvl_cls_preds = [] | |
for i in range(self.num_levels): | |
ins_kernel_feat = feats[i] | |
# ins branch | |
# concat coord | |
coord_feat = generate_coordinate(ins_kernel_feat.size(), | |
ins_kernel_feat.device) | |
ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1) | |
# kernel branch | |
kernel_feat = ins_kernel_feat | |
kernel_feat = F.interpolate( | |
kernel_feat, | |
size=self.num_grids[i], | |
mode='bilinear', | |
align_corners=False) | |
cate_feat = kernel_feat[:, :-2, :, :] | |
kernel_feat = kernel_feat.contiguous() | |
for i, kernel_conv in enumerate(self.kernel_convs): | |
kernel_feat = kernel_conv(kernel_feat) | |
kernel_pred = self.conv_kernel(kernel_feat) | |
# cate branch | |
cate_feat = cate_feat.contiguous() | |
for i, cls_conv in enumerate(self.cls_convs): | |
cate_feat = cls_conv(cate_feat) | |
cate_pred = self.conv_cls(cate_feat) | |
mlvl_kernel_preds.append(kernel_pred) | |
mlvl_cls_preds.append(cate_pred) | |
return mlvl_kernel_preds, mlvl_cls_preds, mask_feats | |
def _get_targets_single(self, | |
gt_bboxes, | |
gt_labels, | |
gt_masks, | |
featmap_size=None): | |
"""Compute targets for predictions of single image. | |
Args: | |
gt_bboxes (Tensor): Ground truth bbox of each instance, | |
shape (num_gts, 4). | |
gt_labels (Tensor): Ground truth label of each instance, | |
shape (num_gts,). | |
gt_masks (Tensor): Ground truth mask of each instance, | |
shape (num_gts, h, w). | |
featmap_sizes (:obj:`torch.size`): Size of UNified mask | |
feature map used to generate instance segmentation | |
masks by dynamic convolution, each element means | |
(feat_h, feat_w). Default: None. | |
Returns: | |
Tuple: Usually returns a tuple containing targets for predictions. | |
- mlvl_pos_mask_targets (list[Tensor]): Each element represent | |
the binary mask targets for positive points in this | |
level, has shape (num_pos, out_h, out_w). | |
- mlvl_labels (list[Tensor]): Each element is | |
classification labels for all | |
points in this level, has shape | |
(num_grid, num_grid). | |
- mlvl_pos_masks (list[Tensor]): Each element is | |
a `BoolTensor` to represent whether the | |
corresponding point in single level | |
is positive, has shape (num_grid **2). | |
- mlvl_pos_indexes (list[list]): Each element | |
in the list contains the positive index in | |
corresponding level, has shape (num_pos). | |
""" | |
device = gt_labels.device | |
gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * | |
(gt_bboxes[:, 3] - gt_bboxes[:, 1])) | |
mlvl_pos_mask_targets = [] | |
mlvl_pos_indexes = [] | |
mlvl_labels = [] | |
mlvl_pos_masks = [] | |
for (lower_bound, upper_bound), num_grid \ | |
in zip(self.scale_ranges, self.num_grids): | |
mask_target = [] | |
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
pos_index = [] | |
labels = torch.zeros([num_grid, num_grid], | |
dtype=torch.int64, | |
device=device) + self.num_classes | |
pos_mask = torch.zeros([num_grid**2], | |
dtype=torch.bool, | |
device=device) | |
gt_inds = ((gt_areas >= lower_bound) & | |
(gt_areas <= upper_bound)).nonzero().flatten() | |
if len(gt_inds) == 0: | |
mlvl_pos_mask_targets.append( | |
torch.zeros([0, featmap_size[0], featmap_size[1]], | |
dtype=torch.uint8, | |
device=device)) | |
mlvl_labels.append(labels) | |
mlvl_pos_masks.append(pos_mask) | |
mlvl_pos_indexes.append([]) | |
continue | |
hit_gt_bboxes = gt_bboxes[gt_inds] | |
hit_gt_labels = gt_labels[gt_inds] | |
hit_gt_masks = gt_masks[gt_inds, ...] | |
pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - | |
hit_gt_bboxes[:, 0]) * self.pos_scale | |
pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - | |
hit_gt_bboxes[:, 1]) * self.pos_scale | |
# Make sure hit_gt_masks has a value | |
valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 | |
for gt_mask, gt_label, pos_h_range, pos_w_range, \ | |
valid_mask_flag in \ | |
zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, | |
pos_w_ranges, valid_mask_flags): | |
if not valid_mask_flag: | |
continue | |
upsampled_size = (featmap_size[0] * self.mask_stride, | |
featmap_size[1] * self.mask_stride) | |
center_h, center_w = center_of_mass(gt_mask) | |
coord_w = int( | |
floordiv((center_w / upsampled_size[1]), (1. / num_grid), | |
rounding_mode='trunc')) | |
coord_h = int( | |
floordiv((center_h / upsampled_size[0]), (1. / num_grid), | |
rounding_mode='trunc')) | |
# left, top, right, down | |
top_box = max( | |
0, | |
int( | |
floordiv( | |
(center_h - pos_h_range) / upsampled_size[0], | |
(1. / num_grid), | |
rounding_mode='trunc'))) | |
down_box = min( | |
num_grid - 1, | |
int( | |
floordiv( | |
(center_h + pos_h_range) / upsampled_size[0], | |
(1. / num_grid), | |
rounding_mode='trunc'))) | |
left_box = max( | |
0, | |
int( | |
floordiv( | |
(center_w - pos_w_range) / upsampled_size[1], | |
(1. / num_grid), | |
rounding_mode='trunc'))) | |
right_box = min( | |
num_grid - 1, | |
int( | |
floordiv( | |
(center_w + pos_w_range) / upsampled_size[1], | |
(1. / num_grid), | |
rounding_mode='trunc'))) | |
top = max(top_box, coord_h - 1) | |
down = min(down_box, coord_h + 1) | |
left = max(coord_w - 1, left_box) | |
right = min(right_box, coord_w + 1) | |
labels[top:(down + 1), left:(right + 1)] = gt_label | |
# ins | |
gt_mask = np.uint8(gt_mask.cpu().numpy()) | |
# Follow the original implementation, F.interpolate is | |
# different from cv2 and opencv | |
gt_mask = mmcv.imrescale(gt_mask, scale=1. / self.mask_stride) | |
gt_mask = torch.from_numpy(gt_mask).to(device=device) | |
for i in range(top, down + 1): | |
for j in range(left, right + 1): | |
index = int(i * num_grid + j) | |
this_mask_target = torch.zeros( | |
[featmap_size[0], featmap_size[1]], | |
dtype=torch.uint8, | |
device=device) | |
this_mask_target[:gt_mask.shape[0], :gt_mask. | |
shape[1]] = gt_mask | |
mask_target.append(this_mask_target) | |
pos_mask[index] = True | |
pos_index.append(index) | |
if len(mask_target) == 0: | |
mask_target = torch.zeros( | |
[0, featmap_size[0], featmap_size[1]], | |
dtype=torch.uint8, | |
device=device) | |
else: | |
mask_target = torch.stack(mask_target, 0) | |
mlvl_pos_mask_targets.append(mask_target) | |
mlvl_labels.append(labels) | |
mlvl_pos_masks.append(pos_mask) | |
mlvl_pos_indexes.append(pos_index) | |
return (mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks, | |
mlvl_pos_indexes) | |
def loss(self, | |
mlvl_kernel_preds, | |
mlvl_cls_preds, | |
mask_feats, | |
gt_labels, | |
gt_masks, | |
img_metas, | |
gt_bboxes=None, | |
**kwargs): | |
"""Calculate the loss of total batch. | |
Args: | |
mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel | |
prediction. The kernel is used to generate instance | |
segmentation masks by dynamic convolution. Each element in the | |
list has shape | |
(batch_size, kernel_out_channels, num_grids, num_grids). | |
mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element | |
in the list has shape | |
(batch_size, num_classes, num_grids, num_grids). | |
mask_feats (Tensor): Unified mask feature map used to generate | |
instance segmentation masks by dynamic convolution. Has shape | |
(batch_size, mask_out_channels, h, w). | |
gt_labels (list[Tensor]): Labels of multiple images. | |
gt_masks (list[Tensor]): Ground truth masks of multiple images. | |
Each has shape (num_instances, h, w). | |
img_metas (list[dict]): Meta information of multiple images. | |
gt_bboxes (list[Tensor]): Ground truth bboxes of multiple | |
images. Default: None. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
featmap_size = mask_feats.size()[-2:] | |
pos_mask_targets, labels, pos_masks, pos_indexes = multi_apply( | |
self._get_targets_single, | |
gt_bboxes, | |
gt_labels, | |
gt_masks, | |
featmap_size=featmap_size) | |
mlvl_mask_targets = [ | |
torch.cat(lvl_mask_targets, 0) | |
for lvl_mask_targets in zip(*pos_mask_targets) | |
] | |
mlvl_pos_kernel_preds = [] | |
for lvl_kernel_preds, lvl_pos_indexes in zip(mlvl_kernel_preds, | |
zip(*pos_indexes)): | |
lvl_pos_kernel_preds = [] | |
for img_lvl_kernel_preds, img_lvl_pos_indexes in zip( | |
lvl_kernel_preds, lvl_pos_indexes): | |
img_lvl_pos_kernel_preds = img_lvl_kernel_preds.view( | |
img_lvl_kernel_preds.shape[0], -1)[:, img_lvl_pos_indexes] | |
lvl_pos_kernel_preds.append(img_lvl_pos_kernel_preds) | |
mlvl_pos_kernel_preds.append(lvl_pos_kernel_preds) | |
# make multilevel mlvl_mask_pred | |
mlvl_mask_preds = [] | |
for lvl_pos_kernel_preds in mlvl_pos_kernel_preds: | |
lvl_mask_preds = [] | |
for img_id, img_lvl_pos_kernel_pred in enumerate( | |
lvl_pos_kernel_preds): | |
if img_lvl_pos_kernel_pred.size()[-1] == 0: | |
continue | |
img_mask_feats = mask_feats[[img_id]] | |
h, w = img_mask_feats.shape[-2:] | |
num_kernel = img_lvl_pos_kernel_pred.shape[1] | |
img_lvl_mask_pred = F.conv2d( | |
img_mask_feats, | |
img_lvl_pos_kernel_pred.permute(1, 0).view( | |
num_kernel, -1, self.dynamic_conv_size, | |
self.dynamic_conv_size), | |
stride=1).view(-1, h, w) | |
lvl_mask_preds.append(img_lvl_mask_pred) | |
if len(lvl_mask_preds) == 0: | |
lvl_mask_preds = None | |
else: | |
lvl_mask_preds = torch.cat(lvl_mask_preds, 0) | |
mlvl_mask_preds.append(lvl_mask_preds) | |
# dice loss | |
num_pos = 0 | |
for img_pos_masks in pos_masks: | |
for lvl_img_pos_masks in img_pos_masks: | |
num_pos += lvl_img_pos_masks.count_nonzero() | |
loss_mask = [] | |
for lvl_mask_preds, lvl_mask_targets in zip(mlvl_mask_preds, | |
mlvl_mask_targets): | |
if lvl_mask_preds is None: | |
continue | |
loss_mask.append( | |
self.loss_mask( | |
lvl_mask_preds, | |
lvl_mask_targets, | |
reduction_override='none')) | |
if num_pos > 0: | |
loss_mask = torch.cat(loss_mask).sum() / num_pos | |
else: | |
loss_mask = mask_feats.sum() * 0 | |
# cate | |
flatten_labels = [ | |
torch.cat( | |
[img_lvl_labels.flatten() for img_lvl_labels in lvl_labels]) | |
for lvl_labels in zip(*labels) | |
] | |
flatten_labels = torch.cat(flatten_labels) | |
flatten_cls_preds = [ | |
lvl_cls_preds.permute(0, 2, 3, 1).reshape(-1, self.num_classes) | |
for lvl_cls_preds in mlvl_cls_preds | |
] | |
flatten_cls_preds = torch.cat(flatten_cls_preds) | |
loss_cls = self.loss_cls( | |
flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) | |
return dict(loss_mask=loss_mask, loss_cls=loss_cls) | |
def get_results(self, mlvl_kernel_preds, mlvl_cls_scores, mask_feats, | |
img_metas, **kwargs): | |
"""Get multi-image mask results. | |
Args: | |
mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel | |
prediction. The kernel is used to generate instance | |
segmentation masks by dynamic convolution. Each element in the | |
list has shape | |
(batch_size, kernel_out_channels, num_grids, num_grids). | |
mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element | |
in the list has shape | |
(batch_size, num_classes, num_grids, num_grids). | |
mask_feats (Tensor): Unified mask feature map used to generate | |
instance segmentation masks by dynamic convolution. Has shape | |
(batch_size, mask_out_channels, h, w). | |
img_metas (list[dict]): Meta information of all images. | |
Returns: | |
list[:obj:`InstanceData`]: Processed results of multiple | |
images.Each :obj:`InstanceData` usually contains | |
following keys. | |
- scores (Tensor): Classification scores, has shape | |
(num_instance,). | |
- labels (Tensor): Has shape (num_instances,). | |
- masks (Tensor): Processed mask results, has | |
shape (num_instances, h, w). | |
""" | |
num_levels = len(mlvl_cls_scores) | |
assert len(mlvl_kernel_preds) == len(mlvl_cls_scores) | |
for lvl in range(num_levels): | |
cls_scores = mlvl_cls_scores[lvl] | |
cls_scores = cls_scores.sigmoid() | |
local_max = F.max_pool2d(cls_scores, 2, stride=1, padding=1) | |
keep_mask = local_max[:, :, :-1, :-1] == cls_scores | |
cls_scores = cls_scores * keep_mask | |
mlvl_cls_scores[lvl] = cls_scores.permute(0, 2, 3, 1) | |
result_list = [] | |
for img_id in range(len(img_metas)): | |
img_cls_pred = [ | |
mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) | |
for lvl in range(num_levels) | |
] | |
img_mask_feats = mask_feats[[img_id]] | |
img_kernel_pred = [ | |
mlvl_kernel_preds[lvl][img_id].permute(1, 2, 0).view( | |
-1, self.kernel_out_channels) for lvl in range(num_levels) | |
] | |
img_cls_pred = torch.cat(img_cls_pred, dim=0) | |
img_kernel_pred = torch.cat(img_kernel_pred, dim=0) | |
result = self._get_results_single( | |
img_kernel_pred, | |
img_cls_pred, | |
img_mask_feats, | |
img_meta=img_metas[img_id]) | |
result_list.append(result) | |
return result_list | |
def _get_results_single(self, | |
kernel_preds, | |
cls_scores, | |
mask_feats, | |
img_meta, | |
cfg=None): | |
"""Get processed mask related results of single image. | |
Args: | |
kernel_preds (Tensor): Dynamic kernel prediction of all points | |
in single image, has shape | |
(num_points, kernel_out_channels). | |
cls_scores (Tensor): Classification score of all points | |
in single image, has shape (num_points, num_classes). | |
mask_preds (Tensor): Mask prediction of all points in | |
single image, has shape (num_points, feat_h, feat_w). | |
img_meta (dict): Meta information of corresponding image. | |
cfg (dict, optional): Config used in test phase. | |
Default: None. | |
Returns: | |
:obj:`InstanceData`: Processed results of single image. | |
it usually contains following keys. | |
- scores (Tensor): Classification scores, has shape | |
(num_instance,). | |
- labels (Tensor): Has shape (num_instances,). | |
- masks (Tensor): Processed mask results, has | |
shape (num_instances, h, w). | |
""" | |
def empty_results(results, cls_scores): | |
"""Generate a empty results.""" | |
results.scores = cls_scores.new_ones(0) | |
results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) | |
results.labels = cls_scores.new_ones(0) | |
return results | |
cfg = self.test_cfg if cfg is None else cfg | |
assert len(kernel_preds) == len(cls_scores) | |
results = InstanceData(img_meta) | |
featmap_size = mask_feats.size()[-2:] | |
img_shape = results.img_shape | |
ori_shape = results.ori_shape | |
# overall info | |
h, w, _ = img_shape | |
upsampled_size = (featmap_size[0] * self.mask_stride, | |
featmap_size[1] * self.mask_stride) | |
# process. | |
score_mask = (cls_scores > cfg.score_thr) | |
cls_scores = cls_scores[score_mask] | |
if len(cls_scores) == 0: | |
return empty_results(results, cls_scores) | |
# cate_labels & kernel_preds | |
inds = score_mask.nonzero() | |
cls_labels = inds[:, 1] | |
kernel_preds = kernel_preds[inds[:, 0]] | |
# trans vector. | |
lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) | |
strides = kernel_preds.new_ones(lvl_interval[-1]) | |
strides[:lvl_interval[0]] *= self.strides[0] | |
for lvl in range(1, self.num_levels): | |
strides[lvl_interval[lvl - | |
1]:lvl_interval[lvl]] *= self.strides[lvl] | |
strides = strides[inds[:, 0]] | |
# mask encoding. | |
kernel_preds = kernel_preds.view( | |
kernel_preds.size(0), -1, self.dynamic_conv_size, | |
self.dynamic_conv_size) | |
mask_preds = F.conv2d( | |
mask_feats, kernel_preds, stride=1).squeeze(0).sigmoid() | |
# mask. | |
masks = mask_preds > cfg.mask_thr | |
sum_masks = masks.sum((1, 2)).float() | |
keep = sum_masks > strides | |
if keep.sum() == 0: | |
return empty_results(results, cls_scores) | |
masks = masks[keep] | |
mask_preds = mask_preds[keep] | |
sum_masks = sum_masks[keep] | |
cls_scores = cls_scores[keep] | |
cls_labels = cls_labels[keep] | |
# maskness. | |
mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks | |
cls_scores *= mask_scores | |
scores, labels, _, keep_inds = mask_matrix_nms( | |
masks, | |
cls_labels, | |
cls_scores, | |
mask_area=sum_masks, | |
nms_pre=cfg.nms_pre, | |
max_num=cfg.max_per_img, | |
kernel=cfg.kernel, | |
sigma=cfg.sigma, | |
filter_thr=cfg.filter_thr) | |
mask_preds = mask_preds[keep_inds] | |
mask_preds = F.interpolate( | |
mask_preds.unsqueeze(0), | |
size=upsampled_size, | |
mode='bilinear', | |
align_corners=False)[:, :, :h, :w] | |
mask_preds = F.interpolate( | |
mask_preds, | |
size=ori_shape[:2], | |
mode='bilinear', | |
align_corners=False).squeeze(0) | |
masks = mask_preds > cfg.mask_thr | |
results.masks = masks | |
results.labels = labels | |
results.scores = scores | |
return results | |