File size: 4,205 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
@Date: 2021/07/18
@description:
"""
import numpy as np
import torch.utils.data
from dataset.mp3d_dataset import MP3DDataset
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset
from dataset.zind_dataset import ZindDataset


def build_loader(config, logger):
    name = config.DATA.DATASET
    ddp = config.WORLD_SIZE > 1
    train_dataset = None
    train_data_loader = None
    if config.MODE == 'train':
        train_dataset = build_dataset(mode='train', config=config, logger=logger)

    val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger)

    train_sampler = None
    val_sampler = None
    if ddp:
        if train_dataset:
            train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
        val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)

    batch_size = config.DATA.BATCH_SIZE
    num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS
    pin_memory = config.DATA.PIN_MEMORY
    if train_dataset:
        logger.info(f'Train data loader batch size: {batch_size}')
        train_data_loader = torch.utils.data.DataLoader(
            train_dataset, sampler=train_sampler,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=True,
        )
    batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0)
    logger.info(f'Val data loader batch size: {batch_size}')
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset, sampler=val_sampler,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=False
    )
    logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}')
    return train_data_loader, val_data_loader


def build_dataset(mode, config, logger):
    name = config.DATA.DATASET
    if name == 'mp3d':
        dataset = MP3DDataset(
            root_dir=config.DATA.DIR,
            mode=mode,
            shape=config.DATA.SHAPE,
            max_wall_num=config.DATA.WALL_NUM,
            aug=config.DATA.AUG if mode == 'train' else None,
            camera_height=config.DATA.CAMERA_HEIGHT,
            logger=logger,
            for_test_index=config.DATA.FOR_TEST_INDEX,
            keys=config.DATA.KEYS
        )
    elif name == 'pano_s2d3d':
        dataset = PanoS2D3DDataset(
            root_dir=config.DATA.DIR,
            mode=mode,
            shape=config.DATA.SHAPE,
            max_wall_num=config.DATA.WALL_NUM,
            aug=config.DATA.AUG if mode == 'train' else None,
            camera_height=config.DATA.CAMERA_HEIGHT,
            logger=logger,
            for_test_index=config.DATA.FOR_TEST_INDEX,
            subset=config.DATA.SUBSET,
            keys=config.DATA.KEYS
        )
    elif name == 'pano_s2d3d_mix':
        dataset = PanoS2D3DMixDataset(
            root_dir=config.DATA.DIR,
            mode=mode,
            shape=config.DATA.SHAPE,
            max_wall_num=config.DATA.WALL_NUM,
            aug=config.DATA.AUG if mode == 'train' else None,
            camera_height=config.DATA.CAMERA_HEIGHT,
            logger=logger,
            for_test_index=config.DATA.FOR_TEST_INDEX,
            subset=config.DATA.SUBSET,
            keys=config.DATA.KEYS
        )
    elif name == 'zind':
        dataset = ZindDataset(
            root_dir=config.DATA.DIR,
            mode=mode,
            shape=config.DATA.SHAPE,
            max_wall_num=config.DATA.WALL_NUM,
            aug=config.DATA.AUG if mode == 'train' else None,
            camera_height=config.DATA.CAMERA_HEIGHT,
            logger=logger,
            for_test_index=config.DATA.FOR_TEST_INDEX,
            is_simple=True,
            is_ceiling_flat=False,
            keys=config.DATA.KEYS,
            vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING
        )
    else:
        raise NotImplementedError(f"Unknown dataset: {name}")

    return dataset