RockeyCoss
add code files”
51f6859
raw
history blame
10.4 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import mmcv
import numpy as np
from mmdet.core import INSTANCE_OFFSET, bbox2result
from mmdet.core.visualization import imshow_det_bboxes
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class MaskFormer(SingleStageDetector):
r"""Implementation of `Per-Pixel Classification is
NOT All You Need for Semantic Segmentation
<https://arxiv.org/pdf/2107.06278>`_."""
def __init__(self,
backbone,
neck=None,
panoptic_head=None,
panoptic_fusion_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
panoptic_head_ = copy.deepcopy(panoptic_head)
panoptic_head_.update(train_cfg=train_cfg)
panoptic_head_.update(test_cfg=test_cfg)
self.panoptic_head = build_head(panoptic_head_)
panoptic_fusion_head_ = copy.deepcopy(panoptic_fusion_head)
panoptic_fusion_head_.update(test_cfg=test_cfg)
self.panoptic_fusion_head = build_head(panoptic_fusion_head_)
self.num_things_classes = self.panoptic_head.num_things_classes
self.num_stuff_classes = self.panoptic_head.num_stuff_classes
self.num_classes = self.panoptic_head.num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# BaseDetector.show_result default for instance segmentation
if self.num_stuff_classes > 0:
self.show_result = self._show_pan_result
def forward_dummy(self, img, img_metas):
"""Used for computing network flops. See
`mmdetection/tools/analysis_tools/get_flops.py`
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[Dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
outs = self.panoptic_head(x, img_metas)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_masks,
gt_semantic_seg=None,
gt_bboxes_ignore=None,
**kargs):
"""
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[Dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box.
gt_masks (list[BitmapMasks]): true segmentation masks for each box
used if the architecture supports a segmentation task.
gt_semantic_seg (list[tensor]): semantic segmentation mask for
images for panoptic segmentation.
Defaults to None for instance segmentation.
gt_bboxes_ignore (list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
Defaults to None.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# add batch_input_shape in img_metas
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_masks,
gt_semantic_seg,
gt_bboxes_ignore)
return losses
def simple_test(self, imgs, img_metas, **kwargs):
"""Test without augmentation.
Args:
imgs (Tensor): A batch of images.
img_metas (list[dict]): List of image information.
Returns:
list[dict[str, np.array | tuple[list]] | tuple[list]]:
Semantic segmentation results and panoptic segmentation \
results of each image for panoptic segmentation, or formatted \
bbox and mask results of each image for instance segmentation.
.. code-block:: none
[
# panoptic segmentation
{
'pan_results': np.array, # shape = [h, w]
'ins_results': tuple[list],
# semantic segmentation results are not supported yet
'sem_results': np.array
},
...
]
or
.. code-block:: none
[
# instance segmentation
(
bboxes, # list[np.array]
masks # list[list[np.array]]
),
...
]
"""
feats = self.extract_feat(imgs)
mask_cls_results, mask_pred_results = self.panoptic_head.simple_test(
feats, img_metas, **kwargs)
results = self.panoptic_fusion_head.simple_test(
mask_cls_results, mask_pred_results, img_metas, **kwargs)
for i in range(len(results)):
if 'pan_results' in results[i]:
results[i]['pan_results'] = results[i]['pan_results'].detach(
).cpu().numpy()
if 'ins_results' in results[i]:
labels_per_image, bboxes, mask_pred_binary = results[i][
'ins_results']
bbox_results = bbox2result(bboxes, labels_per_image,
self.num_things_classes)
mask_results = [[] for _ in range(self.num_things_classes)]
for j, label in enumerate(labels_per_image):
mask = mask_pred_binary[j].detach().cpu().numpy()
mask_results[label].append(mask)
results[i]['ins_results'] = bbox_results, mask_results
assert 'sem_results' not in results[i], 'segmantic segmentation '\
'results are not supported yet.'
if self.num_stuff_classes == 0:
results = [res['ins_results'] for res in results]
return results
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError
def onnx_export(self, img, img_metas):
raise NotImplementedError
def _show_pan_result(self,
img,
result,
score_thr=0.3,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
mask_color=None,
thickness=2,
font_size=13,
win_name='',
show=False,
wait_time=0,
out_file=None):
"""Draw `panoptic result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (dict): The results.
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
The tuple of color should be in BGR order. Default: 'green'.
text_color (str or tuple(int) or :obj:`Color`):Color of texts.
The tuple of color should be in BGR order. Default: 'green'.
mask_color (None or str or tuple(int) or :obj:`Color`):
Color of masks. The tuple of color should be in BGR order.
Default: None.
thickness (int): Thickness of lines. Default: 2.
font_size (int): Font size of texts. Default: 13.
win_name (str): The window name. Default: ''.
wait_time (float): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
pan_results = result['pan_results']
# keep objects ahead
ids = np.unique(pan_results)[::-1]
legal_indices = ids != self.num_classes # for VOID label
ids = ids[legal_indices]
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
segms = (pan_results[None] == ids[:, None, None])
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
img = imshow_det_bboxes(
img,
segms=segms,
labels=labels,
class_names=self.CLASSES,
bbox_color=bbox_color,
text_color=text_color,
mask_color=mask_color,
thickness=thickness,
font_size=font_size,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img