zhigangjiang's picture
no message
88b0dcb
raw
history blame
4.21 kB
"""
@Date: 2021/07/18
@description:
"""
import numpy as np
import torch.utils.data
from dataset.mp3d_dataset import MP3DDataset
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset
from dataset.zind_dataset import ZindDataset
def build_loader(config, logger):
name = config.DATA.DATASET
ddp = config.WORLD_SIZE > 1
train_dataset = None
train_data_loader = None
if config.MODE == 'train':
train_dataset = build_dataset(mode='train', config=config, logger=logger)
val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger)
train_sampler = None
val_sampler = None
if ddp:
if train_dataset:
train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)
batch_size = config.DATA.BATCH_SIZE
num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS
pin_memory = config.DATA.PIN_MEMORY
if train_dataset:
logger.info(f'Train data loader batch size: {batch_size}')
train_data_loader = torch.utils.data.DataLoader(
train_dataset, sampler=train_sampler,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0)
logger.info(f'Val data loader batch size: {batch_size}')
val_data_loader = torch.utils.data.DataLoader(
val_dataset, sampler=val_sampler,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False
)
logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}')
return train_data_loader, val_data_loader
def build_dataset(mode, config, logger):
name = config.DATA.DATASET
if name == 'mp3d':
dataset = MP3DDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
keys=config.DATA.KEYS
)
elif name == 'pano_s2d3d':
dataset = PanoS2D3DDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
subset=config.DATA.SUBSET,
keys=config.DATA.KEYS
)
elif name == 'pano_s2d3d_mix':
dataset = PanoS2D3DMixDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
subset=config.DATA.SUBSET,
keys=config.DATA.KEYS
)
elif name == 'zind':
dataset = ZindDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
is_simple=True,
is_ceiling_flat=False,
keys=config.DATA.KEYS,
vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING
)
else:
raise NotImplementedError(f"Unknown dataset: {name}")
return dataset