Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import Conv2d, Linear, MaxPool2d | |
from mmcv.runner import BaseModule, force_fp32 | |
from torch.nn.modules.utils import _pair | |
from mmdet.models.builder import HEADS, build_loss | |
class MaskIoUHead(BaseModule): | |
"""Mask IoU Head. | |
This head predicts the IoU of predicted masks and corresponding gt masks. | |
""" | |
def __init__(self, | |
num_convs=4, | |
num_fcs=2, | |
roi_feat_size=14, | |
in_channels=256, | |
conv_out_channels=256, | |
fc_out_channels=1024, | |
num_classes=80, | |
loss_iou=dict(type='MSELoss', loss_weight=0.5), | |
init_cfg=[ | |
dict(type='Kaiming', override=dict(name='convs')), | |
dict(type='Caffe2Xavier', override=dict(name='fcs')), | |
dict( | |
type='Normal', | |
std=0.01, | |
override=dict(name='fc_mask_iou')) | |
]): | |
super(MaskIoUHead, self).__init__(init_cfg) | |
self.in_channels = in_channels | |
self.conv_out_channels = conv_out_channels | |
self.fc_out_channels = fc_out_channels | |
self.num_classes = num_classes | |
self.fp16_enabled = False | |
self.convs = nn.ModuleList() | |
for i in range(num_convs): | |
if i == 0: | |
# concatenation of mask feature and mask prediction | |
in_channels = self.in_channels + 1 | |
else: | |
in_channels = self.conv_out_channels | |
stride = 2 if i == num_convs - 1 else 1 | |
self.convs.append( | |
Conv2d( | |
in_channels, | |
self.conv_out_channels, | |
3, | |
stride=stride, | |
padding=1)) | |
roi_feat_size = _pair(roi_feat_size) | |
pooled_area = (roi_feat_size[0] // 2) * (roi_feat_size[1] // 2) | |
self.fcs = nn.ModuleList() | |
for i in range(num_fcs): | |
in_channels = ( | |
self.conv_out_channels * | |
pooled_area if i == 0 else self.fc_out_channels) | |
self.fcs.append(Linear(in_channels, self.fc_out_channels)) | |
self.fc_mask_iou = Linear(self.fc_out_channels, self.num_classes) | |
self.relu = nn.ReLU() | |
self.max_pool = MaxPool2d(2, 2) | |
self.loss_iou = build_loss(loss_iou) | |
def forward(self, mask_feat, mask_pred): | |
mask_pred = mask_pred.sigmoid() | |
mask_pred_pooled = self.max_pool(mask_pred.unsqueeze(1)) | |
x = torch.cat((mask_feat, mask_pred_pooled), 1) | |
for conv in self.convs: | |
x = self.relu(conv(x)) | |
x = x.flatten(1) | |
for fc in self.fcs: | |
x = self.relu(fc(x)) | |
mask_iou = self.fc_mask_iou(x) | |
return mask_iou | |
def loss(self, mask_iou_pred, mask_iou_targets): | |
pos_inds = mask_iou_targets > 0 | |
if pos_inds.sum() > 0: | |
loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds], | |
mask_iou_targets[pos_inds]) | |
else: | |
loss_mask_iou = mask_iou_pred.sum() * 0 | |
return dict(loss_mask_iou=loss_mask_iou) | |
def get_targets(self, sampling_results, gt_masks, mask_pred, mask_targets, | |
rcnn_train_cfg): | |
"""Compute target of mask IoU. | |
Mask IoU target is the IoU of the predicted mask (inside a bbox) and | |
the gt mask of corresponding gt mask (the whole instance). | |
The intersection area is computed inside the bbox, and the gt mask area | |
is computed with two steps, firstly we compute the gt area inside the | |
bbox, then divide it by the area ratio of gt area inside the bbox and | |
the gt area of the whole instance. | |
Args: | |
sampling_results (list[:obj:`SamplingResult`]): sampling results. | |
gt_masks (BitmapMask | PolygonMask): Gt masks (the whole instance) | |
of each image, with the same shape of the input image. | |
mask_pred (Tensor): Predicted masks of each positive proposal, | |
shape (num_pos, h, w). | |
mask_targets (Tensor): Gt mask of each positive proposal, | |
binary map of the shape (num_pos, h, w). | |
rcnn_train_cfg (dict): Training config for R-CNN part. | |
Returns: | |
Tensor: mask iou target (length == num positive). | |
""" | |
pos_proposals = [res.pos_bboxes for res in sampling_results] | |
pos_assigned_gt_inds = [ | |
res.pos_assigned_gt_inds for res in sampling_results | |
] | |
# compute the area ratio of gt areas inside the proposals and | |
# the whole instance | |
area_ratios = map(self._get_area_ratio, pos_proposals, | |
pos_assigned_gt_inds, gt_masks) | |
area_ratios = torch.cat(list(area_ratios)) | |
assert mask_targets.size(0) == area_ratios.size(0) | |
mask_pred = (mask_pred > rcnn_train_cfg.mask_thr_binary).float() | |
mask_pred_areas = mask_pred.sum((-1, -2)) | |
# mask_pred and mask_targets are binary maps | |
overlap_areas = (mask_pred * mask_targets).sum((-1, -2)) | |
# compute the mask area of the whole instance | |
gt_full_areas = mask_targets.sum((-1, -2)) / (area_ratios + 1e-7) | |
mask_iou_targets = overlap_areas / ( | |
mask_pred_areas + gt_full_areas - overlap_areas) | |
return mask_iou_targets | |
def _get_area_ratio(self, pos_proposals, pos_assigned_gt_inds, gt_masks): | |
"""Compute area ratio of the gt mask inside the proposal and the gt | |
mask of the corresponding instance.""" | |
num_pos = pos_proposals.size(0) | |
if num_pos > 0: | |
area_ratios = [] | |
proposals_np = pos_proposals.cpu().numpy() | |
pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() | |
# compute mask areas of gt instances (batch processing for speedup) | |
gt_instance_mask_area = gt_masks.areas | |
for i in range(num_pos): | |
gt_mask = gt_masks[pos_assigned_gt_inds[i]] | |
# crop the gt mask inside the proposal | |
bbox = proposals_np[i, :].astype(np.int32) | |
gt_mask_in_proposal = gt_mask.crop(bbox) | |
ratio = gt_mask_in_proposal.areas[0] / ( | |
gt_instance_mask_area[pos_assigned_gt_inds[i]] + 1e-7) | |
area_ratios.append(ratio) | |
area_ratios = torch.from_numpy(np.stack(area_ratios)).float().to( | |
pos_proposals.device) | |
else: | |
area_ratios = pos_proposals.new_zeros((0, )) | |
return area_ratios | |
def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels): | |
"""Get the mask scores. | |
mask_score = bbox_score * mask_iou | |
""" | |
inds = range(det_labels.size(0)) | |
mask_scores = mask_iou_pred[inds, det_labels] * det_bboxes[inds, -1] | |
mask_scores = mask_scores.cpu().numpy() | |
det_labels = det_labels.cpu().numpy() | |
return [mask_scores[det_labels == i] for i in range(self.num_classes)] | |