Spaces:
Running
on
A10G
Running
on
A10G
import os.path as osp | |
import tempfile | |
import annotator.uniformer.mmcv as mmcv | |
import numpy as np | |
from annotator.uniformer.mmcv.utils import print_log | |
from PIL import Image | |
from .builder import DATASETS | |
from .custom import CustomDataset | |
class CityscapesDataset(CustomDataset): | |
"""Cityscapes dataset. | |
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is | |
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. | |
""" | |
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', | |
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', | |
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', | |
'bicycle') | |
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], | |
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], | |
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], | |
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], | |
[0, 80, 100], [0, 0, 230], [119, 11, 32]] | |
def __init__(self, **kwargs): | |
super(CityscapesDataset, self).__init__( | |
img_suffix='_leftImg8bit.png', | |
seg_map_suffix='_gtFine_labelTrainIds.png', | |
**kwargs) | |
def _convert_to_label_id(result): | |
"""Convert trainId to id for cityscapes.""" | |
if isinstance(result, str): | |
result = np.load(result) | |
import cityscapesscripts.helpers.labels as CSLabels | |
result_copy = result.copy() | |
for trainId, label in CSLabels.trainId2label.items(): | |
result_copy[result == trainId] = label.id | |
return result_copy | |
def results2img(self, results, imgfile_prefix, to_label_id): | |
"""Write the segmentation results to images. | |
Args: | |
results (list[list | tuple | ndarray]): Testing results of the | |
dataset. | |
imgfile_prefix (str): The filename prefix of the png files. | |
If the prefix is "somepath/xxx", | |
the png files will be named "somepath/xxx.png". | |
to_label_id (bool): whether convert output to label_id for | |
submission | |
Returns: | |
list[str: str]: result txt files which contains corresponding | |
semantic segmentation images. | |
""" | |
mmcv.mkdir_or_exist(imgfile_prefix) | |
result_files = [] | |
prog_bar = mmcv.ProgressBar(len(self)) | |
for idx in range(len(self)): | |
result = results[idx] | |
if to_label_id: | |
result = self._convert_to_label_id(result) | |
filename = self.img_infos[idx]['filename'] | |
basename = osp.splitext(osp.basename(filename))[0] | |
png_filename = osp.join(imgfile_prefix, f'{basename}.png') | |
output = Image.fromarray(result.astype(np.uint8)).convert('P') | |
import cityscapesscripts.helpers.labels as CSLabels | |
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) | |
for label_id, label in CSLabels.id2label.items(): | |
palette[label_id] = label.color | |
output.putpalette(palette) | |
output.save(png_filename) | |
result_files.append(png_filename) | |
prog_bar.update() | |
return result_files | |
def format_results(self, results, imgfile_prefix=None, to_label_id=True): | |
"""Format the results into dir (standard format for Cityscapes | |
evaluation). | |
Args: | |
results (list): Testing results of the dataset. | |
imgfile_prefix (str | None): The prefix of images files. It | |
includes the file path and the prefix of filename, e.g., | |
"a/b/prefix". If not specified, a temp file will be created. | |
Default: None. | |
to_label_id (bool): whether convert output to label_id for | |
submission. Default: False | |
Returns: | |
tuple: (result_files, tmp_dir), result_files is a list containing | |
the image paths, tmp_dir is the temporal directory created | |
for saving json/png files when img_prefix is not specified. | |
""" | |
assert isinstance(results, list), 'results must be a list' | |
assert len(results) == len(self), ( | |
'The length of results is not equal to the dataset len: ' | |
f'{len(results)} != {len(self)}') | |
if imgfile_prefix is None: | |
tmp_dir = tempfile.TemporaryDirectory() | |
imgfile_prefix = tmp_dir.name | |
else: | |
tmp_dir = None | |
result_files = self.results2img(results, imgfile_prefix, to_label_id) | |
return result_files, tmp_dir | |
def evaluate(self, | |
results, | |
metric='mIoU', | |
logger=None, | |
imgfile_prefix=None, | |
efficient_test=False): | |
"""Evaluation in Cityscapes/default protocol. | |
Args: | |
results (list): Testing results of the dataset. | |
metric (str | list[str]): Metrics to be evaluated. | |
logger (logging.Logger | None | str): Logger used for printing | |
related information during evaluation. Default: None. | |
imgfile_prefix (str | None): The prefix of output image file, | |
for cityscapes evaluation only. It includes the file path and | |
the prefix of filename, e.g., "a/b/prefix". | |
If results are evaluated with cityscapes protocol, it would be | |
the prefix of output png files. The output files would be | |
png images under folder "a/b/prefix/xxx.png", where "xxx" is | |
the image name of cityscapes. If not specified, a temp file | |
will be created for evaluation. | |
Default: None. | |
Returns: | |
dict[str, float]: Cityscapes/default metrics. | |
""" | |
eval_results = dict() | |
metrics = metric.copy() if isinstance(metric, list) else [metric] | |
if 'cityscapes' in metrics: | |
eval_results.update( | |
self._evaluate_cityscapes(results, logger, imgfile_prefix)) | |
metrics.remove('cityscapes') | |
if len(metrics) > 0: | |
eval_results.update( | |
super(CityscapesDataset, | |
self).evaluate(results, metrics, logger, efficient_test)) | |
return eval_results | |
def _evaluate_cityscapes(self, results, logger, imgfile_prefix): | |
"""Evaluation in Cityscapes protocol. | |
Args: | |
results (list): Testing results of the dataset. | |
logger (logging.Logger | str | None): Logger used for printing | |
related information during evaluation. Default: None. | |
imgfile_prefix (str | None): The prefix of output image file | |
Returns: | |
dict[str: float]: Cityscapes evaluation results. | |
""" | |
try: | |
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa | |
except ImportError: | |
raise ImportError('Please run "pip install cityscapesscripts" to ' | |
'install cityscapesscripts first.') | |
msg = 'Evaluating in Cityscapes style' | |
if logger is None: | |
msg = '\n' + msg | |
print_log(msg, logger=logger) | |
result_files, tmp_dir = self.format_results(results, imgfile_prefix) | |
if tmp_dir is None: | |
result_dir = imgfile_prefix | |
else: | |
result_dir = tmp_dir.name | |
eval_results = dict() | |
print_log(f'Evaluating results under {result_dir} ...', logger=logger) | |
CSEval.args.evalInstLevelScore = True | |
CSEval.args.predictionPath = osp.abspath(result_dir) | |
CSEval.args.evalPixelAccuracy = True | |
CSEval.args.JSONOutput = False | |
seg_map_list = [] | |
pred_list = [] | |
# when evaluating with official cityscapesscripts, | |
# **_gtFine_labelIds.png is used | |
for seg_map in mmcv.scandir( | |
self.ann_dir, 'gtFine_labelIds.png', recursive=True): | |
seg_map_list.append(osp.join(self.ann_dir, seg_map)) | |
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) | |
eval_results.update( | |
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) | |
if tmp_dir is not None: | |
tmp_dir.cleanup() | |
return eval_results | |