FrozenSeg / datasets /prepare_pascal_ctx_sem_seg.py
xichen98cn's picture
Upload 135 files
3dac99f verified
raw
history blame
2.69 kB
import os
from pathlib import Path
import shutil
import numpy as np
import tqdm
from PIL import Image
import multiprocessing as mp
import functools
from detail import Detail
# fmt: off
_mapping = np.sort(
np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
]))
# fmt: on
_key = np.array(range(len(_mapping))).astype("uint8")
def generate_labels(img_info, detail_api, out_dir):
def _class_to_index(mask, _mapping, _key):
# assert the values
values = np.unique(mask)
for i in range(len(values)):
assert values[i] in _mapping
index = np.digitize(mask.ravel(), _mapping, right=True)
return _key[index].reshape(mask.shape)
sem_seg = _class_to_index(detail_api.getMask(img_info), _mapping=_mapping, _key=_key)
sem_seg = sem_seg - 1 # 0 (ignore) becomes 255. others are shifted by 1
filename = img_info["file_name"]
Image.fromarray(sem_seg).save(out_dir / filename.replace("jpg", "png"))
def copy_images(img_info, img_dir, out_dir):
filename = img_info["file_name"]
shutil.copy2(img_dir / filename, out_dir / filename)
if __name__ == "__main__":
dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "pascal_ctx_d2"
voc_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "VOCdevkit/VOC2010"
for split in ["training", "validation"]:
img_dir = voc_dir / "JPEGImages"
if split == "training":
detail_api = Detail(voc_dir / "trainval_merged.json", img_dir, "train")
else:
detail_api = Detail(voc_dir / "trainval_merged.json", img_dir, "val")
img_infos = detail_api.getImgs()
output_img_dir = dataset_dir / "images" / split
output_ann_dir = dataset_dir / "annotations_ctx59" / split
output_img_dir.mkdir(parents=True, exist_ok=True)
output_ann_dir.mkdir(parents=True, exist_ok=True)
pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4))
pool.map(
functools.partial(copy_images, img_dir=img_dir, out_dir=output_img_dir),
tqdm.tqdm(img_infos, desc=f"Writing {split} images to {output_img_dir} ..."),
chunksize=100,
)
pool.map(
functools.partial(generate_labels, detail_api=detail_api, out_dir=output_ann_dir),
tqdm.tqdm(img_infos, desc=f"Writing {split} images to {output_ann_dir} ..."),
chunksize=100,
)