|
""" |
|
@Date: 2021/07/18 |
|
@description: |
|
""" |
|
import numpy as np |
|
import torch.utils.data |
|
from dataset.mp3d_dataset import MP3DDataset |
|
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset |
|
from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset |
|
from dataset.zind_dataset import ZindDataset |
|
|
|
|
|
def build_loader(config, logger): |
|
name = config.DATA.DATASET |
|
ddp = config.WORLD_SIZE > 1 |
|
train_dataset = None |
|
train_data_loader = None |
|
if config.MODE == 'train': |
|
train_dataset = build_dataset(mode='train', config=config, logger=logger) |
|
|
|
val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger) |
|
|
|
train_sampler = None |
|
val_sampler = None |
|
if ddp: |
|
if train_dataset: |
|
train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True) |
|
val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False) |
|
|
|
batch_size = config.DATA.BATCH_SIZE |
|
num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS |
|
pin_memory = config.DATA.PIN_MEMORY |
|
if train_dataset: |
|
logger.info(f'Train data loader batch size: {batch_size}') |
|
train_data_loader = torch.utils.data.DataLoader( |
|
train_dataset, sampler=train_sampler, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
pin_memory=pin_memory, |
|
drop_last=True, |
|
) |
|
batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0) |
|
logger.info(f'Val data loader batch size: {batch_size}') |
|
val_data_loader = torch.utils.data.DataLoader( |
|
val_dataset, sampler=val_sampler, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=pin_memory, |
|
drop_last=False |
|
) |
|
logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}') |
|
return train_data_loader, val_data_loader |
|
|
|
|
|
def build_dataset(mode, config, logger): |
|
name = config.DATA.DATASET |
|
if name == 'mp3d': |
|
dataset = MP3DDataset( |
|
root_dir=config.DATA.DIR, |
|
mode=mode, |
|
shape=config.DATA.SHAPE, |
|
max_wall_num=config.DATA.WALL_NUM, |
|
aug=config.DATA.AUG if mode == 'train' else None, |
|
camera_height=config.DATA.CAMERA_HEIGHT, |
|
logger=logger, |
|
for_test_index=config.DATA.FOR_TEST_INDEX, |
|
keys=config.DATA.KEYS |
|
) |
|
elif name == 'pano_s2d3d': |
|
dataset = PanoS2D3DDataset( |
|
root_dir=config.DATA.DIR, |
|
mode=mode, |
|
shape=config.DATA.SHAPE, |
|
max_wall_num=config.DATA.WALL_NUM, |
|
aug=config.DATA.AUG if mode == 'train' else None, |
|
camera_height=config.DATA.CAMERA_HEIGHT, |
|
logger=logger, |
|
for_test_index=config.DATA.FOR_TEST_INDEX, |
|
subset=config.DATA.SUBSET, |
|
keys=config.DATA.KEYS |
|
) |
|
elif name == 'pano_s2d3d_mix': |
|
dataset = PanoS2D3DMixDataset( |
|
root_dir=config.DATA.DIR, |
|
mode=mode, |
|
shape=config.DATA.SHAPE, |
|
max_wall_num=config.DATA.WALL_NUM, |
|
aug=config.DATA.AUG if mode == 'train' else None, |
|
camera_height=config.DATA.CAMERA_HEIGHT, |
|
logger=logger, |
|
for_test_index=config.DATA.FOR_TEST_INDEX, |
|
subset=config.DATA.SUBSET, |
|
keys=config.DATA.KEYS |
|
) |
|
elif name == 'zind': |
|
dataset = ZindDataset( |
|
root_dir=config.DATA.DIR, |
|
mode=mode, |
|
shape=config.DATA.SHAPE, |
|
max_wall_num=config.DATA.WALL_NUM, |
|
aug=config.DATA.AUG if mode == 'train' else None, |
|
camera_height=config.DATA.CAMERA_HEIGHT, |
|
logger=logger, |
|
for_test_index=config.DATA.FOR_TEST_INDEX, |
|
is_simple=True, |
|
is_ceiling_flat=False, |
|
keys=config.DATA.KEYS, |
|
vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING |
|
) |
|
else: |
|
raise NotImplementedError(f"Unknown dataset: {name}") |
|
|
|
return dataset |
|
|