import os import shutil from typing import List, Tuple from PIL import Image from datasets import load_dataset dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data") SAMPLES_DIR = "samples" def prepare_samples(): samples: List[Tuple[str, str, str]] = [] for sample in dataset: rgb: Image.Image = sample["rgb"] depth: Image.Image = sample["depth"] gt: Image.Image = sample["gt"] name: str = sample["name"] dir_path = os.path.join(SAMPLES_DIR, name) shutil.rmtree(dir_path, ignore_errors=True) os.makedirs(dir_path, exist_ok=True) rgb_path = os.path.join(dir_path, f"rgb.jpg") rgb.save(rgb_path) depth_path = os.path.join(dir_path, f"depth.jpg") depth.save(depth_path) gt_path = os.path.join(dir_path, f"gt.png") gt.save(gt_path) samples.append([rgb_path, depth_path, gt_path]) return samples