File size: 7,198 Bytes
915f69b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import tyro
from dataclasses import dataclass
from typing import Tuple, Literal, Dict, Optional


@dataclass
class Options:
    seed: Optional[int] = None
    is_crop: bool = True
    is_fix_views: bool = False
    specific_demo: Optional[str] = None
    txt_or_image: Optional[bool] = False #True=text prompts
    infer_render_size: int = 256
    mvdream_or_zero123: Optional[bool] = True # True for mvdream  False for zero123plus
    #true for rar
    rar_data: bool = True
    ### model
    # Unet image input size
    input_size: int = 512
    # Unet definition
    down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
    down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
    mid_attention: bool = True
    up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
    up_attention: Tuple[bool, ...] = (True, True, True, False)
    # Unet output size, dependent on the input_size and U-Net structure!
    splat_size: int = 64
    # svd render size
    output_size: Optional[int] = 128

    #for tensor
    density_n_comp: int = 8
    app_n_comp: int = 32
    app_dim: int = 27
    density_dim: int = 8
    shadingMode: Literal['MLP_Fea']='MLP_Fea' #'MLP_Fea'
    view_pe: int = 2
    fea_pe: int = 2
    pos_pe: int = 6
    # points number sampled per ray
    n_sample: int = 64  

    # model type TRF for vsd+nerf    TRF_GS for vsd+gs   TRI_GS for tri+gs
    volume_mode: Literal['TRF_Mesh','TRF_NeRF'] = 'TRF_NeRF'


    # for LRM_Net
    camera_embed_dim: int=1024
    transformer_dim: int=1024
    transformer_layers: int=16
    transformer_heads: int=16
    triplane_low_res: int=32
    triplane_high_res: int=64
    triplane_dim: int=32
    encoder_type: str ='dinov2'
    encoder_model_name: str = 'dinov2_vitb14_reg'#'dinov2_vits14_reg' #'dinov2_vitb14_reg'
    encoder_feat_dim: int = 768 #768
    encoder_freeze: bool = False
    
    #training
    over_fit: Optional[bool] = False
    is_grid_sample: bool = False

    ### dataset
    # data mode (only support s3 now)
    data_mode: Literal['s3','s4','s5'] = 's4'
    data_path: str = 'train_data'
    data_debug_list: str = 'dataset_debug/gobj_merged_debug.json'
    data_list_path: str = 'dataset_debug/gobj_merged_debug_selected.json' #dataset_debug/gobj_merged_debug.json'
    # fovy of the dataset
    fovy: float = 39.6 #49.1
    # camera near plane
    znear: float = 0.5
    # camera far plane
    zfar: float = 2.5
    # number of all views (input + output)
    num_views: int = 12
    # number of views
    num_input_views: int = 4
    # camera radius
    cam_radius: float = 1.5 # to better use [-1, 1]^3 space
    # num workers
    num_workers: int = 8 #8
    # 是否考虑单个视角的view
    training_view_plane: bool = False
    is_certainty: bool = False

    ### training
    # workspace
    workspace: str = './workspace_test'
    # resume
    resume: Optional[str] = None
    ckpt_nerf: Optional[str] = None
    # batch size (per-GPU)
    batch_size: int = 8
    # gradient accumulation
    gradient_accumulation_steps: Optional[int] = 1
    # training epochs
    num_epochs: int = 50
    # lpips loss weight
    lambda_lpips: float = 2.0
    # gradient clip
    gradient_clip: float = 1.0
    # mixed precision
    mixed_precision: str = 'bf16'
    # learning rate
    lr: Optional[float] = 4e-4
    lr_scheduler: str = 'OneCycleLR'
    warmup_real_iters: int = 3000

    # augmentation prob for grid distortion
    prob_grid_distortion: float = 0.5
    # augmentation prob for camera jitter
    prob_cam_jitter: float = 0.5

    ### testing
    # test image path
    test_path: Optional[str] = None

    ### misc
    # nvdiffrast backend setting
    force_cuda_rast: bool = False
    # render fancy video with gaussian scaling effect
    fancy_video: bool = False
    

# all the default settings
config_defaults: Dict[str, Options] = {}
config_doc: Dict[str, str] = {}

config_doc['lrm'] = 'the default settings for LGM'
config_defaults['lrm'] = Options()

config_doc['small'] = 'small model with lower resolution Gaussians'
config_defaults['small'] = Options(
    input_size=256,
    splat_size=64,
    output_size=256,
    batch_size=8,
    gradient_accumulation_steps=1,
    mixed_precision='bf16',
)

config_doc['big'] = 'big model with higher resolution Gaussians'
config_defaults['big'] = Options(
    input_size=256,
    up_channels=(1024, 1024, 512, 256, 128), # one more decoder
    up_attention=(True, True, True, False, False),
    splat_size=128,
    output_size=512, # render & supervise Gaussians at a higher resolution.
    batch_size=8,
    num_views=8,
    gradient_accumulation_steps=1,
    mixed_precision='bf16',
)


config_doc['tiny_trf_trans_mesh'] = 'tiny model for ablation'
config_defaults['tiny_trf_trans_mesh'] = Options(
    input_size=512, 
    down_channels=(32, 64, 128, 256, 512),
    down_attention=(False, False, False, False, True),
    up_channels=(512, 256, 128),
    up_attention=(True, False, False, False),
    volume_mode='TRF_Mesh',
    # ckpt_nerf='workspace_debug/0428_02/last.ckpt',
    splat_size=64,
    output_size=512,
    data_mode='s6',
    batch_size=1,  #8
    num_views=8,
    gradient_accumulation_steps=1,  #2
    mixed_precision='no',
)

config_doc['tiny_trf_trans_nerf'] = 'tiny model for ablation'
config_defaults['tiny_trf_trans_nerf'] = Options(
    input_size=512, 
    down_channels=(32, 64, 128, 256, 512),
    down_attention=(False, False, False, False, True),
    up_channels=(512, 256, 128),
    up_attention=(True, False, False, False),
    volume_mode='TRF_NeRF',
    splat_size=64,
    output_size=62, #crop patch
    data_mode='s5',
    batch_size=4,  #8
    num_views=8,
    gradient_accumulation_steps=1,  #2
    mixed_precision='bf16',
)

config_doc['tiny_trf_trans_nerf_123plus'] = 'tiny model for ablation'
config_defaults['tiny_trf_trans_nerf_123plus'] = Options(
    input_size=512, 
    down_channels=(32, 64, 128, 256, 512),
    down_attention=(False, False, False, False, True),
    up_channels=(512, 256, 128),
    up_attention=(True, False, False, False),
    volume_mode='TRF_NeRF',
    splat_size=64,
    output_size=116, #crop patch
    data_mode='s5',
    mvdream_or_zero123=False,
    batch_size=1,  #8
    num_views=10,
    num_input_views=6,
    gradient_accumulation_steps=1,  #2
    mixed_precision='bf16',
)


config_doc['tiny_trf_trans_nerf_nocrop'] = 'tiny model for ablation'
config_defaults['tiny_trf_trans_nerf_nocrop'] = Options(
    input_size=512, 
    down_channels=(32, 64, 128, 256, 512),
    down_attention=(False, False, False, False, True),
    up_channels=(512, 256, 128),
    up_attention=(True, False, False, False),
    volume_mode='TRF_NeRF',
    splat_size=64,
    output_size=62, #crop patch
    data_mode='s5',
    batch_size=4,  #8
    is_crop=False,
    num_views=8,
    gradient_accumulation_steps=1,  #2
    mixed_precision='bf16',
)


AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)