Spaces:
Runtime error
Runtime error
File size: 2,536 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def preprocess_panoptic_gt(gt_labels, gt_masks, gt_semantic_seg, num_things,
num_stuff, img_metas):
"""Preprocess the ground truth for a image.
Args:
gt_labels (Tensor): Ground truth labels of each bbox,
with shape (num_gts, ).
gt_masks (BitmapMasks): Ground truth masks of each instances
of a image, shape (num_gts, h, w).
gt_semantic_seg (Tensor | None): Ground truth of semantic
segmentation with the shape (1, h, w).
[0, num_thing_class - 1] means things,
[num_thing_class, num_class-1] means stuff,
255 means VOID. It's None when training instance segmentation.
img_metas (dict): List of image meta information.
Returns:
tuple: a tuple containing the following targets.
- labels (Tensor): Ground truth class indices for a
image, with shape (n, ), n is the sum of number
of stuff type and number of instance in a image.
- masks (Tensor): Ground truth mask for a image, with
shape (n, h, w). Contains stuff and things when training
panoptic segmentation, and things only when training
instance segmentation.
"""
num_classes = num_things + num_stuff
things_masks = gt_masks.pad(img_metas['pad_shape'][:2], pad_val=0)\
.to_tensor(dtype=torch.bool, device=gt_labels.device)
if gt_semantic_seg is None:
masks = things_masks.long()
return gt_labels, masks
things_labels = gt_labels
gt_semantic_seg = gt_semantic_seg.squeeze(0)
semantic_labels = torch.unique(
gt_semantic_seg,
sorted=False,
return_inverse=False,
return_counts=False)
stuff_masks_list = []
stuff_labels_list = []
for label in semantic_labels:
if label < num_things or label >= num_classes:
continue
stuff_mask = gt_semantic_seg == label
stuff_masks_list.append(stuff_mask)
stuff_labels_list.append(label)
if len(stuff_masks_list) > 0:
stuff_masks = torch.stack(stuff_masks_list, dim=0)
stuff_labels = torch.stack(stuff_labels_list, dim=0)
labels = torch.cat([things_labels, stuff_labels], dim=0)
masks = torch.cat([things_masks, stuff_masks], dim=0)
else:
labels = things_labels
masks = things_masks
masks = masks.long()
return labels, masks
|