# Copyright (c) OpenMMLab. All rights reserved. import copy import mmcv import numpy as np from mmdet.core import INSTANCE_OFFSET, bbox2result from mmdet.core.visualization import imshow_det_bboxes from ..builder import DETECTORS, build_backbone, build_head, build_neck from .single_stage import SingleStageDetector @DETECTORS.register_module() class MaskFormer(SingleStageDetector): r"""Implementation of `Per-Pixel Classification is NOT All You Need for Semantic Segmentation `_.""" def __init__(self, backbone, neck=None, panoptic_head=None, panoptic_fusion_head=None, train_cfg=None, test_cfg=None, init_cfg=None): super(SingleStageDetector, self).__init__(init_cfg=init_cfg) self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) panoptic_head_ = copy.deepcopy(panoptic_head) panoptic_head_.update(train_cfg=train_cfg) panoptic_head_.update(test_cfg=test_cfg) self.panoptic_head = build_head(panoptic_head_) panoptic_fusion_head_ = copy.deepcopy(panoptic_fusion_head) panoptic_fusion_head_.update(test_cfg=test_cfg) self.panoptic_fusion_head = build_head(panoptic_fusion_head_) self.num_things_classes = self.panoptic_head.num_things_classes self.num_stuff_classes = self.panoptic_head.num_stuff_classes self.num_classes = self.panoptic_head.num_classes self.train_cfg = train_cfg self.test_cfg = test_cfg # BaseDetector.show_result default for instance segmentation if self.num_stuff_classes > 0: self.show_result = self._show_pan_result def forward_dummy(self, img, img_metas): """Used for computing network flops. See `mmdetection/tools/analysis_tools/get_flops.py` Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[Dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. """ super(SingleStageDetector, self).forward_train(img, img_metas) x = self.extract_feat(img) outs = self.panoptic_head(x, img_metas) return outs def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_masks, gt_semantic_seg=None, gt_bboxes_ignore=None, **kargs): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[Dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box. gt_masks (list[BitmapMasks]): true segmentation masks for each box used if the architecture supports a segmentation task. gt_semantic_seg (list[tensor]): semantic segmentation mask for images for panoptic segmentation. Defaults to None for instance segmentation. gt_bboxes_ignore (list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Defaults to None. Returns: dict[str, Tensor]: a dictionary of loss components """ # add batch_input_shape in img_metas super(SingleStageDetector, self).forward_train(img, img_metas) x = self.extract_feat(img) losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes, gt_labels, gt_masks, gt_semantic_seg, gt_bboxes_ignore) return losses def simple_test(self, imgs, img_metas, **kwargs): """Test without augmentation. Args: imgs (Tensor): A batch of images. img_metas (list[dict]): List of image information. Returns: list[dict[str, np.array | tuple[list]] | tuple[list]]: Semantic segmentation results and panoptic segmentation \ results of each image for panoptic segmentation, or formatted \ bbox and mask results of each image for instance segmentation. .. code-block:: none [ # panoptic segmentation { 'pan_results': np.array, # shape = [h, w] 'ins_results': tuple[list], # semantic segmentation results are not supported yet 'sem_results': np.array }, ... ] or .. code-block:: none [ # instance segmentation ( bboxes, # list[np.array] masks # list[list[np.array]] ), ... ] """ feats = self.extract_feat(imgs) mask_cls_results, mask_pred_results = self.panoptic_head.simple_test( feats, img_metas, **kwargs) results = self.panoptic_fusion_head.simple_test( mask_cls_results, mask_pred_results, img_metas, **kwargs) for i in range(len(results)): if 'pan_results' in results[i]: results[i]['pan_results'] = results[i]['pan_results'].detach( ).cpu().numpy() if 'ins_results' in results[i]: labels_per_image, bboxes, mask_pred_binary = results[i][ 'ins_results'] bbox_results = bbox2result(bboxes, labels_per_image, self.num_things_classes) mask_results = [[] for _ in range(self.num_things_classes)] for j, label in enumerate(labels_per_image): mask = mask_pred_binary[j].detach().cpu().numpy() mask_results[label].append(mask) results[i]['ins_results'] = bbox_results, mask_results assert 'sem_results' not in results[i], 'segmantic segmentation '\ 'results are not supported yet.' if self.num_stuff_classes == 0: results = [res['ins_results'] for res in results] return results def aug_test(self, imgs, img_metas, **kwargs): raise NotImplementedError def onnx_export(self, img, img_metas): raise NotImplementedError def _show_pan_result(self, img, result, score_thr=0.3, bbox_color=(72, 101, 241), text_color=(72, 101, 241), mask_color=None, thickness=2, font_size=13, win_name='', show=False, wait_time=0, out_file=None): """Draw `panoptic result` over `img`. Args: img (str or Tensor): The image to be displayed. result (dict): The results. score_thr (float, optional): Minimum score of bboxes to be shown. Default: 0.3. bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. The tuple of color should be in BGR order. Default: 'green'. text_color (str or tuple(int) or :obj:`Color`):Color of texts. The tuple of color should be in BGR order. Default: 'green'. mask_color (None or str or tuple(int) or :obj:`Color`): Color of masks. The tuple of color should be in BGR order. Default: None. thickness (int): Thickness of lines. Default: 2. font_size (int): Font size of texts. Default: 13. win_name (str): The window name. Default: ''. wait_time (float): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. Returns: img (Tensor): Only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() pan_results = result['pan_results'] # keep objects ahead ids = np.unique(pan_results)[::-1] legal_indices = ids != self.num_classes # for VOID label ids = ids[legal_indices] labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) segms = (pan_results[None] == ids[:, None, None]) # if out_file specified, do not show image in window if out_file is not None: show = False # draw bounding boxes img = imshow_det_bboxes( img, segms=segms, labels=labels, class_names=self.CLASSES, bbox_color=bbox_color, text_color=text_color, mask_color=mask_color, thickness=thickness, font_size=font_size, win_name=win_name, show=show, wait_time=wait_time, out_file=out_file) if not (show or out_file): return img