Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import mmcv | |
import numpy as np | |
def palette_val(palette): | |
"""Convert palette to matplotlib palette. | |
Args: | |
palette List[tuple]: A list of color tuples. | |
Returns: | |
List[tuple[float]]: A list of RGB matplotlib color tuples. | |
""" | |
new_palette = [] | |
for color in palette: | |
color = [c / 255 for c in color] | |
new_palette.append(tuple(color)) | |
return new_palette | |
def get_palette(palette, num_classes): | |
"""Get palette from various inputs. | |
Args: | |
palette (list[tuple] | str | tuple | :obj:`Color`): palette inputs. | |
num_classes (int): the number of classes. | |
Returns: | |
list[tuple[int]]: A list of color tuples. | |
""" | |
assert isinstance(num_classes, int) | |
if isinstance(palette, list): | |
dataset_palette = palette | |
elif isinstance(palette, tuple): | |
dataset_palette = [palette] * num_classes | |
elif palette == 'random' or palette is None: | |
state = np.random.get_state() | |
# random color | |
np.random.seed(42) | |
palette = np.random.randint(0, 256, size=(num_classes, 3)) | |
np.random.set_state(state) | |
dataset_palette = [tuple(c) for c in palette] | |
elif palette == 'coco': | |
from mmdet.datasets import CocoDataset, CocoPanopticDataset | |
dataset_palette = CocoDataset.PALETTE | |
if len(dataset_palette) < num_classes: | |
dataset_palette = CocoPanopticDataset.PALETTE | |
elif palette == 'citys': | |
from mmdet.datasets import CityscapesDataset | |
dataset_palette = CityscapesDataset.PALETTE | |
elif palette == 'voc': | |
from mmdet.datasets import VOCDataset | |
dataset_palette = VOCDataset.PALETTE | |
elif mmcv.is_str(palette): | |
dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes | |
else: | |
raise TypeError(f'Invalid type for palette: {type(palette)}') | |
assert len(dataset_palette) >= num_classes, \ | |
'The length of palette should not be less than `num_classes`.' | |
return dataset_palette | |