# Copyright (c) OpenMMLab. All rights reserved. import warnings import torch import torch.nn as nn from mmcv.runner import ModuleList from ..builder import HEADS from ..utils import ConvUpsample from .base_semantic_head import BaseSemanticHead @HEADS.register_module() class PanopticFPNHead(BaseSemanticHead): """PanopticFPNHead used in Panoptic FPN. In this head, the number of output channels is ``num_stuff_classes + 1``, including all stuff classes and one thing class. The stuff classes will be reset from ``0`` to ``num_stuff_classes - 1``, the thing classes will be merged to ``num_stuff_classes``-th channel. Arg: num_things_classes (int): Number of thing classes. Default: 80. num_stuff_classes (int): Number of stuff classes. Default: 53. num_classes (int): Number of classes, including all stuff classes and one thing class. This argument is deprecated, please use ``num_things_classes`` and ``num_stuff_classes``. The module will automatically infer the num_classes by ``num_stuff_classes + 1``. in_channels (int): Number of channels in the input feature map. inner_channels (int): Number of channels in inner features. start_level (int): The start level of the input features used in PanopticFPN. end_level (int): The end level of the used features, the ``end_level``-th layer will not be used. fg_range (tuple): Range of the foreground classes. It starts from ``0`` to ``num_things_classes-1``. Deprecated, please use ``num_things_classes`` directly. bg_range (tuple): Range of the background classes. It starts from ``num_things_classes`` to ``num_things_classes + num_stuff_classes - 1``. Deprecated, please use ``num_stuff_classes`` and ``num_things_classes`` directly. conv_cfg (dict): Dictionary to construct and config conv layer. Default: None. norm_cfg (dict): Dictionary to construct and config norm layer. Use ``GN`` by default. init_cfg (dict or list[dict], optional): Initialization config dict. loss_seg (dict): the loss of the semantic head. """ def __init__(self, num_things_classes=80, num_stuff_classes=53, num_classes=None, in_channels=256, inner_channels=128, start_level=0, end_level=4, fg_range=None, bg_range=None, conv_cfg=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), init_cfg=None, loss_seg=dict( type='CrossEntropyLoss', ignore_index=-1, loss_weight=1.0)): if num_classes is not None: warnings.warn( '`num_classes` is deprecated now, please set ' '`num_stuff_classes` directly, the `num_classes` will be ' 'set to `num_stuff_classes + 1`') # num_classes = num_stuff_classes + 1 for PanopticFPN. assert num_classes == num_stuff_classes + 1 super(PanopticFPNHead, self).__init__(num_stuff_classes + 1, init_cfg, loss_seg) self.num_things_classes = num_things_classes self.num_stuff_classes = num_stuff_classes if fg_range is not None and bg_range is not None: self.fg_range = fg_range self.bg_range = bg_range self.num_things_classes = fg_range[1] - fg_range[0] + 1 self.num_stuff_classes = bg_range[1] - bg_range[0] + 1 warnings.warn( '`fg_range` and `bg_range` are deprecated now, ' f'please use `num_things_classes`={self.num_things_classes} ' f'and `num_stuff_classes`={self.num_stuff_classes} instead.') # Used feature layers are [start_level, end_level) self.start_level = start_level self.end_level = end_level self.num_stages = end_level - start_level self.inner_channels = inner_channels self.conv_upsample_layers = ModuleList() for i in range(start_level, end_level): self.conv_upsample_layers.append( ConvUpsample( in_channels, inner_channels, num_layers=i if i > 0 else 1, num_upsample=i if i > 0 else 0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, )) self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1) def _set_things_to_void(self, gt_semantic_seg): """Merge thing classes to one class. In PanopticFPN, the background labels will be reset from `0` to `self.num_stuff_classes-1`, the foreground labels will be merged to `self.num_stuff_classes`-th channel. """ gt_semantic_seg = gt_semantic_seg.int() fg_mask = gt_semantic_seg < self.num_things_classes bg_mask = (gt_semantic_seg >= self.num_things_classes) * ( gt_semantic_seg < self.num_things_classes + self.num_stuff_classes) new_gt_seg = torch.clone(gt_semantic_seg) new_gt_seg = torch.where(bg_mask, gt_semantic_seg - self.num_things_classes, new_gt_seg) new_gt_seg = torch.where(fg_mask, fg_mask.int() * self.num_stuff_classes, new_gt_seg) return new_gt_seg def loss(self, seg_preds, gt_semantic_seg): """The loss of PanopticFPN head. Things classes will be merged to one class in PanopticFPN. """ gt_semantic_seg = self._set_things_to_void(gt_semantic_seg) return super().loss(seg_preds, gt_semantic_seg) def init_weights(self): super().init_weights() nn.init.normal_(self.conv_logits.weight.data, 0, 0.01) self.conv_logits.bias.data.zero_() def forward(self, x): # the number of subnets must be not more than # the length of features. assert self.num_stages <= len(x) feats = [] for i, layer in enumerate(self.conv_upsample_layers): f = layer(x[self.start_level + i]) feats.append(f) feats = torch.sum(torch.stack(feats, dim=0), dim=0) seg_preds = self.conv_logits(feats) out = dict(seg_preds=seg_preds, feats=feats) return out