Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init | |
from mmcv.cnn.bricks.transformer import (build_positional_encoding, | |
build_transformer_layer_sequence) | |
from mmcv.runner import force_fp32 | |
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean | |
from mmdet.models.utils import preprocess_panoptic_gt | |
from ..builder import HEADS, build_loss | |
from .anchor_free_head import AnchorFreeHead | |
class MaskFormerHead(AnchorFreeHead): | |
"""Implements the MaskFormer head. | |
See `Per-Pixel Classification is Not All You Need for Semantic | |
Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details. | |
Args: | |
in_channels (list[int]): Number of channels in the input feature map. | |
feat_channels (int): Number of channels for feature. | |
out_channels (int): Number of channels for output. | |
num_things_classes (int): Number of things. | |
num_stuff_classes (int): Number of stuff. | |
num_queries (int): Number of query in Transformer. | |
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel | |
decoder. Defaults to None. | |
enforce_decoder_input_project (bool, optional): Whether to add a layer | |
to change the embed_dim of tranformer encoder in pixel decoder to | |
the embed_dim of transformer decoder. Defaults to False. | |
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for | |
transformer decoder. Defaults to None. | |
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for | |
transformer decoder position encoding. Defaults to None. | |
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification | |
loss. Defaults to `CrossEntropyLoss`. | |
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. | |
Defaults to `FocalLoss`. | |
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. | |
Defaults to `DiceLoss`. | |
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of | |
Maskformer head. | |
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of Maskformer | |
head. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_channels, | |
feat_channels, | |
out_channels, | |
num_things_classes=80, | |
num_stuff_classes=53, | |
num_queries=100, | |
pixel_decoder=None, | |
enforce_decoder_input_project=False, | |
transformer_decoder=None, | |
positional_encoding=None, | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=False, | |
loss_weight=1.0, | |
class_weight=[1.0] * 133 + [0.1]), | |
loss_mask=dict( | |
type='FocalLoss', | |
use_sigmoid=True, | |
gamma=2.0, | |
alpha=0.25, | |
loss_weight=20.0), | |
loss_dice=dict( | |
type='DiceLoss', | |
use_sigmoid=True, | |
activate=True, | |
naive_dice=True, | |
loss_weight=1.0), | |
train_cfg=None, | |
test_cfg=None, | |
init_cfg=None, | |
**kwargs): | |
super(AnchorFreeHead, self).__init__(init_cfg) | |
self.num_things_classes = num_things_classes | |
self.num_stuff_classes = num_stuff_classes | |
self.num_classes = self.num_things_classes + self.num_stuff_classes | |
self.num_queries = num_queries | |
pixel_decoder.update( | |
in_channels=in_channels, | |
feat_channels=feat_channels, | |
out_channels=out_channels) | |
self.pixel_decoder = build_plugin_layer(pixel_decoder)[1] | |
self.transformer_decoder = build_transformer_layer_sequence( | |
transformer_decoder) | |
self.decoder_embed_dims = self.transformer_decoder.embed_dims | |
pixel_decoder_type = pixel_decoder.get('type') | |
if pixel_decoder_type == 'PixelDecoder' and ( | |
self.decoder_embed_dims != in_channels[-1] | |
or enforce_decoder_input_project): | |
self.decoder_input_proj = Conv2d( | |
in_channels[-1], self.decoder_embed_dims, kernel_size=1) | |
else: | |
self.decoder_input_proj = nn.Identity() | |
self.decoder_pe = build_positional_encoding(positional_encoding) | |
self.query_embed = nn.Embedding(self.num_queries, out_channels) | |
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |
self.mask_embed = nn.Sequential( | |
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
nn.Linear(feat_channels, out_channels)) | |
self.test_cfg = test_cfg | |
self.train_cfg = train_cfg | |
if train_cfg: | |
self.assigner = build_assigner(train_cfg.get('assigner', None)) | |
self.sampler = build_sampler( | |
train_cfg.get('sampler', None), context=self) | |
self.class_weight = loss_cls.get('class_weight', None) | |
self.loss_cls = build_loss(loss_cls) | |
self.loss_mask = build_loss(loss_mask) | |
self.loss_dice = build_loss(loss_dice) | |
def init_weights(self): | |
if isinstance(self.decoder_input_proj, Conv2d): | |
caffe2_xavier_init(self.decoder_input_proj, bias=0) | |
self.pixel_decoder.init_weights() | |
for p in self.transformer_decoder.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def preprocess_gt(self, gt_labels_list, gt_masks_list, gt_semantic_segs, | |
img_metas): | |
"""Preprocess the ground truth for all images. | |
Args: | |
gt_labels_list (list[Tensor]): Each is ground truth | |
labels of each bbox, with shape (num_gts, ). | |
gt_masks_list (list[BitmapMasks]): Each is ground truth | |
masks of each instances of a image, shape | |
(num_gts, h, w). | |
gt_semantic_seg (Tensor | None): Ground truth of semantic | |
segmentation with the shape (batch_size, n, h, w). | |
[0, num_thing_class - 1] means things, | |
[num_thing_class, num_class-1] means stuff, | |
255 means VOID. It's None when training instance segmentation. | |
img_metas (list[dict]): List of image meta information. | |
Returns: | |
tuple: a tuple containing the following targets. | |
- labels (list[Tensor]): Ground truth class indices\ | |
for all images. Each with shape (n, ), n is the sum of\ | |
number of stuff type and number of instance in a image. | |
- masks (list[Tensor]): Ground truth mask for each\ | |
image, each with shape (n, h, w). | |
""" | |
num_things_list = [self.num_things_classes] * len(gt_labels_list) | |
num_stuff_list = [self.num_stuff_classes] * len(gt_labels_list) | |
if gt_semantic_segs is None: | |
gt_semantic_segs = [None] * len(gt_labels_list) | |
targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, | |
gt_masks_list, gt_semantic_segs, num_things_list, | |
num_stuff_list, img_metas) | |
labels, masks = targets | |
return labels, masks | |
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, | |
gt_masks_list, img_metas): | |
"""Compute classification and mask targets for all images for a decoder | |
layer. | |
Args: | |
cls_scores_list (list[Tensor]): Mask score logits from a single | |
decoder layer for all images. Each with shape (num_queries, | |
cls_out_channels). | |
mask_preds_list (list[Tensor]): Mask logits from a single decoder | |
layer for all images. Each with shape (num_queries, h, w). | |
gt_labels_list (list[Tensor]): Ground truth class indices for all | |
images. Each with shape (n, ), n is the sum of number of stuff | |
type and number of instance in a image. | |
gt_masks_list (list[Tensor]): Ground truth mask for each image, | |
each with shape (n, h, w). | |
img_metas (list[dict]): List of image meta information. | |
Returns: | |
tuple[list[Tensor]]: a tuple containing the following targets. | |
- labels_list (list[Tensor]): Labels of all images.\ | |
Each with shape (num_queries, ). | |
- label_weights_list (list[Tensor]): Label weights\ | |
of all images. Each with shape (num_queries, ). | |
- mask_targets_list (list[Tensor]): Mask targets of\ | |
all images. Each with shape (num_queries, h, w). | |
- mask_weights_list (list[Tensor]): Mask weights of\ | |
all images. Each with shape (num_queries, ). | |
- num_total_pos (int): Number of positive samples in\ | |
all images. | |
- num_total_neg (int): Number of negative samples in\ | |
all images. | |
""" | |
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |
pos_inds_list, | |
neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, | |
mask_preds_list, gt_labels_list, | |
gt_masks_list, img_metas) | |
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) | |
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) | |
return (labels_list, label_weights_list, mask_targets_list, | |
mask_weights_list, num_total_pos, num_total_neg) | |
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, | |
img_metas): | |
"""Compute classification and mask targets for one image. | |
Args: | |
cls_score (Tensor): Mask score logits from a single decoder layer | |
for one image. Shape (num_queries, cls_out_channels). | |
mask_pred (Tensor): Mask logits for a single decoder layer for one | |
image. Shape (num_queries, h, w). | |
gt_labels (Tensor): Ground truth class indices for one image with | |
shape (n, ). n is the sum of number of stuff type and number | |
of instance in a image. | |
gt_masks (Tensor): Ground truth mask for each image, each with | |
shape (n, h, w). | |
img_metas (dict): Image informtation. | |
Returns: | |
tuple[Tensor]: a tuple containing the following for one image. | |
- labels (Tensor): Labels of each image. | |
shape (num_queries, ). | |
- label_weights (Tensor): Label weights of each image. | |
shape (num_queries, ). | |
- mask_targets (Tensor): Mask targets of each image. | |
shape (num_queries, h, w). | |
- mask_weights (Tensor): Mask weights of each image. | |
shape (num_queries, ). | |
- pos_inds (Tensor): Sampled positive indices for each image. | |
- neg_inds (Tensor): Sampled negative indices for each image. | |
""" | |
target_shape = mask_pred.shape[-2:] | |
if gt_masks.shape[0] > 0: | |
gt_masks_downsampled = F.interpolate( | |
gt_masks.unsqueeze(1).float(), target_shape, | |
mode='nearest').squeeze(1).long() | |
else: | |
gt_masks_downsampled = gt_masks | |
# assign and sample | |
assign_result = self.assigner.assign(cls_score, mask_pred, gt_labels, | |
gt_masks_downsampled, img_metas) | |
sampling_result = self.sampler.sample(assign_result, mask_pred, | |
gt_masks) | |
pos_inds = sampling_result.pos_inds | |
neg_inds = sampling_result.neg_inds | |
# label target | |
labels = gt_labels.new_full((self.num_queries, ), | |
self.num_classes, | |
dtype=torch.long) | |
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] | |
label_weights = gt_labels.new_ones(self.num_queries) | |
# mask target | |
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] | |
mask_weights = mask_pred.new_zeros((self.num_queries, )) | |
mask_weights[pos_inds] = 1.0 | |
return (labels, label_weights, mask_targets, mask_weights, pos_inds, | |
neg_inds) | |
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, | |
gt_masks_list, img_metas): | |
"""Loss function. | |
Args: | |
all_cls_scores (Tensor): Classification scores for all decoder | |
layers with shape (num_decoder, batch_size, num_queries, | |
cls_out_channels). Note `cls_out_channels` should includes | |
background. | |
all_mask_preds (Tensor): Mask scores for all decoder layers with | |
shape (num_decoder, batch_size, num_queries, h, w). | |
gt_labels_list (list[Tensor]): Ground truth class indices for each | |
image with shape (n, ). n is the sum of number of stuff type | |
and number of instance in a image. | |
gt_masks_list (list[Tensor]): Ground truth mask for each image with | |
shape (n, h, w). | |
img_metas (list[dict]): List of image meta information. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
num_dec_layers = len(all_cls_scores) | |
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] | |
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] | |
img_metas_list = [img_metas for _ in range(num_dec_layers)] | |
losses_cls, losses_mask, losses_dice = multi_apply( | |
self.loss_single, all_cls_scores, all_mask_preds, | |
all_gt_labels_list, all_gt_masks_list, img_metas_list) | |
loss_dict = dict() | |
# loss from the last decoder layer | |
loss_dict['loss_cls'] = losses_cls[-1] | |
loss_dict['loss_mask'] = losses_mask[-1] | |
loss_dict['loss_dice'] = losses_dice[-1] | |
# loss from other decoder layers | |
num_dec_layer = 0 | |
for loss_cls_i, loss_mask_i, loss_dice_i in zip( | |
losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): | |
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i | |
loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i | |
loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i | |
num_dec_layer += 1 | |
return loss_dict | |
def loss_single(self, cls_scores, mask_preds, gt_labels_list, | |
gt_masks_list, img_metas): | |
"""Loss function for outputs from a single decoder layer. | |
Args: | |
cls_scores (Tensor): Mask score logits from a single decoder layer | |
for all images. Shape (batch_size, num_queries, | |
cls_out_channels). Note `cls_out_channels` should includes | |
background. | |
mask_preds (Tensor): Mask logits for a pixel decoder for all | |
images. Shape (batch_size, num_queries, h, w). | |
gt_labels_list (list[Tensor]): Ground truth class indices for each | |
image, each with shape (n, ). n is the sum of number of stuff | |
types and number of instances in a image. | |
gt_masks_list (list[Tensor]): Ground truth mask for each image, | |
each with shape (n, h, w). | |
img_metas (list[dict]): List of image meta information. | |
Returns: | |
tuple[Tensor]: Loss components for outputs from a single decoder\ | |
layer. | |
""" | |
num_imgs = cls_scores.size(0) | |
cls_scores_list = [cls_scores[i] for i in range(num_imgs)] | |
mask_preds_list = [mask_preds[i] for i in range(num_imgs)] | |
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |
num_total_pos, | |
num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, | |
gt_labels_list, gt_masks_list, | |
img_metas) | |
# shape (batch_size, num_queries) | |
labels = torch.stack(labels_list, dim=0) | |
# shape (batch_size, num_queries) | |
label_weights = torch.stack(label_weights_list, dim=0) | |
# shape (num_total_gts, h, w) | |
mask_targets = torch.cat(mask_targets_list, dim=0) | |
# shape (batch_size, num_queries) | |
mask_weights = torch.stack(mask_weights_list, dim=0) | |
# classfication loss | |
# shape (batch_size * num_queries, ) | |
cls_scores = cls_scores.flatten(0, 1) | |
labels = labels.flatten(0, 1) | |
label_weights = label_weights.flatten(0, 1) | |
class_weight = cls_scores.new_tensor(self.class_weight) | |
loss_cls = self.loss_cls( | |
cls_scores, | |
labels, | |
label_weights, | |
avg_factor=class_weight[labels].sum()) | |
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) | |
num_total_masks = max(num_total_masks, 1) | |
# extract positive ones | |
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) | |
mask_preds = mask_preds[mask_weights > 0] | |
target_shape = mask_targets.shape[-2:] | |
if mask_targets.shape[0] == 0: | |
# zero match | |
loss_dice = mask_preds.sum() | |
loss_mask = mask_preds.sum() | |
return loss_cls, loss_mask, loss_dice | |
# upsample to shape of target | |
# shape (num_total_gts, h, w) | |
mask_preds = F.interpolate( | |
mask_preds.unsqueeze(1), | |
target_shape, | |
mode='bilinear', | |
align_corners=False).squeeze(1) | |
# dice loss | |
loss_dice = self.loss_dice( | |
mask_preds, mask_targets, avg_factor=num_total_masks) | |
# mask loss | |
# FocalLoss support input of shape (n, num_class) | |
h, w = mask_preds.shape[-2:] | |
# shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1) | |
mask_preds = mask_preds.reshape(-1, 1) | |
# shape (num_total_gts, h, w) -> (num_total_gts * h * w) | |
mask_targets = mask_targets.reshape(-1) | |
# target is (1 - mask_targets) !!! | |
loss_mask = self.loss_mask( | |
mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w) | |
return loss_cls, loss_mask, loss_dice | |
def forward(self, feats, img_metas): | |
"""Forward function. | |
Args: | |
feats (list[Tensor]): Features from the upstream network, each | |
is a 4D-tensor. | |
img_metas (list[dict]): List of image information. | |
Returns: | |
tuple: a tuple contains two elements. | |
- all_cls_scores (Tensor): Classification scores for each\ | |
scale level. Each is a 4D-tensor with shape\ | |
(num_decoder, batch_size, num_queries, cls_out_channels).\ | |
Note `cls_out_channels` should includes background. | |
- all_mask_preds (Tensor): Mask scores for each decoder\ | |
layer. Each with shape (num_decoder, batch_size,\ | |
num_queries, h, w). | |
""" | |
batch_size = len(img_metas) | |
input_img_h, input_img_w = img_metas[0]['batch_input_shape'] | |
padding_mask = feats[-1].new_ones( | |
(batch_size, input_img_h, input_img_w), dtype=torch.float32) | |
for i in range(batch_size): | |
img_h, img_w, _ = img_metas[i]['img_shape'] | |
padding_mask[i, :img_h, :img_w] = 0 | |
padding_mask = F.interpolate( | |
padding_mask.unsqueeze(1), | |
size=feats[-1].shape[-2:], | |
mode='nearest').to(torch.bool).squeeze(1) | |
# when backbone is swin, memory is output of last stage of swin. | |
# when backbone is r50, memory is output of tranformer encoder. | |
mask_features, memory = self.pixel_decoder(feats, img_metas) | |
pos_embed = self.decoder_pe(padding_mask) | |
memory = self.decoder_input_proj(memory) | |
# shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |
memory = memory.flatten(2).permute(2, 0, 1) | |
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | |
# shape (batch_size, h * w) | |
padding_mask = padding_mask.flatten(1) | |
# shape = (num_queries, embed_dims) | |
query_embed = self.query_embed.weight | |
# shape = (num_queries, batch_size, embed_dims) | |
query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) | |
target = torch.zeros_like(query_embed) | |
# shape (num_decoder, num_queries, batch_size, embed_dims) | |
out_dec = self.transformer_decoder( | |
query=target, | |
key=memory, | |
value=memory, | |
key_pos=pos_embed, | |
query_pos=query_embed, | |
key_padding_mask=padding_mask) | |
# shape (num_decoder, batch_size, num_queries, embed_dims) | |
out_dec = out_dec.transpose(1, 2) | |
# cls_scores | |
all_cls_scores = self.cls_embed(out_dec) | |
# mask_preds | |
mask_embed = self.mask_embed(out_dec) | |
all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed, | |
mask_features) | |
return all_cls_scores, all_mask_preds | |
def forward_train(self, | |
feats, | |
img_metas, | |
gt_bboxes, | |
gt_labels, | |
gt_masks, | |
gt_semantic_seg, | |
gt_bboxes_ignore=None): | |
"""Forward function for training mode. | |
Args: | |
feats (list[Tensor]): Multi-level features from the upstream | |
network, each is a 4D-tensor. | |
img_metas (list[Dict]): List of image information. | |
gt_bboxes (list[Tensor]): Each element is ground truth bboxes of | |
the image, shape (num_gts, 4). Not used here. | |
gt_labels (list[Tensor]): Each element is ground truth labels of | |
each box, shape (num_gts,). | |
gt_masks (list[BitmapMasks]): Each element is masks of instances | |
of a image, shape (num_gts, h, w). | |
gt_semantic_seg (list[tensor] | None): Each element is the ground | |
truth of semantic segmentation with the shape (N, H, W). | |
[0, num_thing_class - 1] means things, | |
[num_thing_class, num_class-1] means stuff, | |
255 means VOID. It's None when training instance segmentation. | |
gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be | |
ignored. Defaults to None. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
# not consider ignoring bboxes | |
assert gt_bboxes_ignore is None | |
# forward | |
all_cls_scores, all_mask_preds = self(feats, img_metas) | |
# preprocess ground truth | |
gt_labels, gt_masks = self.preprocess_gt(gt_labels, gt_masks, | |
gt_semantic_seg, img_metas) | |
# loss | |
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, | |
img_metas) | |
return losses | |
def simple_test(self, feats, img_metas, **kwargs): | |
"""Test without augmentaton. | |
Args: | |
feats (list[Tensor]): Multi-level features from the | |
upstream network, each is a 4D-tensor. | |
img_metas (list[dict]): List of image information. | |
Returns: | |
tuple: A tuple contains two tensors. | |
- mask_cls_results (Tensor): Mask classification logits,\ | |
shape (batch_size, num_queries, cls_out_channels). | |
Note `cls_out_channels` should includes background. | |
- mask_pred_results (Tensor): Mask logits, shape \ | |
(batch_size, num_queries, h, w). | |
""" | |
all_cls_scores, all_mask_preds = self(feats, img_metas) | |
mask_cls_results = all_cls_scores[-1] | |
mask_pred_results = all_mask_preds[-1] | |
# upsample masks | |
img_shape = img_metas[0]['batch_input_shape'] | |
mask_pred_results = F.interpolate( | |
mask_pred_results, | |
size=(img_shape[0], img_shape[1]), | |
mode='bilinear', | |
align_corners=False) | |
return mask_cls_results, mask_pred_results | |