File size: 4,205 Bytes
88b0dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""
@Date: 2021/07/18
@description:
"""
import numpy as np
import torch.utils.data
from dataset.mp3d_dataset import MP3DDataset
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset
from dataset.zind_dataset import ZindDataset
def build_loader(config, logger):
name = config.DATA.DATASET
ddp = config.WORLD_SIZE > 1
train_dataset = None
train_data_loader = None
if config.MODE == 'train':
train_dataset = build_dataset(mode='train', config=config, logger=logger)
val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger)
train_sampler = None
val_sampler = None
if ddp:
if train_dataset:
train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)
batch_size = config.DATA.BATCH_SIZE
num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS
pin_memory = config.DATA.PIN_MEMORY
if train_dataset:
logger.info(f'Train data loader batch size: {batch_size}')
train_data_loader = torch.utils.data.DataLoader(
train_dataset, sampler=train_sampler,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0)
logger.info(f'Val data loader batch size: {batch_size}')
val_data_loader = torch.utils.data.DataLoader(
val_dataset, sampler=val_sampler,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False
)
logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}')
return train_data_loader, val_data_loader
def build_dataset(mode, config, logger):
name = config.DATA.DATASET
if name == 'mp3d':
dataset = MP3DDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
keys=config.DATA.KEYS
)
elif name == 'pano_s2d3d':
dataset = PanoS2D3DDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
subset=config.DATA.SUBSET,
keys=config.DATA.KEYS
)
elif name == 'pano_s2d3d_mix':
dataset = PanoS2D3DMixDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
subset=config.DATA.SUBSET,
keys=config.DATA.KEYS
)
elif name == 'zind':
dataset = ZindDataset(
root_dir=config.DATA.DIR,
mode=mode,
shape=config.DATA.SHAPE,
max_wall_num=config.DATA.WALL_NUM,
aug=config.DATA.AUG if mode == 'train' else None,
camera_height=config.DATA.CAMERA_HEIGHT,
logger=logger,
for_test_index=config.DATA.FOR_TEST_INDEX,
is_simple=True,
is_ceiling_flat=False,
keys=config.DATA.KEYS,
vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING
)
else:
raise NotImplementedError(f"Unknown dataset: {name}")
return dataset
|