|
import os |
|
import shutil |
|
from typing import List, Tuple |
|
|
|
from PIL import Image |
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data") |
|
SAMPLES_DIR = "samples" |
|
|
|
|
|
def prepare_samples(): |
|
samples: List[Tuple[str, str, str]] = [] |
|
for sample in dataset: |
|
rgb: Image.Image = sample["rgb"] |
|
depth: Image.Image = sample["depth"] |
|
gt: Image.Image = sample["gt"] |
|
name: str = sample["name"] |
|
dir_path = os.path.join(SAMPLES_DIR, name) |
|
shutil.rmtree(dir_path, ignore_errors=True) |
|
os.makedirs(dir_path, exist_ok=True) |
|
rgb_path = os.path.join(dir_path, f"rgb.jpg") |
|
rgb.save(rgb_path) |
|
depth_path = os.path.join(dir_path, f"depth.jpg") |
|
depth.save(depth_path) |
|
gt_path = os.path.join(dir_path, f"gt.png") |
|
gt.save(gt_path) |
|
|
|
samples.append([rgb_path, depth_path, gt_path]) |
|
return samples |
|
|