|
""" |
|
@Date: 2021/09/22 |
|
@description: |
|
""" |
|
import os |
|
import json |
|
import math |
|
import numpy as np |
|
|
|
from dataset.communal.read import read_image, read_label, read_zind |
|
from dataset.communal.base_dataset import BaseDataset |
|
from utils.logger import get_logger |
|
from preprocessing.filter import filter_center, filter_boundary, filter_self_intersection |
|
from utils.boundary import calc_rotation |
|
|
|
|
|
class ZindDataset(BaseDataset): |
|
def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None, |
|
split_list=None, patch_num=256, keys=None, for_test_index=None, |
|
is_simple=True, is_ceiling_flat=False, vp_align=False): |
|
|
|
|
|
super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys) |
|
if logger is None: |
|
logger = get_logger() |
|
self.root_dir = root_dir |
|
self.vp_align = vp_align |
|
|
|
data_dir = os.path.join(root_dir) |
|
img_dir = os.path.join(root_dir, 'image') |
|
|
|
pano_list = read_zind(partition_path=os.path.join(data_dir, f"zind_partition.json"), |
|
simplicity_path=os.path.join(data_dir, f"room_shape_simplicity_labels.json"), |
|
data_dir=data_dir, mode=mode, is_simple=is_simple, is_ceiling_flat=is_ceiling_flat) |
|
|
|
if for_test_index is not None: |
|
pano_list = pano_list[:for_test_index] |
|
if split_list: |
|
pano_list = [pano for pano in pano_list if pano['id'] in split_list] |
|
self.data = [] |
|
invalid_num = 0 |
|
for pano in pano_list: |
|
if not os.path.exists(pano['img_path']): |
|
logger.warning(f"{pano['img_path']} not exists") |
|
invalid_num += 1 |
|
continue |
|
|
|
if not filter_center(pano['corners']): |
|
|
|
|
|
continue |
|
|
|
if self.max_wall_num >= 10: |
|
if len(pano['corners']) < self.max_wall_num: |
|
invalid_num += 1 |
|
continue |
|
elif self.max_wall_num != 0 and len(pano['corners']) != self.max_wall_num: |
|
invalid_num += 1 |
|
continue |
|
|
|
if not filter_boundary(pano['corners']): |
|
logger.warning(f"{pano['id']} boundary cross") |
|
invalid_num += 1 |
|
continue |
|
|
|
if not filter_self_intersection(pano['corners']): |
|
logger.warning(f"{pano['id']} self_intersection") |
|
invalid_num += 1 |
|
continue |
|
|
|
self.data.append(pano) |
|
|
|
logger.info( |
|
f"Build dataset mode: {self.mode} max_wall_num: {self.max_wall_num} valid: {len(self.data)} invalid: {invalid_num}") |
|
|
|
def __getitem__(self, idx): |
|
pano = self.data[idx] |
|
rgb_path = pano['img_path'] |
|
label = pano |
|
image = read_image(rgb_path, self.shape) |
|
|
|
if self.vp_align: |
|
|
|
rotation = calc_rotation(corners=label['corners']) |
|
shift = math.modf(rotation / (2 * np.pi) + 1)[0] |
|
image = np.roll(image, round(shift * self.shape[1]), axis=1) |
|
label['corners'][:, 0] = np.modf(label['corners'][:, 0] + shift)[0] |
|
|
|
output = self.process_data(label, image, self.patch_num) |
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from tqdm import tqdm |
|
from visualization.boundary import draw_boundaries, draw_object |
|
from visualization.floorplan import draw_floorplan |
|
from utils.boundary import depth2boundaries, calc_rotation |
|
from utils.conversion import uv2xyz |
|
from models.other.init_env import init_env |
|
|
|
init_env(123) |
|
|
|
modes = ['val'] |
|
for i in range(1): |
|
for mode in modes: |
|
print(mode) |
|
mp3d_dataset = ZindDataset(root_dir='../src/dataset/zind', mode=mode, aug={ |
|
'STRETCH': False, |
|
'ROTATE': False, |
|
'FLIP': False, |
|
'GAMMA': False |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
bar = tqdm(mp3d_dataset, ncols=100) |
|
for data in bar: |
|
|
|
|
|
bar.set_description(f"Processing {data['id']}") |
|
boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None) |
|
|
|
pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True, |
|
marker_color=None, center_color=0.2) |
|
|
|
|
|
|