xrg commited on
Commit
915f69b
1 Parent(s): 16ef2cb

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. .vscode/launch.json +23 -0
  3. app.py +385 -0
  4. core/__init__.py +0 -0
  5. core/block.py +124 -0
  6. core/embedder.py +37 -0
  7. core/encoders/__init__.py +15 -0
  8. core/encoders/dino_wrapper.py +68 -0
  9. core/encoders/dinov2/__init__.py +15 -0
  10. core/encoders/dinov2/hub/__init__.py +4 -0
  11. core/encoders/dinov2/hub/backbones.py +166 -0
  12. core/encoders/dinov2/hub/classifiers.py +268 -0
  13. core/encoders/dinov2/hub/depth/__init__.py +7 -0
  14. core/encoders/dinov2/hub/depth/decode_heads.py +747 -0
  15. core/encoders/dinov2/hub/depth/encoder_decoder.py +351 -0
  16. core/encoders/dinov2/hub/depth/ops.py +28 -0
  17. core/encoders/dinov2/hub/depthers.py +246 -0
  18. core/encoders/dinov2/hub/utils.py +39 -0
  19. core/encoders/dinov2/layers/__init__.py +20 -0
  20. core/encoders/dinov2/layers/attention.py +89 -0
  21. core/encoders/dinov2/layers/block.py +296 -0
  22. core/encoders/dinov2/layers/dino_head.py +58 -0
  23. core/encoders/dinov2/layers/drop_path.py +34 -0
  24. core/encoders/dinov2/layers/layer_scale.py +27 -0
  25. core/encoders/dinov2/layers/mlp.py +40 -0
  26. core/encoders/dinov2/layers/patch_embed.py +88 -0
  27. core/encoders/dinov2/layers/swiglu_ffn.py +72 -0
  28. core/encoders/dinov2/models/__init__.py +43 -0
  29. core/encoders/dinov2/models/vision_transformer.py +443 -0
  30. core/encoders/dinov2_wrapper.py +67 -0
  31. core/geometry/__init__.py +7 -0
  32. core/geometry/camera/__init__.py +16 -0
  33. core/geometry/camera/perspective_camera.py +51 -0
  34. core/geometry/render/__init__.py +8 -0
  35. core/geometry/render/neural_render.py +121 -0
  36. core/geometry/rep_3d/__init__.py +18 -0
  37. core/geometry/rep_3d/dmtet.py +504 -0
  38. core/geometry/rep_3d/dmtet_utils.py +20 -0
  39. core/geometry/rep_3d/extract_texture_map.py +40 -0
  40. core/geometry/rep_3d/flexicubes.py +579 -0
  41. core/geometry/rep_3d/flexicubes_geometry.py +120 -0
  42. core/geometry/rep_3d/tables.py +791 -0
  43. core/instant_utils/__init__.py +0 -0
  44. core/instant_utils/camera_util.py +111 -0
  45. core/instant_utils/infer_util.py +97 -0
  46. core/instant_utils/mesh_util.py +181 -0
  47. core/instant_utils/train_util.py +26 -0
  48. core/lrm_reconstructor.py +158 -0
  49. core/models.py +783 -0
  50. core/modulate.py +43 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ workspace_test
2
+ __pycache__
.vscode/launch.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "app",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "program": "./app.py",
9
+ "console": "integratedTerminal",
10
+ "env": {
11
+ "CUDA_VISIBLE_DEVICES": "2"
12
+ },
13
+ // "args": [
14
+ // "tiny_trf_trans_nerf",//"tiny_trf_trans_nerf" tiny_trf_trans_nerf_123plus
15
+ // "--resume",
16
+ // "pretrained/last6view060804_24.ckpt",//"pretrained/last_060302_49.ckpt",//"pretrained/last_060302_49.ckpt",
17
+ // "--output_size",
18
+ // "64"
19
+ // ],
20
+ "justMyCode": true
21
+ },
22
+ ]
23
+ }
app.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tyro
3
+ import imageio
4
+ import numpy as np
5
+ import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from safetensors.torch import load_file
11
+ import rembg
12
+ import gradio as gr
13
+
14
+ import kiui
15
+ from kiui.op import recenter
16
+ from kiui.cam import orbit_camera
17
+ from core.utils import get_rays, grid_distortion, orbit_camera_jitter
18
+
19
+ from core.options import AllConfigs, Options
20
+ from core.models import LTRFM_Mesh,LTRFM_NeRF
21
+ from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
22
+ from mvdream.pipeline_mvdream import MVDreamPipeline
23
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
27
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
28
+ GRADIO_VIDEO_PATH = 'gradio_output.mp4'
29
+ GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
30
+ GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
31
+ GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
32
+
33
+ #opt = tyro.cli(AllConfigs)
34
+
35
+ ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
36
+
37
+ opt = Options(
38
+ input_size=512,
39
+ down_channels=(32, 64, 128, 256, 512),
40
+ down_attention=(False, False, False, False, True),
41
+ up_channels=(512, 256, 128),
42
+ up_attention=(True, False, False, False),
43
+ volume_mode='TRF_NeRF',
44
+ splat_size=64,
45
+ output_size=62, #crop patch
46
+ data_mode='s5',
47
+ num_views=8,
48
+ gradient_accumulation_steps=1, #2
49
+ mixed_precision='bf16',
50
+ resume=ckpt_path,
51
+ )
52
+
53
+
54
+ # model
55
+ if opt.volume_mode == 'TRF_Mesh':
56
+ model = LTRFM_Mesh(opt)
57
+ elif opt.volume_mode == 'TRF_NeRF':
58
+ model = LTRFM_NeRF(opt)
59
+ else:
60
+ model = LGM(opt)
61
+
62
+ # resume pretrained checkpoint
63
+ if opt.resume is not None:
64
+ if opt.resume.endswith('safetensors'):
65
+ ckpt = load_file(opt.resume, device='cpu')
66
+ else: #ckpt
67
+ ckpt_dict = torch.load(opt.resume, map_location='cpu')
68
+ ckpt=ckpt_dict["model"]
69
+
70
+ state_dict = model.state_dict()
71
+ for k, v in ckpt.items():
72
+ k=k.replace('module.', '')
73
+ if k in state_dict:
74
+ if state_dict[k].shape == v.shape:
75
+ state_dict[k].copy_(v)
76
+ else:
77
+ print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
78
+ else:
79
+ print(f'[WARN] unexpected param {k}: {v.shape}')
80
+ print(f'[INFO] load resume success!')
81
+
82
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
+ model = model.half().to(device)
84
+ model.eval()
85
+
86
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
87
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
88
+ proj_matrix[0, 0] = 1 / tan_half_fov
89
+ proj_matrix[1, 1] = 1 / tan_half_fov
90
+ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
91
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
92
+ proj_matrix[2, 3] = 1
93
+
94
+ # load dreams
95
+ pipe_text = MVDreamPipeline.from_pretrained(
96
+ 'ashawkey/mvdream-sd2.1-diffusers', # remote weights
97
+ torch_dtype=torch.float16,
98
+ trust_remote_code=True,
99
+ # local_files_only=True,
100
+ )
101
+ pipe_text = pipe_text.to(device)
102
+
103
+ # mvdream
104
+ pipe_image = MVDreamPipeline.from_pretrained(
105
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
106
+ torch_dtype=torch.float16,
107
+ trust_remote_code=True,
108
+ # local_files_only=True,
109
+ )
110
+ pipe_image = pipe_image.to(device)
111
+
112
+
113
+ print('Loading 123plus model ...')
114
+ pipe_image_plus = DiffusionPipeline.from_pretrained(
115
+ "sudo-ai/zero123plus-v1.2",
116
+ custom_pipeline="zero123plus",
117
+ torch_dtype=torch.float16,
118
+ trust_remote_code=True,
119
+ local_files_only=True,
120
+ )
121
+ pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
122
+ pipe_image_plus.scheduler.config, timestep_spacing='trailing'
123
+ )
124
+
125
+ unet_path='./pretrained/diffusion_pytorch_model.bin'
126
+
127
+ print('Loading custom white-background unet ...')
128
+ if os.path.exists(unet_path):
129
+ unet_ckpt_path = unet_path
130
+ else:
131
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
132
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
133
+ pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
134
+ pipe_image_plus = pipe_image_plus.to(device)
135
+
136
+ # load rembg
137
+ bg_remover = rembg.new_session()
138
+
139
+ # process function
140
+ def process(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
141
+
142
+ # seed
143
+ kiui.seed_everything(input_seed)
144
+
145
+ os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
146
+ output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
147
+ output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
148
+ output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
149
+ output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
150
+
151
+ # text-conditioned
152
+ if condition_input_image is None:
153
+ mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
154
+ mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
155
+ # bg removal
156
+ mv_image = []
157
+ for i in range(4):
158
+ image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
159
+ # to white bg
160
+ image = image.astype(np.float32) / 255
161
+ image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
162
+ image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
163
+ mv_image.append(image)
164
+
165
+ mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
166
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
167
+
168
+ processed_image=None
169
+ # image-conditioned (may also input text, but no text usually works too)
170
+ else:
171
+ condition_input_image = np.array(condition_input_image) # uint8
172
+ # bg removal
173
+ carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
174
+ mask = carved_image[..., -1] > 0
175
+ image = recenter(carved_image, mask, border_ratio=0.2)
176
+ image = image.astype(np.float32) / 255.0
177
+ processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
178
+
179
+ if mv_moedl_option=='mvdream':
180
+ mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
181
+
182
+ mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
183
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
184
+ else:
185
+ from PIL import Image
186
+ from einops import rearrange, repeat
187
+
188
+ # input_image=input_image* 255
189
+ processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
190
+ mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
191
+ mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
192
+ mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
193
+ mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
194
+ mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
195
+ input_image = mv_image
196
+
197
+ # generate gaussians
198
+ # [4, 256, 256, 3], float32
199
+ input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
200
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
201
+
202
+ images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
203
+
204
+ data = {}
205
+ input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
206
+ images_input_vit=images_input_vit.unsqueeze(0)
207
+ data['input_vit']=images_input_vit
208
+
209
+ elevation = 0
210
+ cam_poses =[]
211
+ if mv_moedl_option=='mvdream' or condition_input_image is None:
212
+ azimuth = np.arange(0, 360, 90, dtype=np.int32)
213
+ for azi in tqdm.tqdm(azimuth):
214
+ cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
215
+ cam_poses.append(cam_pose)
216
+ else:
217
+ azimuth = np.arange(30, 360, 60, dtype=np.int32)
218
+ cnt = 0
219
+ for azi in tqdm.tqdm(azimuth):
220
+ if (cnt+1) % 2!= 0:
221
+ elevation=-20
222
+ else:
223
+ elevation=30
224
+ cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
225
+ cam_poses.append(cam_pose)
226
+ cnt=cnt+1
227
+
228
+ cam_poses = torch.cat(cam_poses,0)
229
+ radius = torch.norm(cam_poses[0, :3, 3])
230
+ cam_poses[:, :3, 3] *= opt.cam_radius / radius
231
+ transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
232
+ cam_poses = transform.unsqueeze(0) @ cam_poses
233
+
234
+ cam_poses=cam_poses.unsqueeze(0)
235
+ data['source_camera']=cam_poses
236
+
237
+ with torch.no_grad():
238
+ if opt.volume_mode == 'TRF_Mesh':
239
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
240
+ svd_volume = model.forward_svd_volume(input_image,data)
241
+ else:
242
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
243
+ svd_volume = model.forward_svd_volume(input_image,data)
244
+
245
+ #time-consuming
246
+ export_texmap=False
247
+
248
+ mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
249
+
250
+ if export_texmap:
251
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
252
+
253
+ for i in range(len(tex_map)):
254
+ mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
255
+ save_obj_with_mtl(
256
+ vertices.data.cpu().numpy(),
257
+ uvs.data.cpu().numpy(),
258
+ faces.data.cpu().numpy(),
259
+ mesh_tex_idx.data.cpu().numpy(),
260
+ tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
261
+ mesh_path,
262
+ )
263
+ else:
264
+ vertices, faces, vertex_colors = mesh_out
265
+
266
+ save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
267
+ save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
268
+ save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
269
+
270
+
271
+ return mv_image_grid, processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path
272
+
273
+ # gradio UI
274
+
275
+ _TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation'''
276
+
277
+ _DESCRIPTION = '''
278
+
279
+
280
+ * Input can be text prompt, image.
281
+ * If you find the output unsatisfying, try using different seeds!
282
+ '''
283
+
284
+ block = gr.Blocks(title=_TITLE).queue()
285
+ with block:
286
+ with gr.Row():
287
+ with gr.Column(scale=1):
288
+ gr.Markdown('# ' + _TITLE)
289
+ gr.Markdown(_DESCRIPTION)
290
+
291
+ with gr.Row(variant='panel'):
292
+ with gr.Column(scale=1):
293
+ with gr.Tab("Image-to-3D"):
294
+ # input image
295
+ with gr.Row():
296
+ condition_input_image = gr.Image(
297
+ label="Input Image",
298
+ image_mode="RGBA",
299
+ type="pil"
300
+ )
301
+
302
+ processed_image = gr.Image(
303
+ label="Processed Image",
304
+ image_mode="RGBA",
305
+ type="pil",
306
+ interactive=False
307
+ )
308
+
309
+
310
+ with gr.Row():
311
+ mv_moedl_option = gr.Radio([
312
+ "zero123plus",
313
+ "mvdream"
314
+ ], value="zero123plus",
315
+ label="Multi-view Diffusion")
316
+
317
+ with gr.Row(variant="panel"):
318
+ gr.Examples(
319
+ examples=[
320
+ os.path.join("example", img_name) for img_name in sorted(os.listdir("example"))
321
+ ],
322
+ inputs=[condition_input_image],
323
+ fn=lambda x: process(condition_input_image=x, prompt=''),
324
+ cache_examples=False,
325
+ examples_per_page=20,
326
+ label='Image-to-3D Examples'
327
+ )
328
+
329
+ with gr.Tab("Text-to-3D"):
330
+ # input prompt
331
+ with gr.Row():
332
+ input_text = gr.Textbox(label="prompt")
333
+ # negative prompt
334
+ with gr.Row():
335
+ input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
336
+
337
+ with gr.Row(variant="panel"):
338
+ gr.Examples(
339
+ examples=[
340
+ "a hamburger",
341
+ "a furry red fox head",
342
+ "a teddy bear",
343
+ "a motorbike",
344
+ ],
345
+ inputs=[input_text],
346
+ fn=lambda x: process(condition_input_image=None, prompt=x),
347
+ cache_examples=False,
348
+ label='Text-to-3D Examples'
349
+ )
350
+
351
+ # elevation
352
+ input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
353
+ # inference steps
354
+ input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
355
+ # random seed
356
+ input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
357
+ # gen button
358
+ button_gen = gr.Button("Generate")
359
+
360
+
361
+ with gr.Column(scale=1):
362
+ with gr.Row():
363
+ # multi-view results
364
+ mv_image_grid = gr.Image(interactive=False, show_label=False)
365
+ with gr.Row():
366
+ output_obj_rgb_path = gr.Model3D(
367
+ label="RGB Model (OBJ Format)",
368
+ interactive=False,
369
+ )
370
+ with gr.Row():
371
+ output_obj_albedo_path = gr.Model3D(
372
+ label="Albedo Model (OBJ Format)",
373
+ interactive=False,
374
+ )
375
+ with gr.Row():
376
+ output_obj_shading_path = gr.Model3D(
377
+ label="Shading Model (OBJ Format)",
378
+ interactive=False,
379
+ )
380
+
381
+
382
+ button_gen.click(process, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed,mv_moedl_option], outputs=[mv_image_grid,processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path])
383
+
384
+
385
+ block.launch(server_name="0.0.0.0", share=False)
core/__init__.py ADDED
File without changes
core/block.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch.nn as nn
17
+
18
+ from .modulate import ModLN
19
+
20
+
21
+ class BasicBlock(nn.Module):
22
+ """
23
+ Transformer block that is in its simplest form.
24
+ Designed for PF-LRM architecture.
25
+ """
26
+ # Block contains a self-attention layer and an MLP
27
+ def __init__(self, inner_dim: int, num_heads: int, eps: float,
28
+ attn_drop: float = 0., attn_bias: bool = False,
29
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
30
+ super().__init__()
31
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
32
+ self.self_attn = nn.MultiheadAttention(
33
+ embed_dim=inner_dim, num_heads=num_heads,
34
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
35
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
36
+ self.mlp = nn.Sequential(
37
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
38
+ nn.GELU(),
39
+ nn.Dropout(mlp_drop),
40
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
41
+ nn.Dropout(mlp_drop),
42
+ )
43
+
44
+ def forward(self, x):
45
+ # x: [N, L, D]
46
+ before_sa = self.norm1(x)
47
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
48
+ x = x + self.mlp(self.norm2(x))
49
+ return x
50
+
51
+
52
+ class ConditionBlock(nn.Module):
53
+ """
54
+ Transformer block that takes in a cross-attention condition.
55
+ Designed for SparseLRM architecture.
56
+ """
57
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
58
+ def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
59
+ attn_drop: float = 0., attn_bias: bool = False,
60
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
61
+ super().__init__()
62
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
63
+ self.cross_attn = nn.MultiheadAttention(
64
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
65
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
66
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
67
+ self.self_attn = nn.MultiheadAttention(
68
+ embed_dim=inner_dim, num_heads=num_heads,
69
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
70
+ self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
71
+ self.mlp = nn.Sequential(
72
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
73
+ nn.GELU(),
74
+ nn.Dropout(mlp_drop),
75
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
76
+ nn.Dropout(mlp_drop),
77
+ )
78
+
79
+ def forward(self, x, cond):
80
+ # x: [N, L, D]
81
+ # cond: [N, L_cond, D_cond]
82
+ x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
83
+ before_sa = self.norm2(x)
84
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
85
+ x = x + self.mlp(self.norm3(x))
86
+ return x
87
+
88
+
89
+ class ConditionModulationBlock(nn.Module):
90
+ """
91
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
92
+ Designed for raw LRM architecture.
93
+ """
94
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
95
+ def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
96
+ attn_drop: float = 0., attn_bias: bool = False,
97
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
98
+ super().__init__()
99
+ self.norm1 = ModLN(inner_dim, mod_dim, eps)
100
+ self.cross_attn = nn.MultiheadAttention(
101
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
102
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
103
+ self.norm2 = ModLN(inner_dim, mod_dim, eps)
104
+ self.self_attn = nn.MultiheadAttention(
105
+ embed_dim=inner_dim, num_heads=num_heads,
106
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
107
+ self.norm3 = ModLN(inner_dim, mod_dim, eps)
108
+ self.mlp = nn.Sequential(
109
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
110
+ nn.GELU(),
111
+ nn.Dropout(mlp_drop),
112
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
113
+ nn.Dropout(mlp_drop),
114
+ )
115
+
116
+ def forward(self, x, cond, mod):
117
+ # x: [N, L, D]
118
+ # cond: [N, L_cond, D_cond]
119
+ # mod: [N, D_mod]
120
+ x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
121
+ before_sa = self.norm2(x, mod)
122
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
123
+ x = x + self.mlp(self.norm3(x, mod))
124
+ return x
core/embedder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class CameraEmbedder(nn.Module):
21
+ """
22
+ Embed camera features to a high-dimensional vector.
23
+
24
+ Reference:
25
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
26
+ """
27
+ def __init__(self, raw_dim: int, embed_dim: int):
28
+ super().__init__()
29
+ self.mlp = nn.Sequential(
30
+ nn.Linear(raw_dim, embed_dim),
31
+ nn.SiLU(),
32
+ nn.Linear(embed_dim, embed_dim),
33
+ )
34
+
35
+ @torch.compile
36
+ def forward(self, x):
37
+ return self.mlp(x)
core/encoders/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Empty
core/encoders/dino_wrapper.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import ViTImageProcessor, ViTModel
19
+ from accelerate.logging import get_logger
20
+
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ class DinoWrapper(nn.Module):
26
+ """
27
+ Dino v1 wrapper using huggingface transformer implementation.
28
+ """
29
+ def __init__(self, model_name: str, freeze: bool = True):
30
+ super().__init__()
31
+ self.model, self.processor = self._build_dino(model_name)
32
+ if freeze:
33
+ self._freeze()
34
+
35
+ @torch.compile
36
+ def forward_model(self, inputs):
37
+ return self.model(**inputs, interpolate_pos_encoding=True)
38
+
39
+ def forward(self, image):
40
+ # image: [N, C, H, W], on cpu
41
+ # RGB image with [0,1] scale and properly sized
42
+ inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
43
+ # This resampling of positional embedding uses bicubic interpolation
44
+ outputs = self.forward_model(inputs)
45
+ last_hidden_states = outputs.last_hidden_state
46
+ return last_hidden_states
47
+
48
+ def _freeze(self):
49
+ logger.warning(f"======== Freezing DinoWrapper ========")
50
+ self.model.eval()
51
+ for name, param in self.model.named_parameters():
52
+ param.requires_grad = False
53
+
54
+ @staticmethod
55
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
56
+ import requests
57
+ try:
58
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
59
+ processor = ViTImageProcessor.from_pretrained(model_name)
60
+ return model, processor
61
+ except requests.exceptions.ProxyError as err:
62
+ if proxy_error_retries > 0:
63
+ print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
64
+ import time
65
+ time.sleep(proxy_error_cooldown)
66
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
67
+ else:
68
+ raise err
core/encoders/dinov2/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Empty
core/encoders/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
core/encoders/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ # ********** Modified by Zexin He in 2023-2024 **********
60
+ state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
61
+ if vit_kwargs.get("modulation_dim") is not None:
62
+ state_dict = {
63
+ k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
64
+ for k, v in state_dict.items()
65
+ }
66
+ model.load_state_dict(state_dict, strict=False)
67
+ else:
68
+ model.load_state_dict(state_dict, strict=True)
69
+ # ********************************************************
70
+
71
+ return model
72
+
73
+
74
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
75
+ """
76
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
77
+ """
78
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
79
+
80
+
81
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
82
+ """
83
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
84
+ """
85
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
86
+
87
+
88
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
89
+ """
90
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
91
+ """
92
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
93
+
94
+
95
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
96
+ """
97
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
98
+ """
99
+ return _make_dinov2_model(
100
+ arch_name="vit_giant2",
101
+ ffn_layer="swiglufused",
102
+ weights=weights,
103
+ pretrained=pretrained,
104
+ **kwargs,
105
+ )
106
+
107
+
108
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
109
+ """
110
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
111
+ """
112
+ return _make_dinov2_model(
113
+ arch_name="vit_small",
114
+ pretrained=pretrained,
115
+ weights=weights,
116
+ num_register_tokens=4,
117
+ interpolate_antialias=True,
118
+ interpolate_offset=0.0,
119
+ **kwargs,
120
+ )
121
+
122
+
123
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
124
+ """
125
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
126
+ """
127
+ return _make_dinov2_model(
128
+ arch_name="vit_base",
129
+ pretrained=pretrained,
130
+ weights=weights,
131
+ num_register_tokens=4,
132
+ interpolate_antialias=True,
133
+ interpolate_offset=0.0,
134
+ **kwargs,
135
+ )
136
+
137
+
138
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
139
+ """
140
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
141
+ """
142
+ return _make_dinov2_model(
143
+ arch_name="vit_large",
144
+ pretrained=pretrained,
145
+ weights=weights,
146
+ num_register_tokens=4,
147
+ interpolate_antialias=True,
148
+ interpolate_offset=0.0,
149
+ **kwargs,
150
+ )
151
+
152
+
153
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
154
+ """
155
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
156
+ """
157
+ return _make_dinov2_model(
158
+ arch_name="vit_giant2",
159
+ ffn_layer="swiglufused",
160
+ weights=weights,
161
+ pretrained=pretrained,
162
+ num_register_tokens=4,
163
+ interpolate_antialias=True,
164
+ interpolate_offset=0.0,
165
+ **kwargs,
166
+ )
core/encoders/dinov2/hub/classifiers.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .backbones import _make_dinov2_model
13
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
14
+
15
+
16
+ class Weights(Enum):
17
+ IMAGENET1K = "IMAGENET1K"
18
+
19
+
20
+ def _make_dinov2_linear_classification_head(
21
+ *,
22
+ arch_name: str = "vit_large",
23
+ patch_size: int = 14,
24
+ embed_dim: int = 1024,
25
+ layers: int = 4,
26
+ pretrained: bool = True,
27
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
28
+ num_register_tokens: int = 0,
29
+ **kwargs,
30
+ ):
31
+ if layers not in (1, 4):
32
+ raise AssertionError(f"Unsupported number of layers: {layers}")
33
+ if isinstance(weights, str):
34
+ try:
35
+ weights = Weights[weights]
36
+ except KeyError:
37
+ raise AssertionError(f"Unsupported weights: {weights}")
38
+
39
+ linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
40
+
41
+ if pretrained:
42
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
43
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
44
+ layers_str = str(layers) if layers == 4 else ""
45
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
46
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
47
+ linear_head.load_state_dict(state_dict, strict=True)
48
+
49
+ return linear_head
50
+
51
+
52
+ class _LinearClassifierWrapper(nn.Module):
53
+ def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
54
+ super().__init__()
55
+ self.backbone = backbone
56
+ self.linear_head = linear_head
57
+ self.layers = layers
58
+
59
+ def forward(self, x):
60
+ if self.layers == 1:
61
+ x = self.backbone.forward_features(x)
62
+ cls_token = x["x_norm_clstoken"]
63
+ patch_tokens = x["x_norm_patchtokens"]
64
+ # fmt: off
65
+ linear_input = torch.cat([
66
+ cls_token,
67
+ patch_tokens.mean(dim=1),
68
+ ], dim=1)
69
+ # fmt: on
70
+ elif self.layers == 4:
71
+ x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
72
+ # fmt: off
73
+ linear_input = torch.cat([
74
+ x[0][1],
75
+ x[1][1],
76
+ x[2][1],
77
+ x[3][1],
78
+ x[3][0].mean(dim=1),
79
+ ], dim=1)
80
+ # fmt: on
81
+ else:
82
+ assert False, f"Unsupported number of layers: {self.layers}"
83
+ return self.linear_head(linear_input)
84
+
85
+
86
+ def _make_dinov2_linear_classifier(
87
+ *,
88
+ arch_name: str = "vit_large",
89
+ layers: int = 4,
90
+ pretrained: bool = True,
91
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
92
+ num_register_tokens: int = 0,
93
+ interpolate_antialias: bool = False,
94
+ interpolate_offset: float = 0.1,
95
+ **kwargs,
96
+ ):
97
+ backbone = _make_dinov2_model(
98
+ arch_name=arch_name,
99
+ pretrained=pretrained,
100
+ num_register_tokens=num_register_tokens,
101
+ interpolate_antialias=interpolate_antialias,
102
+ interpolate_offset=interpolate_offset,
103
+ **kwargs,
104
+ )
105
+
106
+ embed_dim = backbone.embed_dim
107
+ patch_size = backbone.patch_size
108
+ linear_head = _make_dinov2_linear_classification_head(
109
+ arch_name=arch_name,
110
+ patch_size=patch_size,
111
+ embed_dim=embed_dim,
112
+ layers=layers,
113
+ pretrained=pretrained,
114
+ weights=weights,
115
+ num_register_tokens=num_register_tokens,
116
+ )
117
+
118
+ return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
119
+
120
+
121
+ def dinov2_vits14_lc(
122
+ *,
123
+ layers: int = 4,
124
+ pretrained: bool = True,
125
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
126
+ **kwargs,
127
+ ):
128
+ """
129
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
130
+ """
131
+ return _make_dinov2_linear_classifier(
132
+ arch_name="vit_small",
133
+ layers=layers,
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ **kwargs,
137
+ )
138
+
139
+
140
+ def dinov2_vitb14_lc(
141
+ *,
142
+ layers: int = 4,
143
+ pretrained: bool = True,
144
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
145
+ **kwargs,
146
+ ):
147
+ """
148
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
149
+ """
150
+ return _make_dinov2_linear_classifier(
151
+ arch_name="vit_base",
152
+ layers=layers,
153
+ pretrained=pretrained,
154
+ weights=weights,
155
+ **kwargs,
156
+ )
157
+
158
+
159
+ def dinov2_vitl14_lc(
160
+ *,
161
+ layers: int = 4,
162
+ pretrained: bool = True,
163
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
164
+ **kwargs,
165
+ ):
166
+ """
167
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
168
+ """
169
+ return _make_dinov2_linear_classifier(
170
+ arch_name="vit_large",
171
+ layers=layers,
172
+ pretrained=pretrained,
173
+ weights=weights,
174
+ **kwargs,
175
+ )
176
+
177
+
178
+ def dinov2_vitg14_lc(
179
+ *,
180
+ layers: int = 4,
181
+ pretrained: bool = True,
182
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
183
+ **kwargs,
184
+ ):
185
+ """
186
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
187
+ """
188
+ return _make_dinov2_linear_classifier(
189
+ arch_name="vit_giant2",
190
+ layers=layers,
191
+ ffn_layer="swiglufused",
192
+ pretrained=pretrained,
193
+ weights=weights,
194
+ **kwargs,
195
+ )
196
+
197
+
198
+ def dinov2_vits14_reg_lc(
199
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
200
+ ):
201
+ """
202
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
203
+ """
204
+ return _make_dinov2_linear_classifier(
205
+ arch_name="vit_small",
206
+ layers=layers,
207
+ pretrained=pretrained,
208
+ weights=weights,
209
+ num_register_tokens=4,
210
+ interpolate_antialias=True,
211
+ interpolate_offset=0.0,
212
+ **kwargs,
213
+ )
214
+
215
+
216
+ def dinov2_vitb14_reg_lc(
217
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
218
+ ):
219
+ """
220
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
221
+ """
222
+ return _make_dinov2_linear_classifier(
223
+ arch_name="vit_base",
224
+ layers=layers,
225
+ pretrained=pretrained,
226
+ weights=weights,
227
+ num_register_tokens=4,
228
+ interpolate_antialias=True,
229
+ interpolate_offset=0.0,
230
+ **kwargs,
231
+ )
232
+
233
+
234
+ def dinov2_vitl14_reg_lc(
235
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
236
+ ):
237
+ """
238
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
239
+ """
240
+ return _make_dinov2_linear_classifier(
241
+ arch_name="vit_large",
242
+ layers=layers,
243
+ pretrained=pretrained,
244
+ weights=weights,
245
+ num_register_tokens=4,
246
+ interpolate_antialias=True,
247
+ interpolate_offset=0.0,
248
+ **kwargs,
249
+ )
250
+
251
+
252
+ def dinov2_vitg14_reg_lc(
253
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
254
+ ):
255
+ """
256
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
257
+ """
258
+ return _make_dinov2_linear_classifier(
259
+ arch_name="vit_giant2",
260
+ layers=layers,
261
+ ffn_layer="swiglufused",
262
+ pretrained=pretrained,
263
+ weights=weights,
264
+ num_register_tokens=4,
265
+ interpolate_antialias=True,
266
+ interpolate_offset=0.0,
267
+ **kwargs,
268
+ )
core/encoders/dinov2/hub/depth/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .decode_heads import BNHead, DPTHead
7
+ from .encoder_decoder import DepthEncoderDecoder
core/encoders/dinov2/hub/depth/decode_heads.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ from functools import partial
8
+ import math
9
+ import warnings
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .ops import resize
15
+
16
+
17
+ # XXX: (Untested) replacement for mmcv.imdenormalize()
18
+ def _imdenormalize(img, mean, std, to_bgr=True):
19
+ import numpy as np
20
+
21
+ mean = mean.reshape(1, -1).astype(np.float64)
22
+ std = std.reshape(1, -1).astype(np.float64)
23
+ img = (img * std) + mean
24
+ if to_bgr:
25
+ img = img[::-1]
26
+ return img
27
+
28
+
29
+ class DepthBaseDecodeHead(nn.Module):
30
+ """Base class for BaseDecodeHead.
31
+
32
+ Args:
33
+ in_channels (List): Input channels.
34
+ channels (int): Channels after modules, before conv_depth.
35
+ conv_layer (nn.Module): Conv layers. Default: None.
36
+ act_layer (nn.Module): Activation layers. Default: nn.ReLU.
37
+ loss_decode (dict): Config of decode loss.
38
+ Default: ().
39
+ sampler (dict|None): The config of depth map sampler.
40
+ Default: None.
41
+ align_corners (bool): align_corners argument of F.interpolate.
42
+ Default: False.
43
+ min_depth (int): Min depth in dataset setting.
44
+ Default: 1e-3.
45
+ max_depth (int): Max depth in dataset setting.
46
+ Default: None.
47
+ norm_layer (dict|None): Norm layers.
48
+ Default: None.
49
+ classify (bool): Whether predict depth in a cls.-reg. manner.
50
+ Default: False.
51
+ n_bins (int): The number of bins used in cls. step.
52
+ Default: 256.
53
+ bins_strategy (str): The discrete strategy used in cls. step.
54
+ Default: 'UD'.
55
+ norm_strategy (str): The norm strategy on cls. probability
56
+ distribution. Default: 'linear'
57
+ scale_up (str): Whether predict depth in a scale-up manner.
58
+ Default: False.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ conv_layer=None,
65
+ act_layer=nn.ReLU,
66
+ channels=96,
67
+ loss_decode=(),
68
+ sampler=None,
69
+ align_corners=False,
70
+ min_depth=1e-3,
71
+ max_depth=None,
72
+ norm_layer=None,
73
+ classify=False,
74
+ n_bins=256,
75
+ bins_strategy="UD",
76
+ norm_strategy="linear",
77
+ scale_up=False,
78
+ ):
79
+ super(DepthBaseDecodeHead, self).__init__()
80
+
81
+ self.in_channels = in_channels
82
+ self.channels = channels
83
+ self.conf_layer = conv_layer
84
+ self.act_layer = act_layer
85
+ self.loss_decode = loss_decode
86
+ self.align_corners = align_corners
87
+ self.min_depth = min_depth
88
+ self.max_depth = max_depth
89
+ self.norm_layer = norm_layer
90
+ self.classify = classify
91
+ self.n_bins = n_bins
92
+ self.scale_up = scale_up
93
+
94
+ if self.classify:
95
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
96
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
97
+
98
+ self.bins_strategy = bins_strategy
99
+ self.norm_strategy = norm_strategy
100
+ self.softmax = nn.Softmax(dim=1)
101
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
102
+ else:
103
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
104
+
105
+ self.relu = nn.ReLU()
106
+ self.sigmoid = nn.Sigmoid()
107
+
108
+ def forward(self, inputs, img_metas):
109
+ """Placeholder of forward function."""
110
+ pass
111
+
112
+ def forward_train(self, img, inputs, img_metas, depth_gt):
113
+ """Forward function for training.
114
+ Args:
115
+ inputs (list[Tensor]): List of multi-level img features.
116
+ img_metas (list[dict]): List of image info dict where each dict
117
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
118
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
119
+ For details on the values of these keys see
120
+ `depth/datasets/pipelines/formatting.py:Collect`.
121
+ depth_gt (Tensor): GT depth
122
+
123
+ Returns:
124
+ dict[str, Tensor]: a dictionary of loss components
125
+ """
126
+ depth_pred = self.forward(inputs, img_metas)
127
+ losses = self.losses(depth_pred, depth_gt)
128
+
129
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
130
+ losses.update(**log_imgs)
131
+
132
+ return losses
133
+
134
+ def forward_test(self, inputs, img_metas):
135
+ """Forward function for testing.
136
+ Args:
137
+ inputs (list[Tensor]): List of multi-level img features.
138
+ img_metas (list[dict]): List of image info dict where each dict
139
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
140
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
141
+ For details on the values of these keys see
142
+ `depth/datasets/pipelines/formatting.py:Collect`.
143
+
144
+ Returns:
145
+ Tensor: Output depth map.
146
+ """
147
+ return self.forward(inputs, img_metas)
148
+
149
+ def depth_pred(self, feat):
150
+ """Prediction each pixel."""
151
+ if self.classify:
152
+ logit = self.conv_depth(feat)
153
+
154
+ if self.bins_strategy == "UD":
155
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
156
+ elif self.bins_strategy == "SID":
157
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
158
+
159
+ # following Adabins, default linear
160
+ if self.norm_strategy == "linear":
161
+ logit = torch.relu(logit)
162
+ eps = 0.1
163
+ logit = logit + eps
164
+ logit = logit / logit.sum(dim=1, keepdim=True)
165
+ elif self.norm_strategy == "softmax":
166
+ logit = torch.softmax(logit, dim=1)
167
+ elif self.norm_strategy == "sigmoid":
168
+ logit = torch.sigmoid(logit)
169
+ logit = logit / logit.sum(dim=1, keepdim=True)
170
+
171
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
172
+
173
+ else:
174
+ if self.scale_up:
175
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
176
+ else:
177
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
178
+ return output
179
+
180
+ def losses(self, depth_pred, depth_gt):
181
+ """Compute depth loss."""
182
+ loss = dict()
183
+ depth_pred = resize(
184
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
185
+ )
186
+ if not isinstance(self.loss_decode, nn.ModuleList):
187
+ losses_decode = [self.loss_decode]
188
+ else:
189
+ losses_decode = self.loss_decode
190
+ for loss_decode in losses_decode:
191
+ if loss_decode.loss_name not in loss:
192
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
193
+ else:
194
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
195
+ return loss
196
+
197
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
198
+ import numpy as np
199
+
200
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
201
+ show_img = show_img.numpy().astype(np.float32)
202
+ show_img = _imdenormalize(
203
+ show_img,
204
+ img_meta["img_norm_cfg"]["mean"],
205
+ img_meta["img_norm_cfg"]["std"],
206
+ img_meta["img_norm_cfg"]["to_rgb"],
207
+ )
208
+ show_img = np.clip(show_img, 0, 255)
209
+ show_img = show_img.astype(np.uint8)
210
+ show_img = show_img[:, :, ::-1]
211
+ show_img = show_img.transpose(0, 2, 1)
212
+ show_img = show_img.transpose(1, 0, 2)
213
+
214
+ depth_pred = depth_pred / torch.max(depth_pred)
215
+ depth_gt = depth_gt / torch.max(depth_gt)
216
+
217
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
218
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
219
+
220
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
221
+
222
+
223
+ class BNHead(DepthBaseDecodeHead):
224
+ """Just a batchnorm."""
225
+
226
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
227
+ super().__init__(**kwargs)
228
+ self.input_transform = input_transform
229
+ self.in_index = in_index
230
+ self.upsample = upsample
231
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
232
+ if self.classify:
233
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
234
+ else:
235
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
236
+
237
+ def _transform_inputs(self, inputs):
238
+ """Transform inputs for decoder.
239
+ Args:
240
+ inputs (list[Tensor]): List of multi-level img features.
241
+ Returns:
242
+ Tensor: The transformed inputs
243
+ """
244
+
245
+ if "concat" in self.input_transform:
246
+ inputs = [inputs[i] for i in self.in_index]
247
+ if "resize" in self.input_transform:
248
+ inputs = [
249
+ resize(
250
+ input=x,
251
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
252
+ mode="bilinear",
253
+ align_corners=self.align_corners,
254
+ )
255
+ for x in inputs
256
+ ]
257
+ inputs = torch.cat(inputs, dim=1)
258
+ elif self.input_transform == "multiple_select":
259
+ inputs = [inputs[i] for i in self.in_index]
260
+ else:
261
+ inputs = inputs[self.in_index]
262
+
263
+ return inputs
264
+
265
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
266
+ """Forward function for feature maps before classifying each pixel with
267
+ ``self.cls_seg`` fc.
268
+ Args:
269
+ inputs (list[Tensor]): List of multi-level img features.
270
+ Returns:
271
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
272
+ H, W) which is feature map for last layer of decoder head.
273
+ """
274
+ # accept lists (for cls token)
275
+ inputs = list(inputs)
276
+ for i, x in enumerate(inputs):
277
+ if len(x) == 2:
278
+ x, cls_token = x[0], x[1]
279
+ if len(x.shape) == 2:
280
+ x = x[:, :, None, None]
281
+ cls_token = cls_token[:, :, None, None].expand_as(x)
282
+ inputs[i] = torch.cat((x, cls_token), 1)
283
+ else:
284
+ x = x[0]
285
+ if len(x.shape) == 2:
286
+ x = x[:, :, None, None]
287
+ inputs[i] = x
288
+ x = self._transform_inputs(inputs)
289
+ # feats = self.bn(x)
290
+ return x
291
+
292
+ def forward(self, inputs, img_metas=None, **kwargs):
293
+ """Forward function."""
294
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
295
+ output = self.depth_pred(output)
296
+ return output
297
+
298
+
299
+ class ConvModule(nn.Module):
300
+ """A conv block that bundles conv/norm/activation layers.
301
+
302
+ This block simplifies the usage of convolution layers, which are commonly
303
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
304
+ It is based upon three build methods: `build_conv_layer()`,
305
+ `build_norm_layer()` and `build_activation_layer()`.
306
+
307
+ Besides, we add some additional features in this module.
308
+ 1. Automatically set `bias` of the conv layer.
309
+ 2. Spectral norm is supported.
310
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
311
+ supports zero and circular padding, and we add "reflect" padding mode.
312
+
313
+ Args:
314
+ in_channels (int): Number of channels in the input feature map.
315
+ Same as that in ``nn._ConvNd``.
316
+ out_channels (int): Number of channels produced by the convolution.
317
+ Same as that in ``nn._ConvNd``.
318
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
319
+ Same as that in ``nn._ConvNd``.
320
+ stride (int | tuple[int]): Stride of the convolution.
321
+ Same as that in ``nn._ConvNd``.
322
+ padding (int | tuple[int]): Zero-padding added to both sides of
323
+ the input. Same as that in ``nn._ConvNd``.
324
+ dilation (int | tuple[int]): Spacing between kernel elements.
325
+ Same as that in ``nn._ConvNd``.
326
+ groups (int): Number of blocked connections from input channels to
327
+ output channels. Same as that in ``nn._ConvNd``.
328
+ bias (bool | str): If specified as `auto`, it will be decided by the
329
+ norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
330
+ False. Default: "auto".
331
+ conv_layer (nn.Module): Convolution layer. Default: None,
332
+ which means using conv2d.
333
+ norm_layer (nn.Module): Normalization layer. Default: None.
334
+ act_layer (nn.Module): Activation layer. Default: nn.ReLU.
335
+ inplace (bool): Whether to use inplace mode for activation.
336
+ Default: True.
337
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
338
+ Default: False.
339
+ padding_mode (str): If the `padding_mode` has not been supported by
340
+ current `Conv2d` in PyTorch, we will use our own padding layer
341
+ instead. Currently, we support ['zeros', 'circular'] with official
342
+ implementation and ['reflect'] with our own implementation.
343
+ Default: 'zeros'.
344
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
345
+ sequence of "conv", "norm" and "act". Common examples are
346
+ ("conv", "norm", "act") and ("act", "conv", "norm").
347
+ Default: ('conv', 'norm', 'act').
348
+ """
349
+
350
+ _abbr_ = "conv_block"
351
+
352
+ def __init__(
353
+ self,
354
+ in_channels,
355
+ out_channels,
356
+ kernel_size,
357
+ stride=1,
358
+ padding=0,
359
+ dilation=1,
360
+ groups=1,
361
+ bias="auto",
362
+ conv_layer=nn.Conv2d,
363
+ norm_layer=None,
364
+ act_layer=nn.ReLU,
365
+ inplace=True,
366
+ with_spectral_norm=False,
367
+ padding_mode="zeros",
368
+ order=("conv", "norm", "act"),
369
+ ):
370
+ super(ConvModule, self).__init__()
371
+ official_padding_mode = ["zeros", "circular"]
372
+ self.conv_layer = conv_layer
373
+ self.norm_layer = norm_layer
374
+ self.act_layer = act_layer
375
+ self.inplace = inplace
376
+ self.with_spectral_norm = with_spectral_norm
377
+ self.with_explicit_padding = padding_mode not in official_padding_mode
378
+ self.order = order
379
+ assert isinstance(self.order, tuple) and len(self.order) == 3
380
+ assert set(order) == set(["conv", "norm", "act"])
381
+
382
+ self.with_norm = norm_layer is not None
383
+ self.with_activation = act_layer is not None
384
+ # if the conv layer is before a norm layer, bias is unnecessary.
385
+ if bias == "auto":
386
+ bias = not self.with_norm
387
+ self.with_bias = bias
388
+
389
+ if self.with_explicit_padding:
390
+ if padding_mode == "zeros":
391
+ padding_layer = nn.ZeroPad2d
392
+ else:
393
+ raise AssertionError(f"Unsupported padding mode: {padding_mode}")
394
+ self.pad = padding_layer(padding)
395
+
396
+ # reset padding to 0 for conv module
397
+ conv_padding = 0 if self.with_explicit_padding else padding
398
+ # build convolution layer
399
+ self.conv = self.conv_layer(
400
+ in_channels,
401
+ out_channels,
402
+ kernel_size,
403
+ stride=stride,
404
+ padding=conv_padding,
405
+ dilation=dilation,
406
+ groups=groups,
407
+ bias=bias,
408
+ )
409
+ # export the attributes of self.conv to a higher level for convenience
410
+ self.in_channels = self.conv.in_channels
411
+ self.out_channels = self.conv.out_channels
412
+ self.kernel_size = self.conv.kernel_size
413
+ self.stride = self.conv.stride
414
+ self.padding = padding
415
+ self.dilation = self.conv.dilation
416
+ self.transposed = self.conv.transposed
417
+ self.output_padding = self.conv.output_padding
418
+ self.groups = self.conv.groups
419
+
420
+ if self.with_spectral_norm:
421
+ self.conv = nn.utils.spectral_norm(self.conv)
422
+
423
+ # build normalization layers
424
+ if self.with_norm:
425
+ # norm layer is after conv layer
426
+ if order.index("norm") > order.index("conv"):
427
+ norm_channels = out_channels
428
+ else:
429
+ norm_channels = in_channels
430
+ norm = partial(norm_layer, num_features=norm_channels)
431
+ self.add_module("norm", norm)
432
+ if self.with_bias:
433
+ from torch.nnModules.batchnorm import _BatchNorm
434
+ from torch.nnModules.instancenorm import _InstanceNorm
435
+
436
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
437
+ warnings.warn("Unnecessary conv bias before batch/instance norm")
438
+ else:
439
+ self.norm_name = None
440
+
441
+ # build activation layer
442
+ if self.with_activation:
443
+ # nn.Tanh has no 'inplace' argument
444
+ # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
445
+ if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
446
+ act_layer = partial(act_layer, inplace=inplace)
447
+ self.activate = act_layer()
448
+
449
+ # Use msra init by default
450
+ self.init_weights()
451
+
452
+ @property
453
+ def norm(self):
454
+ if self.norm_name:
455
+ return getattr(self, self.norm_name)
456
+ else:
457
+ return None
458
+
459
+ def init_weights(self):
460
+ # 1. It is mainly for customized conv layers with their own
461
+ # initialization manners by calling their own ``init_weights()``,
462
+ # and we do not want ConvModule to override the initialization.
463
+ # 2. For customized conv layers without their own initialization
464
+ # manners (that is, they don't have their own ``init_weights()``)
465
+ # and PyTorch's conv layers, they will be initialized by
466
+ # this method with default ``kaiming_init``.
467
+ # Note: For PyTorch's conv layers, they will be overwritten by our
468
+ # initialization implementation using default ``kaiming_init``.
469
+ if not hasattr(self.conv, "init_weights"):
470
+ if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
471
+ nonlinearity = "leaky_relu"
472
+ a = 0.01 # XXX: default negative_slope
473
+ else:
474
+ nonlinearity = "relu"
475
+ a = 0
476
+ if hasattr(self.conv, "weight") and self.conv.weight is not None:
477
+ nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
478
+ if hasattr(self.conv, "bias") and self.conv.bias is not None:
479
+ nn.init.constant_(self.conv.bias, 0)
480
+ if self.with_norm:
481
+ if hasattr(self.norm, "weight") and self.norm.weight is not None:
482
+ nn.init.constant_(self.norm.weight, 1)
483
+ if hasattr(self.norm, "bias") and self.norm.bias is not None:
484
+ nn.init.constant_(self.norm.bias, 0)
485
+
486
+ def forward(self, x, activate=True, norm=True):
487
+ for layer in self.order:
488
+ if layer == "conv":
489
+ if self.with_explicit_padding:
490
+ x = self.pad(x)
491
+ x = self.conv(x)
492
+ elif layer == "norm" and norm and self.with_norm:
493
+ x = self.norm(x)
494
+ elif layer == "act" and activate and self.with_activation:
495
+ x = self.activate(x)
496
+ return x
497
+
498
+
499
+ class Interpolate(nn.Module):
500
+ def __init__(self, scale_factor, mode, align_corners=False):
501
+ super(Interpolate, self).__init__()
502
+ self.interp = nn.functional.interpolate
503
+ self.scale_factor = scale_factor
504
+ self.mode = mode
505
+ self.align_corners = align_corners
506
+
507
+ def forward(self, x):
508
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
509
+ return x
510
+
511
+
512
+ class HeadDepth(nn.Module):
513
+ def __init__(self, features):
514
+ super(HeadDepth, self).__init__()
515
+ self.head = nn.Sequential(
516
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
517
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
518
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
519
+ nn.ReLU(),
520
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
521
+ )
522
+
523
+ def forward(self, x):
524
+ x = self.head(x)
525
+ return x
526
+
527
+
528
+ class ReassembleBlocks(nn.Module):
529
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
530
+ rearrange the feature vector to feature map.
531
+ Args:
532
+ in_channels (int): ViT feature channels. Default: 768.
533
+ out_channels (List): output channels of each stage.
534
+ Default: [96, 192, 384, 768].
535
+ readout_type (str): Type of readout operation. Default: 'ignore'.
536
+ patch_size (int): The patch size. Default: 16.
537
+ """
538
+
539
+ def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
540
+ super(ReassembleBlocks, self).__init__()
541
+
542
+ assert readout_type in ["ignore", "add", "project"]
543
+ self.readout_type = readout_type
544
+ self.patch_size = patch_size
545
+
546
+ self.projects = nn.ModuleList(
547
+ [
548
+ ConvModule(
549
+ in_channels=in_channels,
550
+ out_channels=out_channel,
551
+ kernel_size=1,
552
+ act_layer=None,
553
+ )
554
+ for out_channel in out_channels
555
+ ]
556
+ )
557
+
558
+ self.resize_layers = nn.ModuleList(
559
+ [
560
+ nn.ConvTranspose2d(
561
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
562
+ ),
563
+ nn.ConvTranspose2d(
564
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
565
+ ),
566
+ nn.Identity(),
567
+ nn.Conv2d(
568
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
569
+ ),
570
+ ]
571
+ )
572
+ if self.readout_type == "project":
573
+ self.readout_projects = nn.ModuleList()
574
+ for _ in range(len(self.projects)):
575
+ self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
576
+
577
+ def forward(self, inputs):
578
+ assert isinstance(inputs, list)
579
+ out = []
580
+ for i, x in enumerate(inputs):
581
+ assert len(x) == 2
582
+ x, cls_token = x[0], x[1]
583
+ feature_shape = x.shape
584
+ if self.readout_type == "project":
585
+ x = x.flatten(2).permute((0, 2, 1))
586
+ readout = cls_token.unsqueeze(1).expand_as(x)
587
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
588
+ x = x.permute(0, 2, 1).reshape(feature_shape)
589
+ elif self.readout_type == "add":
590
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
591
+ x = x.reshape(feature_shape)
592
+ else:
593
+ pass
594
+ x = self.projects[i](x)
595
+ x = self.resize_layers[i](x)
596
+ out.append(x)
597
+ return out
598
+
599
+
600
+ class PreActResidualConvUnit(nn.Module):
601
+ """ResidualConvUnit, pre-activate residual unit.
602
+ Args:
603
+ in_channels (int): number of channels in the input feature map.
604
+ act_layer (nn.Module): activation layer.
605
+ norm_layer (nn.Module): norm layer.
606
+ stride (int): stride of the first block. Default: 1
607
+ dilation (int): dilation rate for convs layers. Default: 1.
608
+ """
609
+
610
+ def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
611
+ super(PreActResidualConvUnit, self).__init__()
612
+
613
+ self.conv1 = ConvModule(
614
+ in_channels,
615
+ in_channels,
616
+ 3,
617
+ stride=stride,
618
+ padding=dilation,
619
+ dilation=dilation,
620
+ norm_layer=norm_layer,
621
+ act_layer=act_layer,
622
+ bias=False,
623
+ order=("act", "conv", "norm"),
624
+ )
625
+
626
+ self.conv2 = ConvModule(
627
+ in_channels,
628
+ in_channels,
629
+ 3,
630
+ padding=1,
631
+ norm_layer=norm_layer,
632
+ act_layer=act_layer,
633
+ bias=False,
634
+ order=("act", "conv", "norm"),
635
+ )
636
+
637
+ def forward(self, inputs):
638
+ inputs_ = inputs.clone()
639
+ x = self.conv1(inputs)
640
+ x = self.conv2(x)
641
+ return x + inputs_
642
+
643
+
644
+ class FeatureFusionBlock(nn.Module):
645
+ """FeatureFusionBlock, merge feature map from different stages.
646
+ Args:
647
+ in_channels (int): Input channels.
648
+ act_layer (nn.Module): activation layer for ResidualConvUnit.
649
+ norm_layer (nn.Module): normalization layer.
650
+ expand (bool): Whether expand the channels in post process block.
651
+ Default: False.
652
+ align_corners (bool): align_corner setting for bilinear upsample.
653
+ Default: True.
654
+ """
655
+
656
+ def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
657
+ super(FeatureFusionBlock, self).__init__()
658
+
659
+ self.in_channels = in_channels
660
+ self.expand = expand
661
+ self.align_corners = align_corners
662
+
663
+ self.out_channels = in_channels
664
+ if self.expand:
665
+ self.out_channels = in_channels // 2
666
+
667
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
668
+
669
+ self.res_conv_unit1 = PreActResidualConvUnit(
670
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
671
+ )
672
+ self.res_conv_unit2 = PreActResidualConvUnit(
673
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
674
+ )
675
+
676
+ def forward(self, *inputs):
677
+ x = inputs[0]
678
+ if len(inputs) == 2:
679
+ if x.shape != inputs[1].shape:
680
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
681
+ else:
682
+ res = inputs[1]
683
+ x = x + self.res_conv_unit1(res)
684
+ x = self.res_conv_unit2(x)
685
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
686
+ x = self.project(x)
687
+ return x
688
+
689
+
690
+ class DPTHead(DepthBaseDecodeHead):
691
+ """Vision Transformers for Dense Prediction.
692
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
693
+ Args:
694
+ embed_dims (int): The embed dimension of the ViT backbone.
695
+ Default: 768.
696
+ post_process_channels (List): Out channels of post process conv
697
+ layers. Default: [96, 192, 384, 768].
698
+ readout_type (str): Type of readout operation. Default: 'ignore'.
699
+ patch_size (int): The patch size. Default: 16.
700
+ expand_channels (bool): Whether expand the channels in post process
701
+ block. Default: False.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ embed_dims=768,
707
+ post_process_channels=[96, 192, 384, 768],
708
+ readout_type="ignore",
709
+ patch_size=16,
710
+ expand_channels=False,
711
+ **kwargs,
712
+ ):
713
+ super(DPTHead, self).__init__(**kwargs)
714
+
715
+ self.in_channels = self.in_channels
716
+ self.expand_channels = expand_channels
717
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
718
+
719
+ self.post_process_channels = [
720
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
721
+ ]
722
+ self.convs = nn.ModuleList()
723
+ for channel in self.post_process_channels:
724
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
725
+ self.fusion_blocks = nn.ModuleList()
726
+ for _ in range(len(self.convs)):
727
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
728
+ self.fusion_blocks[0].res_conv_unit1 = None
729
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
730
+ self.num_fusion_blocks = len(self.fusion_blocks)
731
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
732
+ self.num_post_process_channels = len(self.post_process_channels)
733
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
734
+ assert self.num_reassemble_blocks == self.num_post_process_channels
735
+ self.conv_depth = HeadDepth(self.channels)
736
+
737
+ def forward(self, inputs, img_metas):
738
+ assert len(inputs) == self.num_reassemble_blocks
739
+ x = [inp for inp in inputs]
740
+ x = self.reassemble_blocks(x)
741
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
742
+ out = self.fusion_blocks[0](x[-1])
743
+ for i in range(1, len(self.fusion_blocks)):
744
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
745
+ out = self.project(out)
746
+ out = self.depth_pred(out)
747
+ return out
core/encoders/dinov2/hub/depth/encoder_decoder.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .ops import resize
13
+
14
+
15
+ def add_prefix(inputs, prefix):
16
+ """Add prefix for dict.
17
+
18
+ Args:
19
+ inputs (dict): The input dict with str keys.
20
+ prefix (str): The prefix to add.
21
+
22
+ Returns:
23
+
24
+ dict: The dict with keys updated with ``prefix``.
25
+ """
26
+
27
+ outputs = dict()
28
+ for name, value in inputs.items():
29
+ outputs[f"{prefix}.{name}"] = value
30
+
31
+ return outputs
32
+
33
+
34
+ class DepthEncoderDecoder(nn.Module):
35
+ """Encoder Decoder depther.
36
+
37
+ EncoderDecoder typically consists of backbone and decode_head.
38
+ """
39
+
40
+ def __init__(self, backbone, decode_head):
41
+ super(DepthEncoderDecoder, self).__init__()
42
+
43
+ self.backbone = backbone
44
+ self.decode_head = decode_head
45
+ self.align_corners = self.decode_head.align_corners
46
+
47
+ def extract_feat(self, img):
48
+ """Extract features from images."""
49
+ return self.backbone(img)
50
+
51
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
52
+ """Encode images with backbone and decode into a depth estimation
53
+ map of the same size as input."""
54
+ x = self.extract_feat(img)
55
+ out = self._decode_head_forward_test(x, img_metas)
56
+ # crop the pred depth to the certain range.
57
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
58
+ if rescale:
59
+ if size is None:
60
+ if img_metas is not None:
61
+ size = img_metas[0]["ori_shape"][:2]
62
+ else:
63
+ size = img.shape[2:]
64
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
65
+ return out
66
+
67
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
68
+ """Run forward function and calculate loss for decode head in
69
+ training."""
70
+ losses = dict()
71
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
72
+ losses.update(add_prefix(loss_decode, "decode"))
73
+ return losses
74
+
75
+ def _decode_head_forward_test(self, x, img_metas):
76
+ """Run forward function and calculate loss for decode head in
77
+ inference."""
78
+ depth_pred = self.decode_head.forward_test(x, img_metas)
79
+ return depth_pred
80
+
81
+ def forward_dummy(self, img):
82
+ """Dummy forward function."""
83
+ depth = self.encode_decode(img, None)
84
+
85
+ return depth
86
+
87
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
88
+ """Forward function for training.
89
+
90
+ Args:
91
+ img (Tensor): Input images.
92
+ img_metas (list[dict]): List of image info dict where each dict
93
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
94
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
95
+ For details on the values of these keys see
96
+ `depth/datasets/pipelines/formatting.py:Collect`.
97
+ depth_gt (Tensor): Depth gt
98
+ used if the architecture supports depth estimation task.
99
+
100
+ Returns:
101
+ dict[str, Tensor]: a dictionary of loss components
102
+ """
103
+
104
+ x = self.extract_feat(img)
105
+
106
+ losses = dict()
107
+
108
+ # the last of x saves the info from neck
109
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
110
+
111
+ losses.update(loss_decode)
112
+
113
+ return losses
114
+
115
+ def whole_inference(self, img, img_meta, rescale, size=None):
116
+ """Inference with full image."""
117
+ return self.encode_decode(img, img_meta, rescale, size=size)
118
+
119
+ def slide_inference(self, img, img_meta, rescale, stride, crop_size):
120
+ """Inference by sliding-window with overlap.
121
+
122
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
123
+ decode without padding.
124
+ """
125
+
126
+ h_stride, w_stride = stride
127
+ h_crop, w_crop = crop_size
128
+ batch_size, _, h_img, w_img = img.size()
129
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
130
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
131
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
132
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
133
+ for h_idx in range(h_grids):
134
+ for w_idx in range(w_grids):
135
+ y1 = h_idx * h_stride
136
+ x1 = w_idx * w_stride
137
+ y2 = min(y1 + h_crop, h_img)
138
+ x2 = min(x1 + w_crop, w_img)
139
+ y1 = max(y2 - h_crop, 0)
140
+ x1 = max(x2 - w_crop, 0)
141
+ crop_img = img[:, :, y1:y2, x1:x2]
142
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
143
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
144
+
145
+ count_mat[:, :, y1:y2, x1:x2] += 1
146
+ assert (count_mat == 0).sum() == 0
147
+ if torch.onnx.is_in_onnx_export():
148
+ # cast count_mat to constant while exporting to ONNX
149
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
150
+ preds = preds / count_mat
151
+ return preds
152
+
153
+ def inference(self, img, img_meta, rescale, size=None, mode="whole"):
154
+ """Inference with slide/whole style.
155
+
156
+ Args:
157
+ img (Tensor): The input image of shape (N, 3, H, W).
158
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
159
+ 'scale_factor', 'flip', and may also contain
160
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
161
+ For details on the values of these keys see
162
+ `depth/datasets/pipelines/formatting.py:Collect`.
163
+ rescale (bool): Whether rescale back to original shape.
164
+
165
+ Returns:
166
+ Tensor: The output depth map.
167
+ """
168
+
169
+ assert mode in ["slide", "whole"]
170
+ ori_shape = img_meta[0]["ori_shape"]
171
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
172
+ if mode == "slide":
173
+ depth_pred = self.slide_inference(img, img_meta, rescale)
174
+ else:
175
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
176
+ output = depth_pred
177
+ flip = img_meta[0]["flip"]
178
+ if flip:
179
+ flip_direction = img_meta[0]["flip_direction"]
180
+ assert flip_direction in ["horizontal", "vertical"]
181
+ if flip_direction == "horizontal":
182
+ output = output.flip(dims=(3,))
183
+ elif flip_direction == "vertical":
184
+ output = output.flip(dims=(2,))
185
+
186
+ return output
187
+
188
+ def simple_test(self, img, img_meta, rescale=True):
189
+ """Simple test with single image."""
190
+ depth_pred = self.inference(img, img_meta, rescale)
191
+ if torch.onnx.is_in_onnx_export():
192
+ # our inference backend only support 4D output
193
+ depth_pred = depth_pred.unsqueeze(0)
194
+ return depth_pred
195
+ depth_pred = depth_pred.cpu().numpy()
196
+ # unravel batch dim
197
+ depth_pred = list(depth_pred)
198
+ return depth_pred
199
+
200
+ def aug_test(self, imgs, img_metas, rescale=True):
201
+ """Test with augmentations.
202
+
203
+ Only rescale=True is supported.
204
+ """
205
+ # aug_test rescale all imgs back to ori_shape for now
206
+ assert rescale
207
+ # to save memory, we get augmented depth logit inplace
208
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
209
+ for i in range(1, len(imgs)):
210
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
211
+ depth_pred += cur_depth_pred
212
+ depth_pred /= len(imgs)
213
+ depth_pred = depth_pred.cpu().numpy()
214
+ # unravel batch dim
215
+ depth_pred = list(depth_pred)
216
+ return depth_pred
217
+
218
+ def forward_test(self, imgs, img_metas, **kwargs):
219
+ """
220
+ Args:
221
+ imgs (List[Tensor]): the outer list indicates test-time
222
+ augmentations and inner Tensor should have a shape NxCxHxW,
223
+ which contains all images in the batch.
224
+ img_metas (List[List[dict]]): the outer list indicates test-time
225
+ augs (multiscale, flip, etc.) and the inner list indicates
226
+ images in a batch.
227
+ """
228
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
229
+ if not isinstance(var, list):
230
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
231
+ num_augs = len(imgs)
232
+ if num_augs != len(img_metas):
233
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
234
+ # all images in the same aug batch all of the same ori_shape and pad
235
+ # shape
236
+ for img_meta in img_metas:
237
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
238
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
239
+ img_shapes = [_["img_shape"] for _ in img_meta]
240
+ assert all(shape == img_shapes[0] for shape in img_shapes)
241
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
242
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
243
+
244
+ if num_augs == 1:
245
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
246
+ else:
247
+ return self.aug_test(imgs, img_metas, **kwargs)
248
+
249
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
250
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
251
+ on whether ``return_loss`` is ``True``.
252
+
253
+ Note this setting will change the expected inputs. When
254
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
255
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
256
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
257
+ the outer list indicating test time augmentations.
258
+ """
259
+ if return_loss:
260
+ return self.forward_train(img, img_metas, **kwargs)
261
+ else:
262
+ return self.forward_test(img, img_metas, **kwargs)
263
+
264
+ def train_step(self, data_batch, optimizer, **kwargs):
265
+ """The iteration step during training.
266
+
267
+ This method defines an iteration step during training, except for the
268
+ back propagation and optimizer updating, which are done in an optimizer
269
+ hook. Note that in some complicated cases or models, the whole process
270
+ including back propagation and optimizer updating is also defined in
271
+ this method, such as GAN.
272
+
273
+ Args:
274
+ data (dict): The output of dataloader.
275
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
276
+ runner is passed to ``train_step()``. This argument is unused
277
+ and reserved.
278
+
279
+ Returns:
280
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
281
+ ``num_samples``.
282
+ ``loss`` is a tensor for back propagation, which can be a
283
+ weighted sum of multiple losses.
284
+ ``log_vars`` contains all the variables to be sent to the
285
+ logger.
286
+ ``num_samples`` indicates the batch size (when the model is
287
+ DDP, it means the batch size on each GPU), which is used for
288
+ averaging the logs.
289
+ """
290
+ losses = self(**data_batch)
291
+
292
+ # split losses and images
293
+ real_losses = {}
294
+ log_imgs = {}
295
+ for k, v in losses.items():
296
+ if "img" in k:
297
+ log_imgs[k] = v
298
+ else:
299
+ real_losses[k] = v
300
+
301
+ loss, log_vars = self._parse_losses(real_losses)
302
+
303
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
304
+
305
+ return outputs
306
+
307
+ def val_step(self, data_batch, **kwargs):
308
+ """The iteration step during validation.
309
+
310
+ This method shares the same signature as :func:`train_step`, but used
311
+ during val epochs. Note that the evaluation after training epochs is
312
+ not implemented with this method, but an evaluation hook.
313
+ """
314
+ output = self(**data_batch, **kwargs)
315
+ return output
316
+
317
+ @staticmethod
318
+ def _parse_losses(losses):
319
+ import torch.distributed as dist
320
+
321
+ """Parse the raw outputs (losses) of the network.
322
+
323
+ Args:
324
+ losses (dict): Raw output of the network, which usually contain
325
+ losses and other necessary information.
326
+
327
+ Returns:
328
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
329
+ which may be a weighted sum of all losses, log_vars contains
330
+ all the variables to be sent to the logger.
331
+ """
332
+ log_vars = OrderedDict()
333
+ for loss_name, loss_value in losses.items():
334
+ if isinstance(loss_value, torch.Tensor):
335
+ log_vars[loss_name] = loss_value.mean()
336
+ elif isinstance(loss_value, list):
337
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
338
+ else:
339
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
340
+
341
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
342
+
343
+ log_vars["loss"] = loss
344
+ for loss_name, loss_value in log_vars.items():
345
+ # reduce loss when distributed training
346
+ if dist.is_available() and dist.is_initialized():
347
+ loss_value = loss_value.data.clone()
348
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
349
+ log_vars[loss_name] = loss_value.item()
350
+
351
+ return loss, log_vars
core/encoders/dinov2/hub/depth/ops.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
12
+ if warning:
13
+ if size is not None and align_corners:
14
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
15
+ output_h, output_w = tuple(int(x) for x in size)
16
+ if output_h > input_h or output_w > output_h:
17
+ if (
18
+ (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
19
+ and (output_h - 1) % (input_h - 1)
20
+ and (output_w - 1) % (input_w - 1)
21
+ ):
22
+ warnings.warn(
23
+ f"When align_corners={align_corners}, "
24
+ "the output would more aligned if "
25
+ f"input size {(input_h, input_w)} is `x+1` and "
26
+ f"out size {(output_h, output_w)} is `nx+1`"
27
+ )
28
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
core/encoders/dinov2/hub/depthers.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .backbones import _make_dinov2_model
13
+ from .depth import BNHead, DepthEncoderDecoder, DPTHead
14
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
15
+
16
+
17
+ class Weights(Enum):
18
+ NYU = "NYU"
19
+ KITTI = "KITTI"
20
+
21
+
22
+ def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
23
+ if not pretrained: # Default
24
+ return (0.001, 10.0)
25
+
26
+ # Pretrained, set according to the training dataset for the provided weights
27
+ if weights == Weights.KITTI:
28
+ return (0.001, 80.0)
29
+
30
+ if weights == Weights.NYU:
31
+ return (0.001, 10.0)
32
+
33
+ return (0.001, 10.0)
34
+
35
+
36
+ def _make_dinov2_linear_depth_head(
37
+ *,
38
+ embed_dim: int,
39
+ layers: int,
40
+ min_depth: float,
41
+ max_depth: float,
42
+ **kwargs,
43
+ ):
44
+ if layers not in (1, 4):
45
+ raise AssertionError(f"Unsupported number of layers: {layers}")
46
+
47
+ if layers == 1:
48
+ in_index = [0]
49
+ else:
50
+ assert layers == 4
51
+ in_index = [0, 1, 2, 3]
52
+
53
+ return BNHead(
54
+ classify=True,
55
+ n_bins=256,
56
+ bins_strategy="UD",
57
+ norm_strategy="linear",
58
+ upsample=4,
59
+ in_channels=[embed_dim] * len(in_index),
60
+ in_index=in_index,
61
+ input_transform="resize_concat",
62
+ channels=embed_dim * len(in_index) * 2,
63
+ align_corners=False,
64
+ min_depth=0.001,
65
+ max_depth=80,
66
+ loss_decode=(),
67
+ )
68
+
69
+
70
+ def _make_dinov2_linear_depther(
71
+ *,
72
+ arch_name: str = "vit_large",
73
+ layers: int = 4,
74
+ pretrained: bool = True,
75
+ weights: Union[Weights, str] = Weights.NYU,
76
+ depth_range: Optional[Tuple[float, float]] = None,
77
+ **kwargs,
78
+ ):
79
+ if layers not in (1, 4):
80
+ raise AssertionError(f"Unsupported number of layers: {layers}")
81
+ if isinstance(weights, str):
82
+ try:
83
+ weights = Weights[weights]
84
+ except KeyError:
85
+ raise AssertionError(f"Unsupported weights: {weights}")
86
+
87
+ if depth_range is None:
88
+ depth_range = _get_depth_range(pretrained, weights)
89
+ min_depth, max_depth = depth_range
90
+
91
+ backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
92
+
93
+ embed_dim = backbone.embed_dim
94
+ patch_size = backbone.patch_size
95
+ model_name = _make_dinov2_model_name(arch_name, patch_size)
96
+ linear_depth_head = _make_dinov2_linear_depth_head(
97
+ embed_dim=embed_dim,
98
+ layers=layers,
99
+ min_depth=min_depth,
100
+ max_depth=max_depth,
101
+ )
102
+
103
+ layer_count = {
104
+ "vit_small": 12,
105
+ "vit_base": 12,
106
+ "vit_large": 24,
107
+ "vit_giant2": 40,
108
+ }[arch_name]
109
+
110
+ if layers == 4:
111
+ out_index = {
112
+ "vit_small": [2, 5, 8, 11],
113
+ "vit_base": [2, 5, 8, 11],
114
+ "vit_large": [4, 11, 17, 23],
115
+ "vit_giant2": [9, 19, 29, 39],
116
+ }[arch_name]
117
+ else:
118
+ assert layers == 1
119
+ out_index = [layer_count - 1]
120
+
121
+ model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
122
+ model.backbone.forward = partial(
123
+ backbone.get_intermediate_layers,
124
+ n=out_index,
125
+ reshape=True,
126
+ return_class_token=True,
127
+ norm=False,
128
+ )
129
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
130
+
131
+ if pretrained:
132
+ layers_str = str(layers) if layers == 4 else ""
133
+ weights_str = weights.value.lower()
134
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
135
+ checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
136
+ if "state_dict" in checkpoint:
137
+ state_dict = checkpoint["state_dict"]
138
+ model.load_state_dict(state_dict, strict=False)
139
+
140
+ return model
141
+
142
+
143
+ def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
144
+ return _make_dinov2_linear_depther(
145
+ arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
146
+ )
147
+
148
+
149
+ def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
150
+ return _make_dinov2_linear_depther(
151
+ arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
152
+ )
153
+
154
+
155
+ def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
156
+ return _make_dinov2_linear_depther(
157
+ arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
158
+ )
159
+
160
+
161
+ def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
162
+ return _make_dinov2_linear_depther(
163
+ arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
164
+ )
165
+
166
+
167
+ def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
168
+ return DPTHead(
169
+ in_channels=[embed_dim] * 4,
170
+ channels=256,
171
+ embed_dims=embed_dim,
172
+ post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
173
+ readout_type="project",
174
+ min_depth=min_depth,
175
+ max_depth=max_depth,
176
+ loss_decode=(),
177
+ )
178
+
179
+
180
+ def _make_dinov2_dpt_depther(
181
+ *,
182
+ arch_name: str = "vit_large",
183
+ pretrained: bool = True,
184
+ weights: Union[Weights, str] = Weights.NYU,
185
+ depth_range: Optional[Tuple[float, float]] = None,
186
+ **kwargs,
187
+ ):
188
+ if isinstance(weights, str):
189
+ try:
190
+ weights = Weights[weights]
191
+ except KeyError:
192
+ raise AssertionError(f"Unsupported weights: {weights}")
193
+
194
+ if depth_range is None:
195
+ depth_range = _get_depth_range(pretrained, weights)
196
+ min_depth, max_depth = depth_range
197
+
198
+ backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
199
+
200
+ model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
201
+ dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
202
+
203
+ out_index = {
204
+ "vit_small": [2, 5, 8, 11],
205
+ "vit_base": [2, 5, 8, 11],
206
+ "vit_large": [4, 11, 17, 23],
207
+ "vit_giant2": [9, 19, 29, 39],
208
+ }[arch_name]
209
+
210
+ model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
211
+ model.backbone.forward = partial(
212
+ backbone.get_intermediate_layers,
213
+ n=out_index,
214
+ reshape=True,
215
+ return_class_token=True,
216
+ norm=False,
217
+ )
218
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
219
+
220
+ if pretrained:
221
+ weights_str = weights.value.lower()
222
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
223
+ checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
224
+ if "state_dict" in checkpoint:
225
+ state_dict = checkpoint["state_dict"]
226
+ model.load_state_dict(state_dict, strict=False)
227
+
228
+ return model
229
+
230
+
231
+ def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
232
+ return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
233
+
234
+
235
+ def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
236
+ return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
237
+
238
+
239
+ def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
240
+ return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
241
+
242
+
243
+ def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
244
+ return _make_dinov2_dpt_depther(
245
+ arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
246
+ )
core/encoders/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
core/encoders/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # ******************************************************************************
7
+ # Code modified by Zexin He in 2023-2024.
8
+ # Modifications are marked with clearly visible comments
9
+ # licensed under the Apache License, Version 2.0.
10
+ # ******************************************************************************
11
+
12
+ from .dino_head import DINOHead
13
+ from .mlp import Mlp
14
+ from .patch_embed import PatchEmbed
15
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
16
+ # ********** Modified by Zexin He in 2023-2024 **********
17
+ # Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
18
+ from .block import Block, BlockWithModulation
19
+ # ********************************************************
20
+ from .attention import MemEffAttention
core/encoders/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
core/encoders/dinov2/layers/block.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ # ******************************************************************************
11
+ # Code modified by Zexin He in 2023-2024.
12
+ # Modifications are marked with clearly visible comments
13
+ # licensed under the Apache License, Version 2.0.
14
+ # ******************************************************************************
15
+
16
+ import logging
17
+ import os
18
+ from typing import Callable, List, Any, Tuple, Dict
19
+ import warnings
20
+
21
+ import torch
22
+ from torch import nn, Tensor
23
+
24
+ from .attention import Attention, MemEffAttention
25
+ from .drop_path import DropPath
26
+ from .layer_scale import LayerScale
27
+ from .mlp import Mlp
28
+
29
+
30
+ logger = logging.getLogger("dinov2")
31
+
32
+
33
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
34
+ try:
35
+ if XFORMERS_ENABLED:
36
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
37
+
38
+ XFORMERS_AVAILABLE = True
39
+ warnings.warn("xFormers is available (Block)")
40
+ else:
41
+ warnings.warn("xFormers is disabled (Block)")
42
+ raise ImportError
43
+ except ImportError:
44
+ XFORMERS_AVAILABLE = False
45
+
46
+ warnings.warn("xFormers is not available (Block)")
47
+
48
+
49
+ class Block(nn.Module):
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ num_heads: int,
54
+ mlp_ratio: float = 4.0,
55
+ qkv_bias: bool = False,
56
+ proj_bias: bool = True,
57
+ ffn_bias: bool = True,
58
+ drop: float = 0.0,
59
+ attn_drop: float = 0.0,
60
+ init_values=None,
61
+ drop_path: float = 0.0,
62
+ act_layer: Callable[..., nn.Module] = nn.GELU,
63
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
64
+ attn_class: Callable[..., nn.Module] = Attention,
65
+ ffn_layer: Callable[..., nn.Module] = Mlp,
66
+ ) -> None:
67
+ super().__init__()
68
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
69
+ self.norm1 = norm_layer(dim)
70
+ self.attn = attn_class(
71
+ dim,
72
+ num_heads=num_heads,
73
+ qkv_bias=qkv_bias,
74
+ proj_bias=proj_bias,
75
+ attn_drop=attn_drop,
76
+ proj_drop=drop,
77
+ )
78
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
79
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
80
+
81
+ self.norm2 = norm_layer(dim)
82
+ mlp_hidden_dim = int(dim * mlp_ratio)
83
+ self.mlp = ffn_layer(
84
+ in_features=dim,
85
+ hidden_features=mlp_hidden_dim,
86
+ act_layer=act_layer,
87
+ drop=drop,
88
+ bias=ffn_bias,
89
+ )
90
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
91
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
92
+
93
+ self.sample_drop_ratio = drop_path
94
+
95
+ def forward(self, x: Tensor) -> Tensor:
96
+ def attn_residual_func(x: Tensor) -> Tensor:
97
+ return self.ls1(self.attn(self.norm1(x)))
98
+
99
+ def ffn_residual_func(x: Tensor) -> Tensor:
100
+ return self.ls2(self.mlp(self.norm2(x)))
101
+
102
+ if self.training and self.sample_drop_ratio > 0.1:
103
+ # the overhead is compensated only for a drop path rate larger than 0.1
104
+ x = drop_add_residual_stochastic_depth(
105
+ x,
106
+ residual_func=attn_residual_func,
107
+ sample_drop_ratio=self.sample_drop_ratio,
108
+ )
109
+ x = drop_add_residual_stochastic_depth(
110
+ x,
111
+ residual_func=ffn_residual_func,
112
+ sample_drop_ratio=self.sample_drop_ratio,
113
+ )
114
+ elif self.training and self.sample_drop_ratio > 0.0:
115
+ x = x + self.drop_path1(attn_residual_func(x))
116
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
117
+ else:
118
+ x = x + attn_residual_func(x)
119
+ x = x + ffn_residual_func(x)
120
+ return x
121
+
122
+
123
+ # ********** Modified by Zexin He in 2023-2024 **********
124
+ # Override forward with modulation input
125
+ class BlockWithModulation(Block):
126
+ def __init__(self, *args, **kwargs) -> None:
127
+ super().__init__(*args, **kwargs)
128
+
129
+ def forward(self, x: Tensor, mod: Tensor) -> Tensor:
130
+ def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
131
+ return self.ls1(self.attn(self.norm1(x, mod)))
132
+
133
+ def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
134
+ return self.ls2(self.mlp(self.norm2(x, mod)))
135
+
136
+ if self.training and self.sample_drop_ratio > 0.1:
137
+ raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet")
138
+ elif self.training and self.sample_drop_ratio > 0.0:
139
+ x = x + self.drop_path1(attn_residual_func(x, mod))
140
+ x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2
141
+ else:
142
+ x = x + attn_residual_func(x, mod)
143
+ x = x + ffn_residual_func(x, mod)
144
+ return x
145
+ # ********************************************************
146
+
147
+
148
+ def drop_add_residual_stochastic_depth(
149
+ x: Tensor,
150
+ residual_func: Callable[[Tensor], Tensor],
151
+ sample_drop_ratio: float = 0.0,
152
+ ) -> Tensor:
153
+ # 1) extract subset using permutation
154
+ b, n, d = x.shape
155
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
156
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
157
+ x_subset = x[brange]
158
+
159
+ # 2) apply residual_func to get residual
160
+ residual = residual_func(x_subset)
161
+
162
+ x_flat = x.flatten(1)
163
+ residual = residual.flatten(1)
164
+
165
+ residual_scale_factor = b / sample_subset_size
166
+
167
+ # 3) add the residual
168
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
169
+ return x_plus_residual.view_as(x)
170
+
171
+
172
+ def get_branges_scales(x, sample_drop_ratio=0.0):
173
+ b, n, d = x.shape
174
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
175
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
176
+ residual_scale_factor = b / sample_subset_size
177
+ return brange, residual_scale_factor
178
+
179
+
180
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
181
+ if scaling_vector is None:
182
+ x_flat = x.flatten(1)
183
+ residual = residual.flatten(1)
184
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
185
+ else:
186
+ x_plus_residual = scaled_index_add(
187
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
188
+ )
189
+ return x_plus_residual
190
+
191
+
192
+ attn_bias_cache: Dict[Tuple, Any] = {}
193
+
194
+
195
+ def get_attn_bias_and_cat(x_list, branges=None):
196
+ """
197
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
198
+ """
199
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
200
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
201
+ if all_shapes not in attn_bias_cache.keys():
202
+ seqlens = []
203
+ for b, x in zip(batch_sizes, x_list):
204
+ for _ in range(b):
205
+ seqlens.append(x.shape[1])
206
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
207
+ attn_bias._batch_sizes = batch_sizes
208
+ attn_bias_cache[all_shapes] = attn_bias
209
+
210
+ if branges is not None:
211
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
212
+ else:
213
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
214
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
215
+
216
+ return attn_bias_cache[all_shapes], cat_tensors
217
+
218
+
219
+ def drop_add_residual_stochastic_depth_list(
220
+ x_list: List[Tensor],
221
+ residual_func: Callable[[Tensor, Any], Tensor],
222
+ sample_drop_ratio: float = 0.0,
223
+ scaling_vector=None,
224
+ ) -> Tensor:
225
+ # 1) generate random set of indices for dropping samples in the batch
226
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
227
+ branges = [s[0] for s in branges_scales]
228
+ residual_scale_factors = [s[1] for s in branges_scales]
229
+
230
+ # 2) get attention bias and index+concat the tensors
231
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
232
+
233
+ # 3) apply residual_func to get residual, and split the result
234
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
235
+
236
+ outputs = []
237
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
238
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
239
+ return outputs
240
+
241
+
242
+ class NestedTensorBlock(Block):
243
+
244
+ # ********** Modified by Zexin He in 2023-2024 **********
245
+ warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning)
246
+ # ********************************************************
247
+
248
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
249
+ """
250
+ x_list contains a list of tensors to nest together and run
251
+ """
252
+ assert isinstance(self.attn, MemEffAttention)
253
+
254
+ if self.training and self.sample_drop_ratio > 0.0:
255
+
256
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
257
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
258
+
259
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
260
+ return self.mlp(self.norm2(x))
261
+
262
+ x_list = drop_add_residual_stochastic_depth_list(
263
+ x_list,
264
+ residual_func=attn_residual_func,
265
+ sample_drop_ratio=self.sample_drop_ratio,
266
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
267
+ )
268
+ x_list = drop_add_residual_stochastic_depth_list(
269
+ x_list,
270
+ residual_func=ffn_residual_func,
271
+ sample_drop_ratio=self.sample_drop_ratio,
272
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
273
+ )
274
+ return x_list
275
+ else:
276
+
277
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
278
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
279
+
280
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
281
+ return self.ls2(self.mlp(self.norm2(x)))
282
+
283
+ attn_bias, x = get_attn_bias_and_cat(x_list)
284
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
285
+ x = x + ffn_residual_func(x)
286
+ return attn_bias.split(x)
287
+
288
+ def forward(self, x_or_x_list):
289
+ if isinstance(x_or_x_list, Tensor):
290
+ return super().forward(x_or_x_list)
291
+ elif isinstance(x_or_x_list, list):
292
+ if not XFORMERS_AVAILABLE:
293
+ raise AssertionError("xFormers is required for using nested tensors")
294
+ return self.forward_nested(x_or_x_list)
295
+ else:
296
+ raise AssertionError
core/encoders/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
core/encoders/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
core/encoders/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
core/encoders/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
core/encoders/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
core/encoders/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
core/encoders/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
core/encoders/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ # ******************************************************************************
11
+ # Code modified by Zexin He in 2023-2024.
12
+ # Modifications are marked with clearly visible comments
13
+ # licensed under the Apache License, Version 2.0.
14
+ # ******************************************************************************
15
+
16
+ from functools import partial
17
+ import math
18
+ import logging
19
+ from typing import Sequence, Tuple, Union, Callable
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.utils.checkpoint
24
+ from torch.nn.init import trunc_normal_
25
+
26
+ # ********** Modified by Zexin He in 2023-2024 **********
27
+ # Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
28
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation
29
+ # ********************************************************
30
+
31
+
32
+ logger = logging.getLogger("dinov2")
33
+
34
+
35
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
36
+ if not depth_first and include_root:
37
+ fn(module=module, name=name)
38
+ for child_name, child_module in module.named_children():
39
+ child_name = ".".join((name, child_name)) if name else child_name
40
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
41
+ if depth_first and include_root:
42
+ fn(module=module, name=name)
43
+ return module
44
+
45
+
46
+ class BlockChunk(nn.ModuleList):
47
+ def forward(self, x):
48
+ for b in self:
49
+ x = b(x)
50
+ return x
51
+
52
+
53
+ class DinoVisionTransformer(nn.Module):
54
+ def __init__(
55
+ self,
56
+ img_size=224,
57
+ patch_size=16,
58
+ in_chans=3,
59
+ embed_dim=768,
60
+ depth=12,
61
+ num_heads=12,
62
+ mlp_ratio=4.0,
63
+ qkv_bias=True,
64
+ ffn_bias=True,
65
+ proj_bias=True,
66
+ drop_path_rate=0.0,
67
+ drop_path_uniform=False,
68
+ init_values=None, # for layerscale: None or 0 => no layerscale
69
+ embed_layer=PatchEmbed,
70
+ act_layer=nn.GELU,
71
+ block_fn=Block,
72
+ # ********** Modified by Zexin He in 2023-2024 **********
73
+ modulation_dim: int = None,
74
+ # ********************************************************
75
+ ffn_layer="mlp",
76
+ block_chunks=1,
77
+ num_register_tokens=0,
78
+ interpolate_antialias=False,
79
+ interpolate_offset=0.1,
80
+ ):
81
+ """
82
+ Args:
83
+ img_size (int, tuple): input image size
84
+ patch_size (int, tuple): patch size
85
+ in_chans (int): number of input channels
86
+ embed_dim (int): embedding dimension
87
+ depth (int): depth of transformer
88
+ num_heads (int): number of attention heads
89
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
90
+ qkv_bias (bool): enable bias for qkv if True
91
+ proj_bias (bool): enable bias for proj in attn if True
92
+ ffn_bias (bool): enable bias for ffn if True
93
+ drop_path_rate (float): stochastic depth rate
94
+ drop_path_uniform (bool): apply uniform drop rate across blocks
95
+ weight_init (str): weight init scheme
96
+ init_values (float): layer-scale init values
97
+ embed_layer (nn.Module): patch embedding layer
98
+ act_layer (nn.Module): MLP activation layer
99
+ block_fn (nn.Module): transformer block class
100
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
101
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
102
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
103
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
104
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
105
+ """
106
+ super().__init__()
107
+
108
+ # ********** Modified by Zexin He in 2023-2024 **********
109
+ block_norm_layer = None
110
+ if modulation_dim is not None:
111
+ from ....modulate import ModLN
112
+ block_norm_layer = partial(ModLN, mod_dim=modulation_dim)
113
+ else:
114
+ block_norm_layer = nn.LayerNorm
115
+ block_norm_layer = partial(block_norm_layer, eps=1e-6)
116
+ # ********************************************************
117
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
118
+
119
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
120
+ self.num_tokens = 1
121
+ self.n_blocks = depth
122
+ self.num_heads = num_heads
123
+ self.patch_size = patch_size
124
+ self.num_register_tokens = num_register_tokens
125
+ self.interpolate_antialias = interpolate_antialias
126
+ self.interpolate_offset = interpolate_offset
127
+
128
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
129
+ num_patches = self.patch_embed.num_patches
130
+
131
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
132
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
133
+ assert num_register_tokens >= 0
134
+ self.register_tokens = (
135
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
136
+ )
137
+
138
+ if drop_path_uniform is True:
139
+ dpr = [drop_path_rate] * depth
140
+ else:
141
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
142
+
143
+ if ffn_layer == "mlp":
144
+ logger.info("using MLP layer as FFN")
145
+ ffn_layer = Mlp
146
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
147
+ logger.info("using SwiGLU layer as FFN")
148
+ ffn_layer = SwiGLUFFNFused
149
+ elif ffn_layer == "identity":
150
+ logger.info("using Identity layer as FFN")
151
+
152
+ def f(*args, **kwargs):
153
+ return nn.Identity()
154
+
155
+ ffn_layer = f
156
+ else:
157
+ raise NotImplementedError
158
+
159
+ blocks_list = [
160
+ block_fn(
161
+ dim=embed_dim,
162
+ num_heads=num_heads,
163
+ mlp_ratio=mlp_ratio,
164
+ qkv_bias=qkv_bias,
165
+ proj_bias=proj_bias,
166
+ ffn_bias=ffn_bias,
167
+ drop_path=dpr[i],
168
+ # ********** Modified by Zexin He in 2023-2024 **********
169
+ norm_layer=block_norm_layer,
170
+ # ********************************************************
171
+ act_layer=act_layer,
172
+ ffn_layer=ffn_layer,
173
+ init_values=init_values,
174
+ )
175
+ for i in range(depth)
176
+ ]
177
+ if block_chunks > 0:
178
+ self.chunked_blocks = True
179
+ chunked_blocks = []
180
+ chunksize = depth // block_chunks
181
+ for i in range(0, depth, chunksize):
182
+ # this is to keep the block index consistent if we chunk the block list
183
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
184
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
185
+ else:
186
+ self.chunked_blocks = False
187
+ self.blocks = nn.ModuleList(blocks_list)
188
+
189
+ self.norm = norm_layer(embed_dim)
190
+ self.head = nn.Identity()
191
+
192
+ # ********** Modified by Zexin He in 2023-2024 **********
193
+ # hacking unused mask_token for better DDP
194
+ # self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
195
+ # ********************************************************
196
+
197
+ self.init_weights()
198
+
199
+ def init_weights(self):
200
+ trunc_normal_(self.pos_embed, std=0.02)
201
+ nn.init.normal_(self.cls_token, std=1e-6)
202
+ if self.register_tokens is not None:
203
+ nn.init.normal_(self.register_tokens, std=1e-6)
204
+ named_apply(init_weights_vit_timm, self)
205
+
206
+ def interpolate_pos_encoding(self, x, w, h):
207
+ previous_dtype = x.dtype
208
+ npatch = x.shape[1] - 1
209
+ N = self.pos_embed.shape[1] - 1
210
+ if npatch == N and w == h:
211
+ return self.pos_embed
212
+ pos_embed = self.pos_embed.float()
213
+ class_pos_embed = pos_embed[:, 0]
214
+ patch_pos_embed = pos_embed[:, 1:]
215
+ dim = x.shape[-1]
216
+ w0 = w // self.patch_size
217
+ h0 = h // self.patch_size
218
+ # we add a small number to avoid floating point error in the interpolation
219
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
220
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
221
+
222
+ sqrt_N = math.sqrt(N)
223
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
224
+ patch_pos_embed = nn.functional.interpolate(
225
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
226
+ scale_factor=(sx, sy),
227
+ mode="bicubic",
228
+ antialias=self.interpolate_antialias,
229
+ )
230
+
231
+ assert int(w0) == patch_pos_embed.shape[-2]
232
+ assert int(h0) == patch_pos_embed.shape[-1]
233
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
234
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
235
+
236
+ def prepare_tokens_with_masks(self, x, masks=None):
237
+ B, nc, w, h = x.shape
238
+ x = self.patch_embed(x)
239
+ if masks is not None:
240
+ # ********** Modified by Zexin He in 2023-2024 **********
241
+ raise NotImplementedError("Masking is not supported in hacked DINOv2")
242
+ # x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
243
+ # ********************************************************
244
+
245
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
246
+ x = x + self.interpolate_pos_encoding(x, w, h)
247
+
248
+ if self.register_tokens is not None:
249
+ x = torch.cat(
250
+ (
251
+ x[:, :1],
252
+ self.register_tokens.expand(x.shape[0], -1, -1),
253
+ x[:, 1:],
254
+ ),
255
+ dim=1,
256
+ )
257
+
258
+ return x
259
+
260
+ def forward_features_list(self, x_list, masks_list):
261
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
262
+ for blk in self.blocks:
263
+ x = blk(x)
264
+
265
+ all_x = x
266
+ output = []
267
+ for x, masks in zip(all_x, masks_list):
268
+ x_norm = self.norm(x)
269
+ output.append(
270
+ {
271
+ "x_norm_clstoken": x_norm[:, 0],
272
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
273
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
274
+ "x_prenorm": x,
275
+ "masks": masks,
276
+ }
277
+ )
278
+ return output
279
+
280
+ # ********** Modified by Zexin He in 2023-2024 **********
281
+ def forward_features(self, x, masks=None, mod=None):
282
+ if isinstance(x, list):
283
+ raise DeprecationWarning("forward_features_list is deprecated, use forward_features")
284
+ return self.forward_features_list(x, masks)
285
+
286
+ x = self.prepare_tokens_with_masks(x, masks)
287
+
288
+ if mod is None:
289
+ for blk in self.blocks:
290
+ x = blk(x)
291
+ else:
292
+ for blk in self.blocks:
293
+ x = blk(x, mod)
294
+
295
+ x_norm = self.norm(x)
296
+ return {
297
+ "x_norm_clstoken": x_norm[:, 0],
298
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
299
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
300
+ "x_prenorm": x,
301
+ "masks": masks,
302
+ }
303
+ # ********************************************************
304
+
305
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
306
+ x = self.prepare_tokens_with_masks(x)
307
+ # If n is an int, take the n last blocks. If it's a list, take them
308
+ output, total_block_len = [], len(self.blocks)
309
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
310
+ for i, blk in enumerate(self.blocks):
311
+ x = blk(x)
312
+ if i in blocks_to_take:
313
+ output.append(x)
314
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
315
+ return output
316
+
317
+ def _get_intermediate_layers_chunked(self, x, n=1):
318
+ x = self.prepare_tokens_with_masks(x)
319
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
320
+ # If n is an int, take the n last blocks. If it's a list, take them
321
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
322
+ for block_chunk in self.blocks:
323
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
324
+ x = blk(x)
325
+ if i in blocks_to_take:
326
+ output.append(x)
327
+ i += 1
328
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
329
+ return output
330
+
331
+ def get_intermediate_layers(
332
+ self,
333
+ x: torch.Tensor,
334
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
335
+ reshape: bool = False,
336
+ return_class_token: bool = False,
337
+ norm=True,
338
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
339
+ if self.chunked_blocks:
340
+ outputs = self._get_intermediate_layers_chunked(x, n)
341
+ else:
342
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
343
+ if norm:
344
+ outputs = [self.norm(out) for out in outputs]
345
+ class_tokens = [out[:, 0] for out in outputs]
346
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
347
+ if reshape:
348
+ B, _, w, h = x.shape
349
+ outputs = [
350
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
351
+ for out in outputs
352
+ ]
353
+ if return_class_token:
354
+ return tuple(zip(outputs, class_tokens))
355
+ return tuple(outputs)
356
+
357
+ def forward(self, *args, is_training=False, **kwargs):
358
+ ret = self.forward_features(*args, **kwargs)
359
+ if is_training:
360
+ return ret
361
+ else:
362
+ return self.head(ret["x_norm_clstoken"])
363
+
364
+
365
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
366
+ """ViT weight initialization, original timm impl (for reproducibility)"""
367
+ if isinstance(module, nn.Linear):
368
+ trunc_normal_(module.weight, std=0.02)
369
+ if module.bias is not None:
370
+ nn.init.zeros_(module.bias)
371
+
372
+
373
+ # ********** Modified by Zexin He in 2023-2024 **********
374
+ # block class selected from Block and BlockWithModulation
375
+
376
+ def _block_cls(**kwargs):
377
+ modulation_dim = kwargs.get("modulation_dim", None)
378
+ if modulation_dim is None:
379
+ block_cls = Block
380
+ else:
381
+ block_cls = BlockWithModulation
382
+ return block_cls
383
+
384
+
385
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
386
+ model = DinoVisionTransformer(
387
+ patch_size=patch_size,
388
+ embed_dim=384,
389
+ depth=12,
390
+ num_heads=6,
391
+ mlp_ratio=4,
392
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
393
+ num_register_tokens=num_register_tokens,
394
+ **kwargs,
395
+ )
396
+ return model
397
+
398
+
399
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
400
+ model = DinoVisionTransformer(
401
+ patch_size=patch_size,
402
+ embed_dim=768,
403
+ depth=12,
404
+ num_heads=12,
405
+ mlp_ratio=4,
406
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
407
+ num_register_tokens=num_register_tokens,
408
+ **kwargs,
409
+ )
410
+ return model
411
+
412
+
413
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
414
+ model = DinoVisionTransformer(
415
+ patch_size=patch_size,
416
+ embed_dim=1024,
417
+ depth=24,
418
+ num_heads=16,
419
+ mlp_ratio=4,
420
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
421
+ num_register_tokens=num_register_tokens,
422
+ **kwargs,
423
+ )
424
+ return model
425
+
426
+
427
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
428
+ """
429
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
430
+ """
431
+ model = DinoVisionTransformer(
432
+ patch_size=patch_size,
433
+ embed_dim=1536,
434
+ depth=40,
435
+ num_heads=24,
436
+ mlp_ratio=4,
437
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
438
+ num_register_tokens=num_register_tokens,
439
+ **kwargs,
440
+ )
441
+ return model
442
+
443
+ # ********************************************************
core/encoders/dinov2_wrapper.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ # from accelerate.logging import get_logger
19
+
20
+
21
+ # logger = get_logger(__name__)
22
+
23
+
24
+ class Dinov2Wrapper(nn.Module):
25
+ """
26
+ Dino v2 wrapper using original implementation, hacked with modulation.
27
+ """
28
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True):
29
+ super().__init__()
30
+ self.modulation_dim = modulation_dim
31
+ self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
32
+ if freeze:
33
+ if modulation_dim is not None:
34
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
35
+ self._freeze()
36
+
37
+ def _freeze(self):
38
+ #logger.warning(f"======== Freezing Dinov2Wrapper ========")
39
+ self.model.eval()
40
+ for name, param in self.model.named_parameters():
41
+ param.requires_grad = False
42
+
43
+ @staticmethod
44
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
45
+ from importlib import import_module
46
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
47
+ model_fn = getattr(dinov2_hub, model_name)
48
+ #logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
49
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
50
+ return model
51
+
52
+ #@torch.compile
53
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
54
+ # image: [N, C, H, W]
55
+ # mod: [N, D] or None
56
+ # RGB image with [0,1] scale and properly sized
57
+ if self.modulation_dim is None:
58
+ assert mod is None, "Unexpected modulation input in dinov2 forward."
59
+ outs = self.model(image, is_training=True)
60
+ else:
61
+ assert mod is not None, "Modulation input is required in modulated dinov2 forward."
62
+ outs = self.model(image, mod=mod, is_training=True)
63
+ ret = torch.cat([
64
+ outs["x_norm_clstoken"].unsqueeze(dim=1),
65
+ outs["x_norm_patchtokens"],
66
+ ], dim=1)
67
+ return ret
core/geometry/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
core/geometry/camera/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class Camera(nn.Module):
14
+ def __init__(self):
15
+ super(Camera, self).__init__()
16
+ pass
core/geometry/camera/perspective_camera.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from . import Camera
11
+ import numpy as np
12
+
13
+
14
+
15
+ def projection(fovy, n=1.0, f=50.0, near_plane=None):
16
+ focal = np.tan(fovy / 180.0 * np.pi * 0.5)
17
+ if near_plane is None:
18
+ near_plane = n
19
+ return np.array(
20
+ [[n / focal, 0, 0, 0],
21
+ [0, n / -focal, 0, 0],
22
+ [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
23
+ [0, 0, -1, 0]]).astype(np.float32)
24
+
25
+ def projection_2(opt):
26
+ zfar= opt.zfar
27
+ znear= opt.znear
28
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
29
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
30
+ proj_matrix[0, 0] = 1 / tan_half_fov
31
+ proj_matrix[1, 1] = 1 / tan_half_fov
32
+ proj_matrix[2, 2] = (zfar + znear) / (zfar - znear)
33
+ proj_matrix[3, 2] = - (zfar * znear) / (zfar - znear)
34
+ proj_matrix[2, 3] = 1
35
+
36
+ return proj_matrix
37
+
38
+
39
+ class PerspectiveCamera(Camera):
40
+ def __init__(self, opt, device='cuda'):
41
+ super(PerspectiveCamera, self).__init__()
42
+ self.device = device
43
+ self.proj_mtx = torch.from_numpy(projection(opt.fovy, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
44
+ #self.proj_mtx= projection_2(opt).to(self.device).unsqueeze(dim=0)
45
+
46
+
47
+ def project(self, points_bxnx4):
48
+ out = torch.matmul(
49
+ points_bxnx4,
50
+ torch.transpose(self.proj_mtx, 1, 2))
51
+ return out
core/geometry/render/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Renderer():
4
+ def __init__(self):
5
+ pass
6
+
7
+ def forward(self):
8
+ pass
core/geometry/render/neural_render.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import nvdiffrast.torch as dr
12
+ from . import Renderer
13
+
14
+ _FG_LUT = None
15
+
16
+
17
+ def interpolate(attr, rast, attr_idx, rast_db=None):
18
+ return dr.interpolate(
19
+ attr.contiguous(), rast, attr_idx, rast_db=rast_db,
20
+ diff_attrs=None if rast_db is None else 'all')
21
+
22
+
23
+ def xfm_points(points, matrix, use_python=True):
24
+ '''Transform points.
25
+ Args:
26
+ points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
27
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
28
+ use_python: Use PyTorch's torch.matmul (for validation)
29
+ Returns:
30
+ Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
31
+ '''
32
+ out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
33
+ if torch.is_anomaly_enabled():
34
+ assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
35
+ return out
36
+
37
+
38
+ def dot(x, y):
39
+ return torch.sum(x * y, -1, keepdim=True)
40
+
41
+
42
+ def compute_vertex_normal(v_pos, t_pos_idx):
43
+ i0 = t_pos_idx[:, 0]
44
+ i1 = t_pos_idx[:, 1]
45
+ i2 = t_pos_idx[:, 2]
46
+
47
+ v0 = v_pos[i0, :]
48
+ v1 = v_pos[i1, :]
49
+ v2 = v_pos[i2, :]
50
+
51
+ face_normals = torch.cross(v1 - v0, v2 - v0)
52
+
53
+ # Splat face normals to vertices
54
+ v_nrm = torch.zeros_like(v_pos)
55
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
56
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
57
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
58
+
59
+ # Normalize, replace zero (degenerated) normals with some default value
60
+ v_nrm = torch.where(
61
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
62
+ )
63
+ v_nrm = F.normalize(v_nrm, dim=1)
64
+ assert torch.all(torch.isfinite(v_nrm))
65
+
66
+ return v_nrm
67
+
68
+
69
+ class NeuralRender(Renderer):
70
+ def __init__(self, device='cuda', camera_model=None):
71
+ super(NeuralRender, self).__init__()
72
+ self.device = device
73
+ self.ctx = dr.RasterizeCudaContext(device=device)
74
+ self.projection_mtx = None
75
+ self.camera = camera_model
76
+
77
+ def render_mesh(
78
+ self,
79
+ mesh_v_pos_bxnx3,
80
+ mesh_t_pos_idx_fx3,
81
+ camera_mv_bx4x4,
82
+ mesh_v_feat_bxnxd,
83
+ resolution=256,
84
+ spp=1,
85
+ device='cuda',
86
+ hierarchical_mask=False
87
+ ):
88
+ assert not hierarchical_mask
89
+
90
+ mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
91
+ v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
92
+ v_pos_clip = self.camera.project(v_pos) # Projection in the camera
93
+
94
+ v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
95
+
96
+ # Render the image,
97
+ # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
98
+ num_layers = 1
99
+ mask_pyramid = None
100
+ assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
101
+ mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
102
+
103
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
104
+ for _ in range(num_layers):
105
+ rast, db = peeler.rasterize_next_layer()
106
+ gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
107
+
108
+ hard_mask = torch.clamp(rast[..., -1:], 0, 1)
109
+ antialias_mask = dr.antialias(
110
+ hard_mask.clone().contiguous(), rast, v_pos_clip,
111
+ mesh_t_pos_idx_fx3)
112
+
113
+ depth = gb_feat[..., -2:-1]
114
+ ori_mesh_feature = gb_feat[..., :-4]
115
+
116
+ normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
117
+ normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
118
+ normal = F.normalize(normal, dim=-1)
119
+ normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
120
+
121
+ return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
core/geometry/rep_3d/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+
12
+
13
+ class Geometry():
14
+ def __init__(self):
15
+ pass
16
+
17
+ def forward(self):
18
+ pass
core/geometry/rep_3d/dmtet.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+ import os
12
+ from . import Geometry
13
+ from .dmtet_utils import get_center_boundary_index
14
+ import torch.nn.functional as F
15
+
16
+
17
+ ###############################################################################
18
+ # DMTet utility functions
19
+ ###############################################################################
20
+ def create_mt_variable(device):
21
+ triangle_table = torch.tensor(
22
+ [
23
+ [-1, -1, -1, -1, -1, -1],
24
+ [1, 0, 2, -1, -1, -1],
25
+ [4, 0, 3, -1, -1, -1],
26
+ [1, 4, 2, 1, 3, 4],
27
+ [3, 1, 5, -1, -1, -1],
28
+ [2, 3, 0, 2, 5, 3],
29
+ [1, 4, 0, 1, 5, 4],
30
+ [4, 2, 5, -1, -1, -1],
31
+ [4, 5, 2, -1, -1, -1],
32
+ [4, 1, 0, 4, 5, 1],
33
+ [3, 2, 0, 3, 5, 2],
34
+ [1, 3, 5, -1, -1, -1],
35
+ [4, 1, 2, 4, 3, 1],
36
+ [3, 0, 4, -1, -1, -1],
37
+ [2, 0, 1, -1, -1, -1],
38
+ [-1, -1, -1, -1, -1, -1]
39
+ ], dtype=torch.long, device=device)
40
+
41
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
42
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
43
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
44
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
45
+
46
+
47
+ def sort_edges(edges_ex2):
48
+ with torch.no_grad():
49
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
50
+ order = order.unsqueeze(dim=1)
51
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
52
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
53
+ return torch.stack([a, b], -1)
54
+
55
+
56
+ ###############################################################################
57
+ # marching tetrahedrons (differentiable)
58
+ ###############################################################################
59
+
60
+ def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id):
61
+ with torch.no_grad():
62
+ occ_n = sdf_n > 0
63
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
64
+ occ_sum = torch.sum(occ_fx4, -1)
65
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
66
+ occ_sum = occ_sum[valid_tets]
67
+
68
+ # find all vertices
69
+ all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
70
+ all_edges = sort_edges(all_edges)
71
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
72
+
73
+ unique_edges = unique_edges.long()
74
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
75
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
76
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
77
+ idx_map = mapping[idx_map] # map edges to verts
78
+
79
+ interp_v = unique_edges[mask_edges] # .long()
80
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
81
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
82
+ edges_to_interp_sdf[:, -1] *= -1
83
+
84
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
85
+
86
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
87
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
88
+
89
+ idx_map = idx_map.reshape(-1, 6)
90
+
91
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
92
+ num_triangles = num_triangles_table[tetindex]
93
+
94
+ # Generate triangle indices
95
+ faces = torch.cat(
96
+ (
97
+ torch.gather(
98
+ input=idx_map[num_triangles == 1], dim=1,
99
+ index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
100
+ torch.gather(
101
+ input=idx_map[num_triangles == 2], dim=1,
102
+ index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
103
+ ), dim=0)
104
+ return verts, faces
105
+
106
+
107
+ def create_tetmesh_variables(device='cuda'):
108
+ tet_table = torch.tensor(
109
+ [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
110
+ [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
111
+ [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
112
+ [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
113
+ [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
114
+ [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
115
+ [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
116
+ [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
117
+ [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
118
+ [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
119
+ [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
120
+ [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
121
+ [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
122
+ [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
123
+ [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
124
+ [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
125
+ num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)
126
+ return tet_table, num_tets_table
127
+
128
+
129
+ def marching_tets_tetmesh(
130
+ pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
131
+ return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
132
+ with torch.no_grad():
133
+ occ_n = sdf_n > 0
134
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
135
+ occ_sum = torch.sum(occ_fx4, -1)
136
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
137
+ occ_sum = occ_sum[valid_tets]
138
+
139
+ # find all vertices
140
+ all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
141
+ all_edges = sort_edges(all_edges)
142
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
143
+
144
+ unique_edges = unique_edges.long()
145
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
146
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
147
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
148
+ idx_map = mapping[idx_map] # map edges to verts
149
+
150
+ interp_v = unique_edges[mask_edges] # .long()
151
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
152
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
153
+ edges_to_interp_sdf[:, -1] *= -1
154
+
155
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
156
+
157
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
158
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
159
+
160
+ idx_map = idx_map.reshape(-1, 6)
161
+
162
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
163
+ num_triangles = num_triangles_table[tetindex]
164
+
165
+ # Generate triangle indices
166
+ faces = torch.cat(
167
+ (
168
+ torch.gather(
169
+ input=idx_map[num_triangles == 1], dim=1,
170
+ index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
171
+ torch.gather(
172
+ input=idx_map[num_triangles == 2], dim=1,
173
+ index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
174
+ ), dim=0)
175
+ if not return_tet_mesh:
176
+ return verts, faces
177
+ occupied_verts = ori_v[occ_n]
178
+ mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1
179
+ mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda")
180
+ tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4))
181
+
182
+ idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10
183
+ tet_verts = torch.cat([verts, occupied_verts], 0)
184
+ num_tets = num_tets_table[tetindex]
185
+
186
+ tets = torch.cat(
187
+ (
188
+ torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape(
189
+ -1,
190
+ 4),
191
+ torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape(
192
+ -1,
193
+ 4),
194
+ ), dim=0)
195
+ # add fully occupied tets
196
+ fully_occupied = occ_fx4.sum(-1) == 4
197
+ tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
198
+ tets = torch.cat([tets, tet_fully_occupied])
199
+
200
+ return verts, faces, tet_verts, tets
201
+
202
+
203
+ ###############################################################################
204
+ # Compact tet grid
205
+ ###############################################################################
206
+
207
+ def compact_tets(pos_nx3, sdf_n, tet_fx4):
208
+ with torch.no_grad():
209
+ # Find surface tets
210
+ occ_n = sdf_n > 0
211
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
212
+ occ_sum = torch.sum(occ_fx4, -1)
213
+ valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets
214
+
215
+ valid_vtx = tet_fx4[valid_tets].reshape(-1)
216
+ unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
217
+ new_pos = pos_nx3[unique_vtx]
218
+ new_sdf = sdf_n[unique_vtx]
219
+ new_tets = idx_map.reshape(-1, 4)
220
+ return new_pos, new_sdf, new_tets
221
+
222
+
223
+ ###############################################################################
224
+ # Subdivide volume
225
+ ###############################################################################
226
+
227
+ def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
228
+ device = tet_pos_bxnx3.device
229
+ # get new verts
230
+ tet_fx4 = tet_bxfx4[0]
231
+ edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
232
+ all_edges = tet_fx4[:, edges].reshape(-1, 2)
233
+ all_edges = sort_edges(all_edges)
234
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
235
+ idx_map = idx_map + tet_pos_bxnx3.shape[1]
236
+ all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1)
237
+ mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
238
+ all_values.shape[0], -1, 2,
239
+ all_values.shape[-1]).mean(2)
240
+ new_v = torch.cat([all_values, mid_points_pos], 1)
241
+ new_v, new_sdf = new_v[..., :3], new_v[..., 3]
242
+
243
+ # get new tets
244
+
245
+ idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
246
+ idx_ab = idx_map[0::6]
247
+ idx_ac = idx_map[1::6]
248
+ idx_ad = idx_map[2::6]
249
+ idx_bc = idx_map[3::6]
250
+ idx_bd = idx_map[4::6]
251
+ idx_cd = idx_map[5::6]
252
+
253
+ tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
254
+ tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
255
+ tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
256
+ tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
257
+ tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
258
+ tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
259
+ tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
260
+ tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
261
+
262
+ tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
263
+ tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
264
+ tet = tet_np.long().to(device)
265
+
266
+ return new_v, tet, new_sdf
267
+
268
+
269
+ ###############################################################################
270
+ # Adjacency
271
+ ###############################################################################
272
+ def tet_to_tet_adj_sparse(tet_tx4):
273
+ # include self connection!!!!!!!!!!!!!!!!!!!
274
+ with torch.no_grad():
275
+ t = tet_tx4.shape[0]
276
+ device = tet_tx4.device
277
+ idx_array = torch.LongTensor(
278
+ [0, 1, 2,
279
+ 1, 0, 3,
280
+ 2, 3, 0,
281
+ 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3)
282
+
283
+ # get all faces
284
+ all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape(
285
+ -1,
286
+ 3) # (tx4, 3)
287
+ all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1)
288
+ # sort and group
289
+ all_faces_sorted, _ = torch.sort(all_faces, dim=1)
290
+
291
+ all_faces_unique, inverse_indices, counts = torch.unique(
292
+ all_faces_sorted, dim=0, return_counts=True,
293
+ return_inverse=True)
294
+ tet_face_fx3 = all_faces_unique[counts == 2]
295
+ counts = counts[inverse_indices] # tx4
296
+ valid = (counts == 2)
297
+
298
+ group = inverse_indices[valid]
299
+ # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape)
300
+ _, indices = torch.sort(group)
301
+ all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices]
302
+ tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1)
303
+
304
+ tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])])
305
+ adj_self = torch.arange(t, device=tet_tx4.device)
306
+ adj_self = torch.stack([adj_self, adj_self], -1)
307
+ tet_adj_idx = torch.cat([tet_adj_idx, adj_self])
308
+
309
+ tet_adj_idx = torch.unique(tet_adj_idx, dim=0)
310
+ values = torch.ones(
311
+ tet_adj_idx.shape[0], device=tet_tx4.device).float()
312
+ adj_sparse = torch.sparse.FloatTensor(
313
+ tet_adj_idx.t(), values, torch.Size([t, t]))
314
+
315
+ # normalization
316
+ neighbor_num = 1.0 / torch.sparse.sum(
317
+ adj_sparse, dim=1).to_dense()
318
+ values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0])
319
+ adj_sparse = torch.sparse.FloatTensor(
320
+ tet_adj_idx.t(), values, torch.Size([t, t]))
321
+ return adj_sparse
322
+
323
+
324
+ ###############################################################################
325
+ # Compact grid
326
+ ###############################################################################
327
+
328
+ def get_tet_bxfx4x3(bxnxz, bxfx4):
329
+ n_batch, z = bxnxz.shape[0], bxnxz.shape[2]
330
+ gather_input = bxnxz.unsqueeze(2).expand(
331
+ n_batch, bxnxz.shape[1], 4, z)
332
+ gather_index = bxfx4.unsqueeze(-1).expand(
333
+ n_batch, bxfx4.shape[1], 4, z).long()
334
+ tet_bxfx4xz = torch.gather(
335
+ input=gather_input, dim=1, index=gather_index)
336
+
337
+ return tet_bxfx4xz
338
+
339
+
340
+ def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
341
+ with torch.no_grad():
342
+ assert tet_pos_bxnx3.shape[0] == 1
343
+
344
+ occ = grid_sdf[0] > 0
345
+ occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1)
346
+ mask = (occ_sum > 0) & (occ_sum < 4)
347
+
348
+ # build connectivity graph
349
+ adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0])
350
+ mask = mask.float().unsqueeze(-1)
351
+
352
+ # Include a one ring of neighbors
353
+ for i in range(1):
354
+ mask = torch.sparse.mm(adj_matrix, mask)
355
+ mask = mask.squeeze(-1) > 0
356
+
357
+ mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long)
358
+ new_tet_bxfx4 = tet_bxfx4[:, mask].long()
359
+ selected_verts_idx = torch.unique(new_tet_bxfx4)
360
+ new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx]
361
+ mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device)
362
+ new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape)
363
+ new_grid_sdf = grid_sdf[:, selected_verts_idx]
364
+ return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf
365
+
366
+
367
+ ###############################################################################
368
+ # Regularizer
369
+ ###############################################################################
370
+
371
+ def sdf_reg_loss(sdf, all_edges):
372
+ sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
373
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
374
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
375
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
376
+ sdf_f1x6x2[..., 0],
377
+ (sdf_f1x6x2[..., 1] > 0).float()) + \
378
+ torch.nn.functional.binary_cross_entropy_with_logits(
379
+ sdf_f1x6x2[..., 1],
380
+ (sdf_f1x6x2[..., 0] > 0).float())
381
+ return sdf_diff
382
+
383
+
384
+ def sdf_reg_loss_batch(sdf, all_edges):
385
+ sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
386
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
387
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
388
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
389
+ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
390
+ return sdf_diff
391
+
392
+
393
+ ###############################################################################
394
+ # Geometry interface
395
+ ###############################################################################
396
+ class DMTetGeometry(Geometry):
397
+ def __init__(
398
+ self, grid_res=64, scale=2.0, device='cuda', renderer=None,
399
+ render_type='neural_render', args=None):
400
+ super(DMTetGeometry, self).__init__()
401
+ self.grid_res = grid_res
402
+ self.device = device
403
+ self.args = args
404
+ tets = np.load('data/tets/%d_compress.npz' % (grid_res))
405
+ self.verts = torch.from_numpy(tets['vertices']).float().to(self.device)
406
+ # Make sure the tet is zero-centered and length is equal to 1
407
+ length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
408
+ length = length.max()
409
+ mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0
410
+ self.verts = (self.verts - mid.unsqueeze(dim=0)) / length
411
+ if isinstance(scale, list):
412
+ self.verts[:, 0] = self.verts[:, 0] * scale[0]
413
+ self.verts[:, 1] = self.verts[:, 1] * scale[1]
414
+ self.verts[:, 2] = self.verts[:, 2] * scale[1]
415
+ else:
416
+ self.verts = self.verts * scale
417
+ self.indices = torch.from_numpy(tets['tets']).long().to(self.device)
418
+ self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device)
419
+ self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device)
420
+ # Parameters for regularization computation
421
+ edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
422
+ all_edges = self.indices[:, edges].reshape(-1, 2)
423
+ all_edges_sorted = torch.sort(all_edges, dim=1)[0]
424
+ self.all_edges = torch.unique(all_edges_sorted, dim=0)
425
+
426
+ # Parameters used for fix boundary sdf
427
+ self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
428
+ self.renderer = renderer
429
+ self.render_type = render_type
430
+
431
+ def getAABB(self):
432
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
433
+
434
+ def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
435
+ if indices is None:
436
+ indices = self.indices
437
+ verts, faces = marching_tets(
438
+ v_deformed_nx3, sdf_n, indices, self.triangle_table,
439
+ self.num_triangles_table, self.base_tet_edges, self.v_id)
440
+ faces = torch.cat(
441
+ [faces[:, 0:1],
442
+ faces[:, 2:3],
443
+ faces[:, 1:2], ], dim=-1)
444
+ return verts, faces
445
+
446
+ def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
447
+ if indices is None:
448
+ indices = self.indices
449
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(
450
+ v_deformed_nx3, sdf_n, indices, self.triangle_table,
451
+ self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True,
452
+ num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3)
453
+ faces = torch.cat(
454
+ [faces[:, 0:1],
455
+ faces[:, 2:3],
456
+ faces[:, 1:2], ], dim=-1)
457
+ return verts, faces, tet_verts, tets
458
+
459
+ def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
460
+ return_value = dict()
461
+ if self.render_type == 'neural_render':
462
+ tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
463
+ mesh_v_nx3.unsqueeze(dim=0),
464
+ mesh_f_fx3.int(),
465
+ camera_mv_bx4x4,
466
+ mesh_v_nx3.unsqueeze(dim=0),
467
+ resolution=resolution,
468
+ device=self.device,
469
+ hierarchical_mask=hierarchical_mask
470
+ )
471
+
472
+ return_value['tex_pos'] = tex_pos
473
+ return_value['mask'] = mask
474
+ return_value['hard_mask'] = hard_mask
475
+ return_value['rast'] = rast
476
+ return_value['v_pos_clip'] = v_pos_clip
477
+ return_value['mask_pyramid'] = mask_pyramid
478
+ return_value['depth'] = depth
479
+ else:
480
+ raise NotImplementedError
481
+
482
+ return return_value
483
+
484
+ def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
485
+ # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
486
+ v_list = []
487
+ f_list = []
488
+ n_batch = v_deformed_bxnx3.shape[0]
489
+ all_render_output = []
490
+ for i_batch in range(n_batch):
491
+ verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
492
+ v_list.append(verts_nx3)
493
+ f_list.append(faces_fx3)
494
+ render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
495
+ all_render_output.append(render_output)
496
+
497
+ # Concatenate all render output
498
+ return_keys = all_render_output[0].keys()
499
+ return_value = dict()
500
+ for k in return_keys:
501
+ value = [v[k] for v in all_render_output]
502
+ return_value[k] = value
503
+ # We can do concatenation outside of the render
504
+ return return_value
core/geometry/rep_3d/dmtet_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+
11
+
12
+ def get_center_boundary_index(verts):
13
+ length_ = torch.sum(verts ** 2, dim=-1)
14
+ center_idx = torch.argmin(length_)
15
+ boundary_neg = verts == verts.max()
16
+ boundary_pos = verts == verts.min()
17
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
18
+ boundary = torch.sum(boundary.float(), dim=-1)
19
+ boundary_idx = torch.nonzero(boundary)
20
+ return center_idx, boundary_idx.squeeze(dim=-1)
core/geometry/rep_3d/extract_texture_map.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import xatlas
11
+ import numpy as np
12
+ import nvdiffrast.torch as dr
13
+
14
+
15
+ # ==============================================================================================
16
+ def interpolate(attr, rast, attr_idx, rast_db=None):
17
+ return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
18
+
19
+
20
+ def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
21
+ vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
22
+
23
+ # Convert to tensors
24
+ indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
25
+
26
+ uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
27
+ mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
28
+ # mesh_v_tex. ture
29
+ uv_clip = uvs[None, ...] * 2.0 - 1.0
30
+
31
+ # pad to four component coordinate
32
+ uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
33
+
34
+ # rasterize
35
+ rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
36
+
37
+ # Interpolate world space position
38
+ gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
39
+ mask = rast[..., 3:4] > 0
40
+ return uvs, mesh_tex_idx, gb_pos, mask
core/geometry/rep_3d/flexicubes.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ import torch
9
+ from .tables import *
10
+
11
+ __all__ = [
12
+ 'FlexiCubes'
13
+ ]
14
+
15
+
16
+ class FlexiCubes:
17
+ """
18
+ This class implements the FlexiCubes method for extracting meshes from scalar fields.
19
+ It maintains a series of lookup tables and indices to support the mesh extraction process.
20
+ FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
21
+ the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
22
+ the surface representation through gradient-based optimization.
23
+
24
+ During instantiation, the class loads DMC tables from a file and transforms them into
25
+ PyTorch tensors on the specified device.
26
+
27
+ Attributes:
28
+ device (str): Specifies the computational device (default is "cuda").
29
+ dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
30
+ associated with each dual vertex in 256 Marching Cubes (MC) configurations.
31
+ num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
32
+ the 256 MC configurations.
33
+ check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
34
+ of the DMC configurations.
35
+ tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
36
+ quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
37
+ along one diagonal.
38
+ quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
39
+ two triangles along the other diagonal.
40
+ quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
41
+ during training by connecting all edges to their midpoints.
42
+ cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
43
+ eight corners in 3D space, ordered starting from the origin (0,0,0),
44
+ moving along the x-axis, then y-axis, and finally z-axis.
45
+ Used as a blueprint for generating a voxel grid.
46
+ cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
47
+ to retrieve the case id.
48
+ cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
49
+ Used to retrieve edge vertices in DMC.
50
+ edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
51
+ their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
52
+ first edge is oriented along the x-axis.
53
+ dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
54
+ across four adjacent cubes to the shared faces of these cubes. For instance,
55
+ dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
56
+ the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
57
+ This tensor is only utilized during isosurface tetrahedralization.
58
+ adj_pairs (torch.Tensor):
59
+ A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
60
+ qef_reg_scale (float):
61
+ The scaling factor applied to the regularization loss to prevent issues with singularity
62
+ when solving the QEF. This parameter is only used when a 'grad_func' is specified.
63
+ weight_scale (float):
64
+ The scale of weights in FlexiCubes. Should be between 0 and 1.
65
+ """
66
+
67
+ def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
68
+
69
+ self.device = device
70
+ self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
71
+ self.num_vd_table = torch.tensor(num_vd_table,
72
+ dtype=torch.long, device=device, requires_grad=False)
73
+ self.check_table = torch.tensor(
74
+ check_table,
75
+ dtype=torch.long, device=device, requires_grad=False)
76
+
77
+ self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
78
+ self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
79
+ self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
80
+ self.quad_split_train = torch.tensor(
81
+ [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
82
+
83
+ self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
84
+ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
85
+ self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
86
+ self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
87
+ 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
88
+
89
+ self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
90
+ dtype=torch.long, device=device)
91
+ self.dir_faces_table = torch.tensor([
92
+ [[5, 4], [3, 2], [4, 5], [2, 3]],
93
+ [[5, 4], [1, 0], [4, 5], [0, 1]],
94
+ [[3, 2], [1, 0], [2, 3], [0, 1]]
95
+ ], dtype=torch.long, device=device)
96
+ self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
97
+ self.qef_reg_scale = qef_reg_scale
98
+ self.weight_scale = weight_scale
99
+
100
+ def construct_voxel_grid(self, res):
101
+ """
102
+ Generates a voxel grid based on the specified resolution.
103
+
104
+ Args:
105
+ res (int or list[int]): The resolution of the voxel grid. If an integer
106
+ is provided, it is used for all three dimensions. If a list or tuple
107
+ of 3 integers is provided, they define the resolution for the x,
108
+ y, and z dimensions respectively.
109
+
110
+ Returns:
111
+ (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
112
+ cube corners (index into vertices) of the constructed voxel grid.
113
+ The vertices are centered at the origin, with the length of each
114
+ dimension in the grid being one.
115
+ """
116
+ base_cube_f = torch.arange(8).to(self.device)
117
+ if isinstance(res, int):
118
+ res = (res, res, res)
119
+ voxel_grid_template = torch.ones(res, device=self.device)
120
+
121
+ res = torch.tensor([res], dtype=torch.float, device=self.device)
122
+ coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
123
+ verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
124
+ cubes = (base_cube_f.unsqueeze(0) +
125
+ torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
126
+
127
+ verts_rounded = torch.round(verts * 10**5) / (10**5)
128
+ verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
129
+ cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
130
+
131
+ return verts_unique - 0.5, cubes
132
+
133
+ def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
134
+ gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
135
+ r"""
136
+ Main function for mesh extraction from scalar field using FlexiCubes. This function converts
137
+ discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
138
+ to triangle or tetrahedral meshes using a differentiable operation as described in
139
+ `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
140
+ mesh quality and geometric fidelity by adjusting the surface representation based on gradient
141
+ optimization. The output surface is differentiable with respect to the input vertex positions,
142
+ scalar field values, and weight parameters.
143
+
144
+ If you intend to extract a surface mesh from a fixed Signed Distance Field without the
145
+ optimization of parameters, it is suggested to provide the "grad_func" which should
146
+ return the surface gradient at any given 3D position. When grad_func is provided, the process
147
+ to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
148
+ described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
149
+ Please note, this approach is non-differentiable.
150
+
151
+ For more details and example usage in optimization, refer to the
152
+ `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
153
+
154
+ Args:
155
+ x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
156
+ s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
157
+ denote that the corresponding vertex resides inside the isosurface. This affects
158
+ the directions of the extracted triangle faces and volume to be tetrahedralized.
159
+ cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
160
+ res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
161
+ is used for all three dimensions. If a list or tuple of 3 integers is provided, they
162
+ specify the resolution for the x, y, and z dimensions respectively.
163
+ beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
164
+ vertices positioning. Defaults to uniform value for all edges.
165
+ alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
166
+ vertices positioning. Defaults to uniform value for all vertices.
167
+ gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
168
+ quadrilaterals into triangles. Defaults to uniform value for all cubes.
169
+ training (bool, optional): If set to True, applies differentiable quad splitting for
170
+ training. Defaults to False.
171
+ output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
172
+ outputs a triangular mesh. Defaults to False.
173
+ grad_func (callable, optional): A function to compute the surface gradient at specified
174
+ 3D positions (input: Nx3 positions). The function should return gradients as an Nx3
175
+ tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
176
+
177
+ Returns:
178
+ (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
179
+ - Vertices for the extracted triangular/tetrahedral mesh.
180
+ - Faces for the extracted triangular/tetrahedral mesh.
181
+ - Regularizer L_dev, computed per dual vertex.
182
+
183
+ .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
184
+ https://research.nvidia.com/labs/toronto-ai/flexicubes/
185
+ .. _Manifold Dual Contouring:
186
+ https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
187
+ """
188
+
189
+ surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
190
+ if surf_cubes.sum() == 0:
191
+ return torch.zeros(
192
+ (0, 3),
193
+ device=self.device), torch.zeros(
194
+ (0, 4),
195
+ dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
196
+ (0, 3),
197
+ dtype=torch.long, device=self.device), torch.zeros(
198
+ (0),
199
+ device=self.device)
200
+ beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
201
+
202
+ case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
203
+
204
+ surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
205
+
206
+ vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
207
+ x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
208
+ vertices, faces, s_edges, edge_indices = self._triangulate(
209
+ s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
210
+ if not output_tetmesh:
211
+ return vertices, faces, L_dev
212
+ else:
213
+ vertices, tets = self._tetrahedralize(
214
+ x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
215
+ surf_cubes, training)
216
+ return vertices, tets, L_dev
217
+
218
+ def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
219
+ """
220
+ Regularizer L_dev as in Equation 8
221
+ """
222
+ dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
223
+ mean_l2 = torch.zeros_like(vd[:, 0])
224
+ mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
225
+ mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
226
+ return mad
227
+
228
+ def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
229
+ """
230
+ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
231
+ """
232
+ n_cubes = surf_cubes.shape[0]
233
+
234
+ if beta_fx12 is not None:
235
+ beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
236
+ else:
237
+ beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
238
+
239
+ if alpha_fx8 is not None:
240
+ alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
241
+ else:
242
+ alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
243
+
244
+ if gamma_f is not None:
245
+ gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
246
+ else:
247
+ gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
248
+
249
+ return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
250
+
251
+ @torch.no_grad()
252
+ def _get_case_id(self, occ_fx8, surf_cubes, res):
253
+ """
254
+ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
255
+ ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
256
+ supplementary material. It should be noted that this function assumes a regular grid.
257
+ """
258
+ case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
259
+
260
+ problem_config = self.check_table.to(self.device)[case_ids]
261
+ to_check = problem_config[..., 0] == 1
262
+ problem_config = problem_config[to_check]
263
+ if not isinstance(res, (list, tuple)):
264
+ res = [res, res, res]
265
+
266
+ # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
267
+ # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
268
+ # This allows efficient checking on adjacent cubes.
269
+ problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
270
+ vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
271
+ vol_idx_problem = vol_idx[surf_cubes][to_check]
272
+ problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
273
+ vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
274
+
275
+ within_range = (
276
+ vol_idx_problem_adj[..., 0] >= 0) & (
277
+ vol_idx_problem_adj[..., 0] < res[0]) & (
278
+ vol_idx_problem_adj[..., 1] >= 0) & (
279
+ vol_idx_problem_adj[..., 1] < res[1]) & (
280
+ vol_idx_problem_adj[..., 2] >= 0) & (
281
+ vol_idx_problem_adj[..., 2] < res[2])
282
+
283
+ vol_idx_problem = vol_idx_problem[within_range]
284
+ vol_idx_problem_adj = vol_idx_problem_adj[within_range]
285
+ problem_config = problem_config[within_range]
286
+ problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
287
+ vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
288
+ # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
289
+ to_invert = (problem_config_adj[..., 0] == 1)
290
+ idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
291
+ case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
292
+ return case_ids
293
+
294
+ @torch.no_grad()
295
+ def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
296
+ """
297
+ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
298
+ can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
299
+ and marks the cube edges with this index.
300
+ """
301
+ occ_n = s_n < 0
302
+ all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
303
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
304
+
305
+ unique_edges = unique_edges.long()
306
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
307
+
308
+ surf_edges_mask = mask_edges[_idx_map]
309
+ counts = counts[_idx_map]
310
+
311
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
312
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
313
+ # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
314
+ # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
315
+ idx_map = mapping[_idx_map]
316
+ surf_edges = unique_edges[mask_edges]
317
+ return surf_edges, idx_map, counts, surf_edges_mask
318
+
319
+ @torch.no_grad()
320
+ def _identify_surf_cubes(self, s_n, cube_fx8):
321
+ """
322
+ Identifies grid cubes that intersect with the underlying surface by checking if the signs at
323
+ all corners are not identical.
324
+ """
325
+ occ_n = s_n < 0
326
+ occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
327
+ _occ_sum = torch.sum(occ_fx8, -1)
328
+ surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
329
+ return surf_cubes, occ_fx8
330
+
331
+ def _linear_interp(self, edges_weight, edges_x):
332
+ """
333
+ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
334
+ """
335
+ edge_dim = edges_weight.dim() - 2
336
+ assert edges_weight.shape[edge_dim] == 2
337
+ edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
338
+ torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
339
+ denominator = edges_weight.sum(edge_dim)
340
+ ue = (edges_x * edges_weight).sum(edge_dim) / denominator
341
+ return ue
342
+
343
+ def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
344
+ p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
345
+ norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
346
+ c_bx3 = c_bx3.reshape(-1, 3)
347
+ A = norm_bxnx3
348
+ B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
349
+
350
+ A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
351
+ B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
352
+ A = torch.cat([A, A_reg], 1)
353
+ B = torch.cat([B, B_reg], 1)
354
+ dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
355
+ return dual_verts
356
+
357
+ def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
358
+ """
359
+ Computes the location of dual vertices as described in Section 4.2
360
+ """
361
+ alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
362
+ surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
363
+ surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
364
+ zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
365
+
366
+ idx_map = idx_map.reshape(-1, 12)
367
+ num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
368
+ edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
369
+
370
+ total_num_vd = 0
371
+ vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
372
+ if grad_func is not None:
373
+ normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
374
+ vd = []
375
+ for num in torch.unique(num_vd):
376
+ cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
377
+ curr_num_vd = cur_cubes.sum() * num
378
+ curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
379
+ curr_edge_group_to_vd = torch.arange(
380
+ curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
381
+ total_num_vd += curr_num_vd
382
+ curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
383
+ cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
384
+
385
+ curr_mask = (curr_edge_group != -1)
386
+ edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
387
+ edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
388
+ edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
389
+ vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
390
+ vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
391
+
392
+ if grad_func is not None:
393
+ with torch.no_grad():
394
+ cube_e_verts_idx = idx_map[cur_cubes]
395
+ curr_edge_group[~curr_mask] = 0
396
+
397
+ verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
398
+ verts_group_idx[verts_group_idx == -1] = 0
399
+ verts_group_pos = torch.index_select(
400
+ input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
401
+ v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
402
+ curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
403
+ verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
404
+
405
+ normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
406
+ -1, num.item(), 7,
407
+ 3)
408
+ curr_mask = curr_mask.squeeze(2)
409
+ vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
410
+ verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
411
+ edge_group = torch.cat(edge_group)
412
+ edge_group_to_vd = torch.cat(edge_group_to_vd)
413
+ edge_group_to_cube = torch.cat(edge_group_to_cube)
414
+ vd_num_edges = torch.cat(vd_num_edges)
415
+ vd_gamma = torch.cat(vd_gamma)
416
+
417
+ if grad_func is not None:
418
+ vd = torch.cat(vd)
419
+ L_dev = torch.zeros([1], device=self.device)
420
+ else:
421
+ vd = torch.zeros((total_num_vd, 3), device=self.device)
422
+ beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
423
+
424
+ idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
425
+
426
+ x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
427
+ s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
428
+
429
+ zero_crossing_group = torch.index_select(
430
+ input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
431
+
432
+ alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
433
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
434
+ ue_group = self._linear_interp(s_group * alpha_group, x_group)
435
+
436
+ beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
437
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
438
+ beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
439
+ vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
440
+ L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
441
+
442
+ v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
443
+
444
+ vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
445
+ 12 + edge_group, src=v_idx[edge_group_to_vd])
446
+
447
+ return vd, L_dev, vd_gamma, vd_idx_map
448
+
449
+ def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
450
+ """
451
+ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
452
+ triangles based on the gamma parameter, as described in Section 4.3.
453
+ """
454
+ with torch.no_grad():
455
+ group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
456
+ group = idx_map.reshape(-1)[group_mask]
457
+ vd_idx = vd_idx_map[group_mask]
458
+ edge_indices, indices = torch.sort(group, stable=True)
459
+ quad_vd_idx = vd_idx[indices].reshape(-1, 4)
460
+
461
+ # Ensure all face directions point towards the positive SDF to maintain consistent winding.
462
+ s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
463
+ flip_mask = s_edges[:, 0] > 0
464
+ quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
465
+ quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
466
+ if grad_func is not None:
467
+ # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
468
+ with torch.no_grad():
469
+ vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
470
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
471
+ gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
472
+ gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
473
+ else:
474
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
475
+ gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
476
+ 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
477
+ gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
478
+ 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
479
+ if not training:
480
+ mask = (gamma_02 > gamma_13).squeeze(1)
481
+ faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
482
+ faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
483
+ faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
484
+ faces = faces.reshape(-1, 3)
485
+ else:
486
+ vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
487
+ vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
488
+ torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
489
+ vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
490
+ torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
491
+ weight_sum = (gamma_02 + gamma_13) + 1e-8
492
+ vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
493
+ weight_sum.unsqueeze(-1)).squeeze(1)
494
+ vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
495
+ vd = torch.cat([vd, vd_center])
496
+ faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
497
+ faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
498
+ return vd, faces, s_edges, edge_indices
499
+
500
+ def _tetrahedralize(
501
+ self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
502
+ surf_cubes, training):
503
+ """
504
+ Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
505
+ """
506
+ occ_n = s_n < 0
507
+ occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
508
+ occ_sum = torch.sum(occ_fx8, -1)
509
+
510
+ inside_verts = x_nx3[occ_n]
511
+ mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
512
+ mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
513
+ """
514
+ For each grid edge connecting two grid vertices with different
515
+ signs, we first form a four-sided pyramid by connecting one
516
+ of the grid vertices with four mesh vertices that correspond
517
+ to the grid edge and then subdivide the pyramid into two tetrahedra
518
+ """
519
+ inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
520
+ s_edges < 0]]
521
+ if not training:
522
+ inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
523
+ else:
524
+ inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
525
+
526
+ tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
527
+ """
528
+ For each grid edge connecting two grid vertices with the
529
+ same sign, the tetrahedron is formed by the two grid vertices
530
+ and two vertices in consecutive adjacent cells
531
+ """
532
+ inside_cubes = (occ_sum == 8)
533
+ inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
534
+ inside_cubes_center_idx = torch.arange(
535
+ inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
536
+
537
+ surface_n_inside_cubes = surf_cubes | inside_cubes
538
+ edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
539
+ dtype=torch.long, device=x_nx3.device) * -1
540
+ surf_cubes = surf_cubes[surface_n_inside_cubes]
541
+ inside_cubes = inside_cubes[surface_n_inside_cubes]
542
+ edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
543
+ edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
544
+
545
+ all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
546
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
547
+ unique_edges = unique_edges.long()
548
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
549
+ mask = mask_edges[_idx_map]
550
+ counts = counts[_idx_map]
551
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
552
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
553
+ idx_map = mapping[_idx_map]
554
+
555
+ group_mask = (counts == 4) & mask
556
+ group = idx_map.reshape(-1)[group_mask]
557
+ edge_indices, indices = torch.sort(group)
558
+ cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
559
+ device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
560
+ edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
561
+ 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
562
+ # Identify the face shared by the adjacent cells.
563
+ cube_idx_4 = cube_idx[indices].reshape(-1, 4)
564
+ edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
565
+ shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
566
+ cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
567
+ # Identify an edge of the face with different signs and
568
+ # select the mesh vertex corresponding to the identified edge.
569
+ case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
570
+ case_ids_expand[surf_cubes] = case_ids
571
+ cases = case_ids_expand[cube_idx_4x2]
572
+ quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
573
+ mask = (quad_edge == -1).sum(-1) == 0
574
+ inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
575
+ tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
576
+
577
+ tets = torch.cat([tets_surface, tets_inside])
578
+ vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
579
+ return vertices, tets
core/geometry/rep_3d/flexicubes_geometry.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import numpy as np
11
+ import os
12
+ from . import Geometry
13
+ from .flexicubes import FlexiCubes # replace later
14
+ from .dmtet import sdf_reg_loss_batch
15
+ import torch.nn.functional as F
16
+
17
+ def get_center_boundary_index(grid_res, device):
18
+ v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
19
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
20
+ center_indices = torch.nonzero(v.reshape(-1))
21
+
22
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
23
+ v[:2, ...] = True
24
+ v[-2:, ...] = True
25
+ v[:, :2, ...] = True
26
+ v[:, -2:, ...] = True
27
+ v[:, :, :2] = True
28
+ v[:, :, -2:] = True
29
+ boundary_indices = torch.nonzero(v.reshape(-1))
30
+ return center_indices, boundary_indices
31
+
32
+ ###############################################################################
33
+ # Geometry interface
34
+ ###############################################################################
35
+ class FlexiCubesGeometry(Geometry):
36
+ def __init__(
37
+ self, grid_res=64, scale=2.0, device='cuda', renderer=None,
38
+ render_type='neural_render', args=None):
39
+ super(FlexiCubesGeometry, self).__init__()
40
+ self.grid_res = grid_res
41
+ self.device = device
42
+ self.args = args
43
+ self.fc = FlexiCubes(device, weight_scale=0.5)
44
+ self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
45
+ if isinstance(scale, list):
46
+ self.verts[:, 0] = self.verts[:, 0] * scale[0]
47
+ self.verts[:, 1] = self.verts[:, 1] * scale[1]
48
+ self.verts[:, 2] = self.verts[:, 2] * scale[1]
49
+ else:
50
+ self.verts = self.verts * scale
51
+
52
+ all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
53
+ self.all_edges = torch.unique(all_edges, dim=0)
54
+
55
+ # Parameters used for fix boundary sdf
56
+ self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
57
+ self.renderer = renderer
58
+ self.render_type = render_type
59
+
60
+ def getAABB(self):
61
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
62
+
63
+ def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
64
+ if indices is None:
65
+ indices = self.indices
66
+
67
+ verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
68
+ beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
69
+ gamma_f=weight_n[:, 20], training=is_training
70
+ )
71
+ return verts, faces, v_reg_loss
72
+
73
+
74
+ def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
75
+ return_value = dict()
76
+ if self.render_type == 'neural_render':
77
+ tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh(
78
+ mesh_v_nx3.unsqueeze(dim=0),
79
+ mesh_f_fx3.int(),
80
+ camera_mv_bx4x4,
81
+ mesh_v_nx3.unsqueeze(dim=0),
82
+ resolution=resolution,
83
+ device=self.device,
84
+ hierarchical_mask=hierarchical_mask
85
+ )
86
+
87
+ return_value['tex_pos'] = tex_pos
88
+ return_value['mask'] = mask
89
+ return_value['hard_mask'] = hard_mask
90
+ return_value['rast'] = rast
91
+ return_value['v_pos_clip'] = v_pos_clip
92
+ return_value['mask_pyramid'] = mask_pyramid
93
+ return_value['depth'] = depth
94
+ return_value['normal'] = normal
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return return_value
99
+
100
+ def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
101
+ # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
102
+ v_list = []
103
+ f_list = []
104
+ n_batch = v_deformed_bxnx3.shape[0]
105
+ all_render_output = []
106
+ for i_batch in range(n_batch):
107
+ verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
108
+ v_list.append(verts_nx3)
109
+ f_list.append(faces_fx3)
110
+ render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
111
+ all_render_output.append(render_output)
112
+
113
+ # Concatenate all render output
114
+ return_keys = all_render_output[0].keys()
115
+ return_value = dict()
116
+ for k in return_keys:
117
+ value = [v[k] for v in all_render_output]
118
+ return_value[k] = value
119
+ # We can do concatenation outside of the render
120
+ return return_value
core/geometry/rep_3d/tables.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+ dmc_table = [
9
+ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
10
+ [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
11
+ [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
12
+ [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
13
+ [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
14
+ [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
15
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
16
+ [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
17
+ [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
18
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
19
+ [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
20
+ [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
21
+ [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
22
+ [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
23
+ [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
24
+ [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
25
+ [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
26
+ [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
27
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
28
+ [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
29
+ [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
30
+ [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
31
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
32
+ [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
33
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
34
+ [[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
35
+ [[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
36
+ [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
37
+ [[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
38
+ [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
39
+ [[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
40
+ [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
41
+ [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
42
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
43
+ [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
44
+ [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
45
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
46
+ [[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
47
+ [[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
48
+ [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
49
+ [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
50
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
51
+ [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
52
+ [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
53
+ [[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
54
+ [[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
55
+ [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
56
+ [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
57
+ [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
58
+ [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
59
+ [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
60
+ [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
61
+ [[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
62
+ [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
63
+ [[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
64
+ [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
65
+ [[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
66
+ [[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
67
+ [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
68
+ [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
69
+ [[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
70
+ [[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
71
+ [[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
72
+ [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
73
+ [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
74
+ [[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
75
+ [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
76
+ [[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
77
+ [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
78
+ [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
79
+ [[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
80
+ [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
81
+ [[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
82
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
83
+ [[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
84
+ [[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
85
+ [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
86
+ [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
87
+ [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
88
+ [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
89
+ [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
90
+ [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
91
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
92
+ [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
93
+ [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
94
+ [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
95
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
96
+ [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
97
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
98
+ [[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
99
+ [[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
100
+ [[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
101
+ [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
102
+ [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
103
+ [[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
104
+ [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
105
+ [[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
106
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
107
+ [[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
108
+ [[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
109
+ [[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
110
+ [[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
111
+ [[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
112
+ [[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
113
+ [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
114
+ [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
115
+ [[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
116
+ [[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
117
+ [[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
118
+ [[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
119
+ [[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
120
+ [[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
121
+ [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
122
+ [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
123
+ [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
124
+ [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
125
+ [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
126
+ [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
127
+ [[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
128
+ [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
129
+ [[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
130
+ [[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
131
+ [[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
132
+ [[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
133
+ [[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
134
+ [[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
135
+ [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
136
+ [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
137
+ [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
138
+ [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
139
+ [[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
140
+ [[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
141
+ [[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
142
+ [[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
143
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
144
+ [[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
145
+ [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
146
+ [[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
147
+ [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
148
+ [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
149
+ [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
150
+ [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
151
+ [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
152
+ [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
153
+ [[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
154
+ [[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
155
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
156
+ [[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
157
+ [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
158
+ [[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
159
+ [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
160
+ [[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
161
+ [[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
162
+ [[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
163
+ [[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
164
+ [[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
165
+ [[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
166
+ [[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
167
+ [[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
168
+ [[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
169
+ [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
170
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
171
+ [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
172
+ [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
173
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
174
+ [[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
175
+ [[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
176
+ [[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
177
+ [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
178
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
179
+ [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
180
+ [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
181
+ [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
182
+ [[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
183
+ [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
184
+ [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
185
+ [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
186
+ [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
187
+ [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
188
+ [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
189
+ [[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
190
+ [[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
191
+ [[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
192
+ [[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
193
+ [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
194
+ [[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
195
+ [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
196
+ [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
197
+ [[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
198
+ [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
199
+ [[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
200
+ [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
201
+ [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
202
+ [[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
203
+ [[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
204
+ [[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
205
+ [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
206
+ [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
207
+ [[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
208
+ [[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
209
+ [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
210
+ [[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
211
+ [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
212
+ [[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
213
+ [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
214
+ [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
215
+ [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
216
+ [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
217
+ [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
218
+ [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
219
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
220
+ [[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
221
+ [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
222
+ [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
223
+ [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
224
+ [[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
225
+ [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
226
+ [[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
227
+ [[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
228
+ [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
229
+ [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
230
+ [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
231
+ [[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
232
+ [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
233
+ [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
234
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
235
+ [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
236
+ [[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
237
+ [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
238
+ [[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
239
+ [[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
240
+ [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
241
+ [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
242
+ [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
243
+ [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
244
+ [[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
245
+ [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
246
+ [[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
247
+ [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
248
+ [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
249
+ [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
250
+ [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
251
+ [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
252
+ [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
253
+ [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
254
+ [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
255
+ [[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
256
+ [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
257
+ [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
258
+ [[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
259
+ [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
260
+ [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
261
+ [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
262
+ [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
263
+ [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
264
+ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
265
+ ]
266
+ num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
267
+ 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
268
+ 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
269
+ 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
270
+ 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
271
+ 3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
272
+ 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
273
+ 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
274
+ 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
275
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
276
+ check_table = [
277
+ [0, 0, 0, 0, 0],
278
+ [0, 0, 0, 0, 0],
279
+ [0, 0, 0, 0, 0],
280
+ [0, 0, 0, 0, 0],
281
+ [0, 0, 0, 0, 0],
282
+ [0, 0, 0, 0, 0],
283
+ [0, 0, 0, 0, 0],
284
+ [0, 0, 0, 0, 0],
285
+ [0, 0, 0, 0, 0],
286
+ [0, 0, 0, 0, 0],
287
+ [0, 0, 0, 0, 0],
288
+ [0, 0, 0, 0, 0],
289
+ [0, 0, 0, 0, 0],
290
+ [0, 0, 0, 0, 0],
291
+ [0, 0, 0, 0, 0],
292
+ [0, 0, 0, 0, 0],
293
+ [0, 0, 0, 0, 0],
294
+ [0, 0, 0, 0, 0],
295
+ [0, 0, 0, 0, 0],
296
+ [0, 0, 0, 0, 0],
297
+ [0, 0, 0, 0, 0],
298
+ [0, 0, 0, 0, 0],
299
+ [0, 0, 0, 0, 0],
300
+ [0, 0, 0, 0, 0],
301
+ [0, 0, 0, 0, 0],
302
+ [0, 0, 0, 0, 0],
303
+ [0, 0, 0, 0, 0],
304
+ [0, 0, 0, 0, 0],
305
+ [0, 0, 0, 0, 0],
306
+ [0, 0, 0, 0, 0],
307
+ [0, 0, 0, 0, 0],
308
+ [0, 0, 0, 0, 0],
309
+ [0, 0, 0, 0, 0],
310
+ [0, 0, 0, 0, 0],
311
+ [0, 0, 0, 0, 0],
312
+ [0, 0, 0, 0, 0],
313
+ [0, 0, 0, 0, 0],
314
+ [0, 0, 0, 0, 0],
315
+ [0, 0, 0, 0, 0],
316
+ [0, 0, 0, 0, 0],
317
+ [0, 0, 0, 0, 0],
318
+ [0, 0, 0, 0, 0],
319
+ [0, 0, 0, 0, 0],
320
+ [0, 0, 0, 0, 0],
321
+ [0, 0, 0, 0, 0],
322
+ [0, 0, 0, 0, 0],
323
+ [0, 0, 0, 0, 0],
324
+ [0, 0, 0, 0, 0],
325
+ [0, 0, 0, 0, 0],
326
+ [0, 0, 0, 0, 0],
327
+ [0, 0, 0, 0, 0],
328
+ [0, 0, 0, 0, 0],
329
+ [0, 0, 0, 0, 0],
330
+ [0, 0, 0, 0, 0],
331
+ [0, 0, 0, 0, 0],
332
+ [0, 0, 0, 0, 0],
333
+ [0, 0, 0, 0, 0],
334
+ [0, 0, 0, 0, 0],
335
+ [0, 0, 0, 0, 0],
336
+ [0, 0, 0, 0, 0],
337
+ [0, 0, 0, 0, 0],
338
+ [1, 1, 0, 0, 194],
339
+ [1, -1, 0, 0, 193],
340
+ [0, 0, 0, 0, 0],
341
+ [0, 0, 0, 0, 0],
342
+ [0, 0, 0, 0, 0],
343
+ [0, 0, 0, 0, 0],
344
+ [0, 0, 0, 0, 0],
345
+ [0, 0, 0, 0, 0],
346
+ [0, 0, 0, 0, 0],
347
+ [0, 0, 0, 0, 0],
348
+ [0, 0, 0, 0, 0],
349
+ [0, 0, 0, 0, 0],
350
+ [0, 0, 0, 0, 0],
351
+ [0, 0, 0, 0, 0],
352
+ [0, 0, 0, 0, 0],
353
+ [0, 0, 0, 0, 0],
354
+ [0, 0, 0, 0, 0],
355
+ [0, 0, 0, 0, 0],
356
+ [0, 0, 0, 0, 0],
357
+ [0, 0, 0, 0, 0],
358
+ [0, 0, 0, 0, 0],
359
+ [0, 0, 0, 0, 0],
360
+ [0, 0, 0, 0, 0],
361
+ [0, 0, 0, 0, 0],
362
+ [0, 0, 0, 0, 0],
363
+ [0, 0, 0, 0, 0],
364
+ [0, 0, 0, 0, 0],
365
+ [0, 0, 0, 0, 0],
366
+ [0, 0, 0, 0, 0],
367
+ [0, 0, 0, 0, 0],
368
+ [1, 0, 1, 0, 164],
369
+ [0, 0, 0, 0, 0],
370
+ [0, 0, 0, 0, 0],
371
+ [1, 0, -1, 0, 161],
372
+ [0, 0, 0, 0, 0],
373
+ [0, 0, 0, 0, 0],
374
+ [0, 0, 0, 0, 0],
375
+ [0, 0, 0, 0, 0],
376
+ [0, 0, 0, 0, 0],
377
+ [0, 0, 0, 0, 0],
378
+ [0, 0, 0, 0, 0],
379
+ [0, 0, 0, 0, 0],
380
+ [1, 0, 0, 1, 152],
381
+ [0, 0, 0, 0, 0],
382
+ [0, 0, 0, 0, 0],
383
+ [0, 0, 0, 0, 0],
384
+ [0, 0, 0, 0, 0],
385
+ [0, 0, 0, 0, 0],
386
+ [0, 0, 0, 0, 0],
387
+ [1, 0, 0, 1, 145],
388
+ [1, 0, 0, 1, 144],
389
+ [0, 0, 0, 0, 0],
390
+ [0, 0, 0, 0, 0],
391
+ [0, 0, 0, 0, 0],
392
+ [0, 0, 0, 0, 0],
393
+ [0, 0, 0, 0, 0],
394
+ [0, 0, 0, 0, 0],
395
+ [1, 0, 0, -1, 137],
396
+ [0, 0, 0, 0, 0],
397
+ [0, 0, 0, 0, 0],
398
+ [0, 0, 0, 0, 0],
399
+ [1, 0, 1, 0, 133],
400
+ [1, 0, 1, 0, 132],
401
+ [1, 1, 0, 0, 131],
402
+ [1, 1, 0, 0, 130],
403
+ [0, 0, 0, 0, 0],
404
+ [0, 0, 0, 0, 0],
405
+ [0, 0, 0, 0, 0],
406
+ [0, 0, 0, 0, 0],
407
+ [0, 0, 0, 0, 0],
408
+ [0, 0, 0, 0, 0],
409
+ [0, 0, 0, 0, 0],
410
+ [0, 0, 0, 0, 0],
411
+ [0, 0, 0, 0, 0],
412
+ [0, 0, 0, 0, 0],
413
+ [0, 0, 0, 0, 0],
414
+ [0, 0, 0, 0, 0],
415
+ [0, 0, 0, 0, 0],
416
+ [0, 0, 0, 0, 0],
417
+ [0, 0, 0, 0, 0],
418
+ [0, 0, 0, 0, 0],
419
+ [0, 0, 0, 0, 0],
420
+ [0, 0, 0, 0, 0],
421
+ [0, 0, 0, 0, 0],
422
+ [0, 0, 0, 0, 0],
423
+ [0, 0, 0, 0, 0],
424
+ [0, 0, 0, 0, 0],
425
+ [0, 0, 0, 0, 0],
426
+ [0, 0, 0, 0, 0],
427
+ [0, 0, 0, 0, 0],
428
+ [0, 0, 0, 0, 0],
429
+ [0, 0, 0, 0, 0],
430
+ [0, 0, 0, 0, 0],
431
+ [0, 0, 0, 0, 0],
432
+ [1, 0, 0, 1, 100],
433
+ [0, 0, 0, 0, 0],
434
+ [1, 0, 0, 1, 98],
435
+ [0, 0, 0, 0, 0],
436
+ [1, 0, 0, 1, 96],
437
+ [0, 0, 0, 0, 0],
438
+ [0, 0, 0, 0, 0],
439
+ [0, 0, 0, 0, 0],
440
+ [0, 0, 0, 0, 0],
441
+ [0, 0, 0, 0, 0],
442
+ [0, 0, 0, 0, 0],
443
+ [0, 0, 0, 0, 0],
444
+ [1, 0, 1, 0, 88],
445
+ [0, 0, 0, 0, 0],
446
+ [0, 0, 0, 0, 0],
447
+ [0, 0, 0, 0, 0],
448
+ [0, 0, 0, 0, 0],
449
+ [0, 0, 0, 0, 0],
450
+ [1, 0, -1, 0, 82],
451
+ [0, 0, 0, 0, 0],
452
+ [0, 0, 0, 0, 0],
453
+ [0, 0, 0, 0, 0],
454
+ [0, 0, 0, 0, 0],
455
+ [0, 0, 0, 0, 0],
456
+ [0, 0, 0, 0, 0],
457
+ [0, 0, 0, 0, 0],
458
+ [1, 0, 1, 0, 74],
459
+ [0, 0, 0, 0, 0],
460
+ [1, 0, 1, 0, 72],
461
+ [0, 0, 0, 0, 0],
462
+ [1, 0, 0, -1, 70],
463
+ [0, 0, 0, 0, 0],
464
+ [0, 0, 0, 0, 0],
465
+ [1, -1, 0, 0, 67],
466
+ [0, 0, 0, 0, 0],
467
+ [1, -1, 0, 0, 65],
468
+ [0, 0, 0, 0, 0],
469
+ [0, 0, 0, 0, 0],
470
+ [0, 0, 0, 0, 0],
471
+ [0, 0, 0, 0, 0],
472
+ [0, 0, 0, 0, 0],
473
+ [0, 0, 0, 0, 0],
474
+ [0, 0, 0, 0, 0],
475
+ [0, 0, 0, 0, 0],
476
+ [1, 1, 0, 0, 56],
477
+ [0, 0, 0, 0, 0],
478
+ [0, 0, 0, 0, 0],
479
+ [0, 0, 0, 0, 0],
480
+ [1, -1, 0, 0, 52],
481
+ [0, 0, 0, 0, 0],
482
+ [0, 0, 0, 0, 0],
483
+ [0, 0, 0, 0, 0],
484
+ [0, 0, 0, 0, 0],
485
+ [0, 0, 0, 0, 0],
486
+ [0, 0, 0, 0, 0],
487
+ [0, 0, 0, 0, 0],
488
+ [1, 1, 0, 0, 44],
489
+ [0, 0, 0, 0, 0],
490
+ [0, 0, 0, 0, 0],
491
+ [0, 0, 0, 0, 0],
492
+ [1, 1, 0, 0, 40],
493
+ [0, 0, 0, 0, 0],
494
+ [1, 0, 0, -1, 38],
495
+ [1, 0, -1, 0, 37],
496
+ [0, 0, 0, 0, 0],
497
+ [0, 0, 0, 0, 0],
498
+ [0, 0, 0, 0, 0],
499
+ [1, 0, -1, 0, 33],
500
+ [0, 0, 0, 0, 0],
501
+ [0, 0, 0, 0, 0],
502
+ [0, 0, 0, 0, 0],
503
+ [0, 0, 0, 0, 0],
504
+ [1, -1, 0, 0, 28],
505
+ [0, 0, 0, 0, 0],
506
+ [1, 0, -1, 0, 26],
507
+ [1, 0, 0, -1, 25],
508
+ [0, 0, 0, 0, 0],
509
+ [0, 0, 0, 0, 0],
510
+ [0, 0, 0, 0, 0],
511
+ [0, 0, 0, 0, 0],
512
+ [1, -1, 0, 0, 20],
513
+ [0, 0, 0, 0, 0],
514
+ [1, 0, -1, 0, 18],
515
+ [0, 0, 0, 0, 0],
516
+ [0, 0, 0, 0, 0],
517
+ [0, 0, 0, 0, 0],
518
+ [0, 0, 0, 0, 0],
519
+ [0, 0, 0, 0, 0],
520
+ [0, 0, 0, 0, 0],
521
+ [0, 0, 0, 0, 0],
522
+ [0, 0, 0, 0, 0],
523
+ [1, 0, 0, -1, 9],
524
+ [0, 0, 0, 0, 0],
525
+ [0, 0, 0, 0, 0],
526
+ [1, 0, 0, -1, 6],
527
+ [0, 0, 0, 0, 0],
528
+ [0, 0, 0, 0, 0],
529
+ [0, 0, 0, 0, 0],
530
+ [0, 0, 0, 0, 0],
531
+ [0, 0, 0, 0, 0],
532
+ [0, 0, 0, 0, 0]
533
+ ]
534
+ tet_table = [
535
+ [-1, -1, -1, -1, -1, -1],
536
+ [0, 0, 0, 0, 0, 0],
537
+ [0, 0, 0, 0, 0, 0],
538
+ [1, 1, 1, 1, 1, 1],
539
+ [4, 4, 4, 4, 4, 4],
540
+ [0, 0, 0, 0, 0, 0],
541
+ [4, 0, 0, 4, 4, -1],
542
+ [1, 1, 1, 1, 1, 1],
543
+ [4, 4, 4, 4, 4, 4],
544
+ [0, 4, 0, 4, 4, -1],
545
+ [0, 0, 0, 0, 0, 0],
546
+ [1, 1, 1, 1, 1, 1],
547
+ [5, 5, 5, 5, 5, 5],
548
+ [0, 0, 0, 0, 0, 0],
549
+ [0, 0, 0, 0, 0, 0],
550
+ [1, 1, 1, 1, 1, 1],
551
+ [2, 2, 2, 2, 2, 2],
552
+ [0, 0, 0, 0, 0, 0],
553
+ [2, 0, 2, -1, 0, 2],
554
+ [1, 1, 1, 1, 1, 1],
555
+ [2, -1, 2, 4, 4, 2],
556
+ [0, 0, 0, 0, 0, 0],
557
+ [2, 0, 2, 4, 4, 2],
558
+ [1, 1, 1, 1, 1, 1],
559
+ [2, 4, 2, 4, 4, 2],
560
+ [0, 4, 0, 4, 4, 0],
561
+ [2, 0, 2, 0, 0, 2],
562
+ [1, 1, 1, 1, 1, 1],
563
+ [2, 5, 2, 5, 5, 2],
564
+ [0, 0, 0, 0, 0, 0],
565
+ [2, 0, 2, 0, 0, 2],
566
+ [1, 1, 1, 1, 1, 1],
567
+ [1, 1, 1, 1, 1, 1],
568
+ [0, 1, 1, -1, 0, 1],
569
+ [0, 0, 0, 0, 0, 0],
570
+ [2, 2, 2, 2, 2, 2],
571
+ [4, 1, 1, 4, 4, 1],
572
+ [0, 1, 1, 0, 0, 1],
573
+ [4, 0, 0, 4, 4, 0],
574
+ [2, 2, 2, 2, 2, 2],
575
+ [-1, 1, 1, 4, 4, 1],
576
+ [0, 1, 1, 4, 4, 1],
577
+ [0, 0, 0, 0, 0, 0],
578
+ [2, 2, 2, 2, 2, 2],
579
+ [5, 1, 1, 5, 5, 1],
580
+ [0, 1, 1, 0, 0, 1],
581
+ [0, 0, 0, 0, 0, 0],
582
+ [2, 2, 2, 2, 2, 2],
583
+ [1, 1, 1, 1, 1, 1],
584
+ [0, 0, 0, 0, 0, 0],
585
+ [0, 0, 0, 0, 0, 0],
586
+ [8, 8, 8, 8, 8, 8],
587
+ [1, 1, 1, 4, 4, 1],
588
+ [0, 0, 0, 0, 0, 0],
589
+ [4, 0, 0, 4, 4, 0],
590
+ [4, 4, 4, 4, 4, 4],
591
+ [1, 1, 1, 4, 4, 1],
592
+ [0, 4, 0, 4, 4, 0],
593
+ [0, 0, 0, 0, 0, 0],
594
+ [4, 4, 4, 4, 4, 4],
595
+ [1, 1, 1, 5, 5, 1],
596
+ [0, 0, 0, 0, 0, 0],
597
+ [0, 0, 0, 0, 0, 0],
598
+ [5, 5, 5, 5, 5, 5],
599
+ [6, 6, 6, 6, 6, 6],
600
+ [6, -1, 0, 6, 0, 6],
601
+ [6, 0, 0, 6, 0, 6],
602
+ [6, 1, 1, 6, 1, 6],
603
+ [4, 4, 4, 4, 4, 4],
604
+ [0, 0, 0, 0, 0, 0],
605
+ [4, 0, 0, 4, 4, 4],
606
+ [1, 1, 1, 1, 1, 1],
607
+ [6, 4, -1, 6, 4, 6],
608
+ [6, 4, 0, 6, 4, 6],
609
+ [6, 0, 0, 6, 0, 6],
610
+ [6, 1, 1, 6, 1, 6],
611
+ [5, 5, 5, 5, 5, 5],
612
+ [0, 0, 0, 0, 0, 0],
613
+ [0, 0, 0, 0, 0, 0],
614
+ [1, 1, 1, 1, 1, 1],
615
+ [2, 2, 2, 2, 2, 2],
616
+ [0, 0, 0, 0, 0, 0],
617
+ [2, 0, 2, 2, 0, 2],
618
+ [1, 1, 1, 1, 1, 1],
619
+ [2, 2, 2, 2, 2, 2],
620
+ [0, 0, 0, 0, 0, 0],
621
+ [2, 0, 2, 2, 2, 2],
622
+ [1, 1, 1, 1, 1, 1],
623
+ [2, 4, 2, 2, 4, 2],
624
+ [0, 4, 0, 4, 4, 0],
625
+ [2, 0, 2, 2, 0, 2],
626
+ [1, 1, 1, 1, 1, 1],
627
+ [2, 2, 2, 2, 2, 2],
628
+ [0, 0, 0, 0, 0, 0],
629
+ [0, 0, 0, 0, 0, 0],
630
+ [1, 1, 1, 1, 1, 1],
631
+ [6, 1, 1, 6, -1, 6],
632
+ [6, 1, 1, 6, 0, 6],
633
+ [6, 0, 0, 6, 0, 6],
634
+ [6, 2, 2, 6, 2, 6],
635
+ [4, 1, 1, 4, 4, 1],
636
+ [0, 1, 1, 0, 0, 1],
637
+ [4, 0, 0, 4, 4, 4],
638
+ [2, 2, 2, 2, 2, 2],
639
+ [6, 1, 1, 6, 4, 6],
640
+ [6, 1, 1, 6, 4, 6],
641
+ [6, 0, 0, 6, 0, 6],
642
+ [6, 2, 2, 6, 2, 6],
643
+ [5, 1, 1, 5, 5, 1],
644
+ [0, 1, 1, 0, 0, 1],
645
+ [0, 0, 0, 0, 0, 0],
646
+ [2, 2, 2, 2, 2, 2],
647
+ [1, 1, 1, 1, 1, 1],
648
+ [0, 0, 0, 0, 0, 0],
649
+ [0, 0, 0, 0, 0, 0],
650
+ [6, 6, 6, 6, 6, 6],
651
+ [1, 1, 1, 1, 1, 1],
652
+ [0, 0, 0, 0, 0, 0],
653
+ [0, 0, 0, 0, 0, 0],
654
+ [4, 4, 4, 4, 4, 4],
655
+ [1, 1, 1, 1, 4, 1],
656
+ [0, 4, 0, 4, 4, 0],
657
+ [0, 0, 0, 0, 0, 0],
658
+ [4, 4, 4, 4, 4, 4],
659
+ [1, 1, 1, 1, 1, 1],
660
+ [0, 0, 0, 0, 0, 0],
661
+ [0, 5, 0, 5, 0, 5],
662
+ [5, 5, 5, 5, 5, 5],
663
+ [5, 5, 5, 5, 5, 5],
664
+ [0, 5, 0, 5, 0, 5],
665
+ [-1, 5, 0, 5, 0, 5],
666
+ [1, 5, 1, 5, 1, 5],
667
+ [4, 5, -1, 5, 4, 5],
668
+ [0, 5, 0, 5, 0, 5],
669
+ [4, 5, 0, 5, 4, 5],
670
+ [1, 5, 1, 5, 1, 5],
671
+ [4, 4, 4, 4, 4, 4],
672
+ [0, 4, 0, 4, 4, 4],
673
+ [0, 0, 0, 0, 0, 0],
674
+ [1, 1, 1, 1, 1, 1],
675
+ [6, 6, 6, 6, 6, 6],
676
+ [0, 0, 0, 0, 0, 0],
677
+ [0, 0, 0, 0, 0, 0],
678
+ [1, 1, 1, 1, 1, 1],
679
+ [2, 5, 2, 5, -1, 5],
680
+ [0, 5, 0, 5, 0, 5],
681
+ [2, 5, 2, 5, 0, 5],
682
+ [1, 5, 1, 5, 1, 5],
683
+ [2, 5, 2, 5, 4, 5],
684
+ [0, 5, 0, 5, 0, 5],
685
+ [2, 5, 2, 5, 4, 5],
686
+ [1, 5, 1, 5, 1, 5],
687
+ [2, 4, 2, 4, 4, 2],
688
+ [0, 4, 0, 4, 4, 4],
689
+ [2, 0, 2, 0, 0, 2],
690
+ [1, 1, 1, 1, 1, 1],
691
+ [2, 6, 2, 6, 6, 2],
692
+ [0, 0, 0, 0, 0, 0],
693
+ [2, 0, 2, 0, 0, 2],
694
+ [1, 1, 1, 1, 1, 1],
695
+ [1, 1, 1, 1, 1, 1],
696
+ [0, 1, 1, 1, 0, 1],
697
+ [0, 0, 0, 0, 0, 0],
698
+ [2, 2, 2, 2, 2, 2],
699
+ [4, 1, 1, 1, 4, 1],
700
+ [0, 1, 1, 1, 0, 1],
701
+ [4, 0, 0, 4, 4, 0],
702
+ [2, 2, 2, 2, 2, 2],
703
+ [1, 1, 1, 1, 1, 1],
704
+ [0, 1, 1, 1, 1, 1],
705
+ [0, 0, 0, 0, 0, 0],
706
+ [2, 2, 2, 2, 2, 2],
707
+ [1, 1, 1, 1, 1, 1],
708
+ [0, 0, 0, 0, 0, 0],
709
+ [0, 0, 0, 0, 0, 0],
710
+ [2, 2, 2, 2, 2, 2],
711
+ [1, 1, 1, 1, 1, 1],
712
+ [0, 0, 0, 0, 0, 0],
713
+ [0, 0, 0, 0, 0, 0],
714
+ [5, 5, 5, 5, 5, 5],
715
+ [1, 1, 1, 1, 4, 1],
716
+ [0, 0, 0, 0, 0, 0],
717
+ [4, 0, 0, 4, 4, 0],
718
+ [4, 4, 4, 4, 4, 4],
719
+ [1, 1, 1, 1, 1, 1],
720
+ [0, 0, 0, 0, 0, 0],
721
+ [0, 0, 0, 0, 0, 0],
722
+ [4, 4, 4, 4, 4, 4],
723
+ [1, 1, 1, 1, 1, 1],
724
+ [6, 0, 0, 6, 0, 6],
725
+ [0, 0, 0, 0, 0, 0],
726
+ [6, 6, 6, 6, 6, 6],
727
+ [5, 5, 5, 5, 5, 5],
728
+ [5, 5, 0, 5, 0, 5],
729
+ [5, 5, 0, 5, 0, 5],
730
+ [5, 5, 1, 5, 1, 5],
731
+ [4, 4, 4, 4, 4, 4],
732
+ [0, 0, 0, 0, 0, 0],
733
+ [4, 4, 0, 4, 4, 4],
734
+ [1, 1, 1, 1, 1, 1],
735
+ [4, 4, 4, 4, 4, 4],
736
+ [4, 4, 0, 4, 4, 4],
737
+ [0, 0, 0, 0, 0, 0],
738
+ [1, 1, 1, 1, 1, 1],
739
+ [8, 8, 8, 8, 8, 8],
740
+ [0, 0, 0, 0, 0, 0],
741
+ [0, 0, 0, 0, 0, 0],
742
+ [1, 1, 1, 1, 1, 1],
743
+ [2, 2, 2, 2, 2, 2],
744
+ [0, 0, 0, 0, 0, 0],
745
+ [2, 2, 2, 2, 0, 2],
746
+ [1, 1, 1, 1, 1, 1],
747
+ [2, 2, 2, 2, 2, 2],
748
+ [0, 0, 0, 0, 0, 0],
749
+ [2, 2, 2, 2, 2, 2],
750
+ [1, 1, 1, 1, 1, 1],
751
+ [2, 2, 2, 2, 2, 2],
752
+ [0, 0, 0, 0, 0, 0],
753
+ [0, 0, 0, 0, 0, 0],
754
+ [4, 1, 1, 4, 4, 1],
755
+ [2, 2, 2, 2, 2, 2],
756
+ [0, 0, 0, 0, 0, 0],
757
+ [0, 0, 0, 0, 0, 0],
758
+ [1, 1, 1, 1, 1, 1],
759
+ [1, 1, 1, 1, 1, 1],
760
+ [1, 1, 1, 1, 0, 1],
761
+ [0, 0, 0, 0, 0, 0],
762
+ [2, 2, 2, 2, 2, 2],
763
+ [1, 1, 1, 1, 1, 1],
764
+ [0, 0, 0, 0, 0, 0],
765
+ [0, 0, 0, 0, 0, 0],
766
+ [2, 4, 2, 4, 4, 2],
767
+ [1, 1, 1, 1, 1, 1],
768
+ [1, 1, 1, 1, 1, 1],
769
+ [0, 0, 0, 0, 0, 0],
770
+ [2, 2, 2, 2, 2, 2],
771
+ [1, 1, 1, 1, 1, 1],
772
+ [0, 0, 0, 0, 0, 0],
773
+ [0, 0, 0, 0, 0, 0],
774
+ [2, 2, 2, 2, 2, 2],
775
+ [1, 1, 1, 1, 1, 1],
776
+ [0, 0, 0, 0, 0, 0],
777
+ [0, 0, 0, 0, 0, 0],
778
+ [5, 5, 5, 5, 5, 5],
779
+ [1, 1, 1, 1, 1, 1],
780
+ [0, 0, 0, 0, 0, 0],
781
+ [0, 0, 0, 0, 0, 0],
782
+ [4, 4, 4, 4, 4, 4],
783
+ [1, 1, 1, 1, 1, 1],
784
+ [0, 0, 0, 0, 0, 0],
785
+ [0, 0, 0, 0, 0, 0],
786
+ [4, 4, 4, 4, 4, 4],
787
+ [1, 1, 1, 1, 1, 1],
788
+ [0, 0, 0, 0, 0, 0],
789
+ [0, 0, 0, 0, 0, 0],
790
+ [12, 12, 12, 12, 12, 12]
791
+ ]
core/instant_utils/__init__.py ADDED
File without changes
core/instant_utils/camera_util.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+
6
+ def pad_camera_extrinsics_4x4(extrinsics):
7
+ if extrinsics.shape[-2] == 4:
8
+ return extrinsics
9
+ padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
10
+ if extrinsics.ndim == 3:
11
+ padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
12
+ extrinsics = torch.cat([extrinsics, padding], dim=-2)
13
+ return extrinsics
14
+
15
+
16
+ def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
17
+ """
18
+ Create OpenGL camera extrinsics from camera locations and look-at position.
19
+
20
+ camera_position: (M, 3) or (3,)
21
+ look_at: (3)
22
+ up_world: (3)
23
+ return: (M, 3, 4) or (3, 4)
24
+ """
25
+ # by default, looking at the origin and world up is z-axis
26
+ if look_at is None:
27
+ look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
28
+ if up_world is None:
29
+ up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
30
+ if camera_position.ndim == 2:
31
+ look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
32
+ up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
33
+
34
+ # OpenGL camera: z-backward, x-right, y-up
35
+ z_axis = camera_position - look_at
36
+ z_axis = F.normalize(z_axis, dim=-1).float()
37
+ x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
38
+ x_axis = F.normalize(x_axis, dim=-1).float()
39
+ y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
40
+ y_axis = F.normalize(y_axis, dim=-1).float()
41
+
42
+ extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
43
+ extrinsics = pad_camera_extrinsics_4x4(extrinsics)
44
+ return extrinsics
45
+
46
+
47
+ def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
48
+ azimuths = np.deg2rad(azimuths)
49
+ elevations = np.deg2rad(elevations)
50
+
51
+ xs = radius * np.cos(elevations) * np.cos(azimuths)
52
+ ys = radius * np.cos(elevations) * np.sin(azimuths)
53
+ zs = radius * np.sin(elevations)
54
+
55
+ cam_locations = np.stack([xs, ys, zs], axis=-1)
56
+ cam_locations = torch.from_numpy(cam_locations).float()
57
+
58
+ c2ws = center_looking_at_camera_pose(cam_locations)
59
+ return c2ws
60
+
61
+
62
+ def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
63
+ # M: number of circular views
64
+ # radius: camera dist to center
65
+ # elevation: elevation degrees of the camera
66
+ # return: (M, 4, 4)
67
+ assert M > 0 and radius > 0
68
+
69
+ elevation = np.deg2rad(elevation)
70
+
71
+ camera_positions = []
72
+ for i in range(M):
73
+ azimuth = 2 * np.pi * i / M
74
+ x = radius * np.cos(elevation) * np.cos(azimuth)
75
+ y = radius * np.cos(elevation) * np.sin(azimuth)
76
+ z = radius * np.sin(elevation)
77
+ camera_positions.append([x, y, z])
78
+ camera_positions = np.array(camera_positions)
79
+ camera_positions = torch.from_numpy(camera_positions).float()
80
+ extrinsics = center_looking_at_camera_pose(camera_positions)
81
+ return extrinsics
82
+
83
+
84
+ def FOV_to_intrinsics(fov, device='cpu'):
85
+ """
86
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
87
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
88
+ Assumes principal point is at image center.
89
+ """
90
+ focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
91
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
92
+ return intrinsics
93
+
94
+
95
+ def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
96
+ """
97
+ Get the input camera parameters.
98
+ """
99
+ azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
100
+ elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
101
+
102
+ c2ws = spherical_camera_pose(azimuths, elevations, radius)
103
+ c2ws = c2ws.float().flatten(-2)
104
+
105
+ Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
106
+
107
+ extrinsics = c2ws[:, :12]
108
+ intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
109
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
110
+
111
+ return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
core/instant_utils/infer_util.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import rembg
4
+ import torch
5
+ import numpy as np
6
+ import PIL.Image
7
+ from PIL import Image
8
+ from typing import Any
9
+
10
+
11
+ def remove_background(image: PIL.Image.Image,
12
+ rembg_session: Any = None,
13
+ force: bool = False,
14
+ **rembg_kwargs,
15
+ ) -> PIL.Image.Image:
16
+ do_remove = True
17
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
18
+ do_remove = False
19
+ do_remove = do_remove or force
20
+ if do_remove:
21
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
22
+ return image
23
+
24
+
25
+ def resize_foreground(
26
+ image: PIL.Image.Image,
27
+ ratio: float,
28
+ ) -> PIL.Image.Image:
29
+ image = np.array(image)
30
+ assert image.shape[-1] == 4
31
+ alpha = np.where(image[..., 3] > 0)
32
+ y1, y2, x1, x2 = (
33
+ alpha[0].min(),
34
+ alpha[0].max(),
35
+ alpha[1].min(),
36
+ alpha[1].max(),
37
+ )
38
+ # crop the foreground
39
+ fg = image[y1:y2, x1:x2]
40
+ # pad to square
41
+ size = max(fg.shape[0], fg.shape[1])
42
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
43
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
44
+ new_image = np.pad(
45
+ fg,
46
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
47
+ mode="constant",
48
+ constant_values=((0, 0), (0, 0), (0, 0)),
49
+ )
50
+
51
+ # compute padding according to the ratio
52
+ new_size = int(new_image.shape[0] / ratio)
53
+ # pad to size, double side
54
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
55
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
56
+ new_image = np.pad(
57
+ new_image,
58
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
59
+ mode="constant",
60
+ constant_values=((0, 0), (0, 0), (0, 0)),
61
+ )
62
+ new_image = PIL.Image.fromarray(new_image)
63
+ return new_image
64
+
65
+
66
+ def images_to_video(
67
+ images: torch.Tensor,
68
+ output_path: str,
69
+ fps: int = 30,
70
+ ) -> None:
71
+ # images: (N, C, H, W)
72
+ video_dir = os.path.dirname(output_path)
73
+ video_name = os.path.basename(output_path)
74
+ os.makedirs(video_dir, exist_ok=True)
75
+
76
+ frames = []
77
+ for i in range(len(images)):
78
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
79
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
80
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
81
+ assert frame.min() >= 0 and frame.max() <= 255, \
82
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
83
+ frames.append(frame)
84
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
85
+
86
+
87
+ def save_video(
88
+ frames: torch.Tensor,
89
+ output_path: str,
90
+ fps: int = 30,
91
+ ) -> None:
92
+ # images: (N, C, H, W)
93
+ frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]
94
+ writer = imageio.get_writer(output_path, fps=fps)
95
+ for frame in frames:
96
+ writer.append_data(frame)
97
+ writer.close()
core/instant_utils/mesh_util.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import xatlas
11
+ import trimesh
12
+ import cv2
13
+ import numpy as np
14
+ import nvdiffrast.torch as dr
15
+ from PIL import Image
16
+
17
+
18
+ def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):
19
+
20
+ pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
21
+ #facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
22
+
23
+ mesh = trimesh.Trimesh(
24
+ vertices=pointnp_px3,
25
+ faces=facenp_fx3,
26
+ vertex_colors=colornp_px3,
27
+ )
28
+ mesh.export(fpath, 'obj')
29
+
30
+
31
+ def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
32
+
33
+ pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
34
+
35
+ mesh = trimesh.Trimesh(
36
+ vertices=pointnp_px3,
37
+ faces=facenp_fx3,
38
+ vertex_colors=colornp_px3,
39
+ )
40
+ mesh.export(fpath, 'glb')
41
+
42
+
43
+ def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
44
+ import os
45
+ fol, na = os.path.split(fname)
46
+ na, _ = os.path.splitext(na)
47
+
48
+ matname = '%s/%s.mtl' % (fol, na)
49
+ fid = open(matname, 'w')
50
+ fid.write('newmtl material_0\n')
51
+ fid.write('Kd 1 1 1\n')
52
+ fid.write('Ka 0 0 0\n')
53
+ fid.write('Ks 0.4 0.4 0.4\n')
54
+ fid.write('Ns 10\n')
55
+ fid.write('illum 2\n')
56
+ fid.write('map_Kd %s.png\n' % na)
57
+ fid.close()
58
+ ####
59
+
60
+ fid = open(fname, 'w')
61
+ fid.write('mtllib %s.mtl\n' % na)
62
+
63
+ for pidx, p in enumerate(pointnp_px3):
64
+ pp = p
65
+ fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
66
+
67
+ for pidx, p in enumerate(tcoords_px2):
68
+ pp = p
69
+ fid.write('vt %f %f\n' % (pp[0], pp[1]))
70
+
71
+ fid.write('usemtl material_0\n')
72
+ for i, f in enumerate(facenp_fx3):
73
+ f1 = f + 1
74
+ f2 = facetex_fx3[i] + 1
75
+ fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
76
+ fid.close()
77
+
78
+ # save texture map
79
+ lo, hi = 0, 1
80
+ img = np.asarray(texmap_hxwx3, dtype=np.float32)
81
+ img = (img - lo) * (255 / (hi - lo))
82
+ img = img.clip(0, 255)
83
+ mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
84
+ mask = (mask <= 3.0).astype(np.float32)
85
+ kernel = np.ones((3, 3), 'uint8')
86
+ dilate_img = cv2.dilate(img, kernel, iterations=1)
87
+ img = img * (1 - mask) + dilate_img * mask
88
+ img = img.clip(0, 255).astype(np.uint8)
89
+ Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png')
90
+
91
+
92
+ def loadobj(meshfile):
93
+ v = []
94
+ f = []
95
+ meshfp = open(meshfile, 'r')
96
+ for line in meshfp.readlines():
97
+ data = line.strip().split(' ')
98
+ data = [da for da in data if len(da) > 0]
99
+ if len(data) != 4:
100
+ continue
101
+ if data[0] == 'v':
102
+ v.append([float(d) for d in data[1:]])
103
+ if data[0] == 'f':
104
+ data = [da.split('/')[0] for da in data]
105
+ f.append([int(d) for d in data[1:]])
106
+ meshfp.close()
107
+
108
+ # torch need int64
109
+ facenp_fx3 = np.array(f, dtype=np.int64) - 1
110
+ pointnp_px3 = np.array(v, dtype=np.float32)
111
+ return pointnp_px3, facenp_fx3
112
+
113
+
114
+ def loadobjtex(meshfile):
115
+ v = []
116
+ vt = []
117
+ f = []
118
+ ft = []
119
+ meshfp = open(meshfile, 'r')
120
+ for line in meshfp.readlines():
121
+ data = line.strip().split(' ')
122
+ data = [da for da in data if len(da) > 0]
123
+ if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
124
+ continue
125
+ if data[0] == 'v':
126
+ assert len(data) == 4
127
+
128
+ v.append([float(d) for d in data[1:]])
129
+ if data[0] == 'vt':
130
+ if len(data) == 3 or len(data) == 4:
131
+ vt.append([float(d) for d in data[1:3]])
132
+ if data[0] == 'f':
133
+ data = [da.split('/') for da in data]
134
+ if len(data) == 4:
135
+ f.append([int(d[0]) for d in data[1:]])
136
+ ft.append([int(d[1]) for d in data[1:]])
137
+ elif len(data) == 5:
138
+ idx1 = [1, 2, 3]
139
+ data1 = [data[i] for i in idx1]
140
+ f.append([int(d[0]) for d in data1])
141
+ ft.append([int(d[1]) for d in data1])
142
+ idx2 = [1, 3, 4]
143
+ data2 = [data[i] for i in idx2]
144
+ f.append([int(d[0]) for d in data2])
145
+ ft.append([int(d[1]) for d in data2])
146
+ meshfp.close()
147
+
148
+ # torch need int64
149
+ facenp_fx3 = np.array(f, dtype=np.int64) - 1
150
+ ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
151
+ pointnp_px3 = np.array(v, dtype=np.float32)
152
+ uvs = np.array(vt, dtype=np.float32)
153
+ return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
154
+
155
+
156
+ # ==============================================================================================
157
+ def interpolate(attr, rast, attr_idx, rast_db=None):
158
+ return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
159
+
160
+
161
+ def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
162
+ vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
163
+
164
+ # Convert to tensors
165
+ indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
166
+
167
+ uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
168
+ mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
169
+ # mesh_v_tex. ture
170
+ uv_clip = uvs[None, ...] * 2.0 - 1.0
171
+
172
+ # pad to four component coordinate
173
+ uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
174
+
175
+ # rasterize
176
+ rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
177
+
178
+ # Interpolate world space position
179
+ gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
180
+ mask = rast[..., 3:4] > 0
181
+ return uvs, mesh_tex_idx, gb_pos, mask
core/instant_utils/train_util.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def count_params(model, verbose=False):
5
+ total_params = sum(p.numel() for p in model.parameters())
6
+ if verbose:
7
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
8
+ return total_params
9
+
10
+
11
+ def instantiate_from_config(config):
12
+ if not "target" in config:
13
+ if config == '__is_first_stage__':
14
+ return None
15
+ elif config == "__is_unconditional__":
16
+ return None
17
+ raise KeyError("Expected key `target` to instantiate.")
18
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
19
+
20
+
21
+ def get_obj_from_str(string, reload=False):
22
+ module, cls = string.rsplit(".", 1)
23
+ if reload:
24
+ module_imp = importlib.import_module(module)
25
+ importlib.reload(module_imp)
26
+ return getattr(importlib.import_module(module, package=None), cls)
core/lrm_reconstructor.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import numpy as np
6
+ from typing import Tuple, Literal
7
+ from functools import partial
8
+
9
+ import itertools
10
+
11
+
12
+ # LRM
13
+ from .embedder import CameraEmbedder
14
+ from .transformer import TransformerDecoder
15
+ # from accelerate.logging import get_logger
16
+
17
+ # logger = get_logger(__name__)
18
+
19
+
20
+ class LRM_VSD_Mesh_Net(nn.Module):
21
+ """
22
+ predict VSD using transformer
23
+ """
24
+ def __init__(self, camera_embed_dim: int,
25
+ transformer_dim: int, transformer_layers: int, transformer_heads: int,
26
+ triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
27
+ encoder_freeze: bool = True, encoder_type: str = 'dino',
28
+ encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, app_dim = 27, density_dim = 8, app_n_comp=24,
29
+ density_n_comp=8):
30
+ super().__init__()
31
+
32
+ # attributes
33
+ self.encoder_feat_dim = encoder_feat_dim
34
+ self.camera_embed_dim = camera_embed_dim
35
+ self.triplane_low_res = triplane_low_res
36
+ self.triplane_high_res = triplane_high_res
37
+ self.triplane_dim = triplane_dim
38
+ self.transformer_dim=transformer_dim
39
+
40
+ # modules
41
+ self.encoder = self._encoder_fn(encoder_type)(
42
+ model_name=encoder_model_name,
43
+ modulation_dim=self.camera_embed_dim, #mod camera vector
44
+ freeze=encoder_freeze,
45
+ )
46
+ self.camera_embedder = CameraEmbedder(
47
+ raw_dim=12+4, embed_dim=camera_embed_dim,
48
+ )
49
+
50
+ self.n_comp=app_n_comp+density_n_comp
51
+ self.app_dim=app_dim
52
+ self.density_dim=density_dim
53
+ self.app_n_comp=app_n_comp
54
+ self.density_n_comp=density_n_comp
55
+
56
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*(triplane_low_res**2)+3*triplane_low_res, transformer_dim) * (1. / transformer_dim) ** 0.5)
57
+ self.transformer = TransformerDecoder(
58
+ block_type='cond',
59
+ num_layers=transformer_layers, num_heads=transformer_heads,
60
+ inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=None,
61
+ )
62
+ # for plane
63
+ self.upsampler = nn.ConvTranspose2d(transformer_dim, self.n_comp, kernel_size=2, stride=2, padding=0)
64
+ self.dim_map = nn.Linear(transformer_dim,self.n_comp)
65
+ self.up_line = nn.Linear(triplane_low_res,triplane_low_res*2)
66
+
67
+
68
+ @staticmethod
69
+ def _encoder_fn(encoder_type: str):
70
+ encoder_type = encoder_type.lower()
71
+ assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type"
72
+ if encoder_type == 'dino':
73
+ from .encoders.dino_wrapper import DinoWrapper
74
+ #logger.info("Using DINO as the encoder")
75
+ return DinoWrapper
76
+ elif encoder_type == 'dinov2':
77
+ from .encoders.dinov2_wrapper import Dinov2Wrapper
78
+ #logger.info("Using DINOv2 as the encoder")
79
+ return Dinov2Wrapper
80
+
81
+ def forward_transformer(self, image_feats, camera_embeddings=None):
82
+ N = image_feats.shape[0]
83
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
84
+ x = self.transformer(
85
+ x,
86
+ cond=image_feats,
87
+ mod=camera_embeddings,
88
+ )
89
+ return x
90
+ def reshape_upsample(self, tokens):
91
+ #B,_,3*ncomp
92
+ N = tokens.shape[0]
93
+ H = W = self.triplane_low_res
94
+ P=self.n_comp
95
+
96
+ offset=3*H*W
97
+
98
+ # planes
99
+ plane_tokens= tokens[:,:3*H*W,:].view(N,H,W,3,self.transformer_dim)
100
+ plane_tokens = torch.einsum('nhwip->inphw', plane_tokens) # [3, N, P, H, W]
101
+ plane_tokens = plane_tokens.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
102
+ plane_tokens = self.upsampler(plane_tokens) # [3*N, P, H', W']
103
+ plane_tokens = plane_tokens.view(3, N, *plane_tokens.shape[-3:]) # [3, N, P, H', W']
104
+ plane_tokens = torch.einsum('inphw->niphw', plane_tokens) # [N, 3, P, H', W']
105
+ plane_tokens = plane_tokens.reshape(N, 3*P, *plane_tokens.shape[-2:]) # # [N, 3*P, H', W']
106
+ plane_tokens = plane_tokens.contiguous()
107
+
108
+ #lines
109
+ line_tokens= tokens[:,3*H*W:3*H*W+3*H,:].view(N,H,3,self.transformer_dim)
110
+ line_tokens= self.dim_map(line_tokens)
111
+ line_tokens = torch.einsum('nhip->npih', line_tokens) # [ N, P, 3, H]
112
+ line_tokens=self.up_line(line_tokens)
113
+ line_tokens = torch.einsum('npih->niph', line_tokens) # [ N, 3, P, H]
114
+ line_tokens=line_tokens.reshape(N,3*P,line_tokens.shape[-1],1)
115
+ line_tokens = line_tokens.contiguous()
116
+
117
+ mat_tokens=None
118
+
119
+ d_mat_tokens=None
120
+
121
+ return plane_tokens[:,:self.app_n_comp*3,:,:],line_tokens[:,:self.app_n_comp*3,:,:],mat_tokens,d_mat_tokens,plane_tokens[:,self.app_n_comp*3:,:,:],line_tokens[:,self.app_n_comp*3:,:,:]
122
+
123
+ def forward_planes(self, image, camera):
124
+ # image: [N, V, C_img, H_img, W_img]
125
+ # camera: [N,V, D_cam_raw]
126
+ N,V,_,H,W = image.shape
127
+ image=image.reshape(N*V,3,H,W)
128
+ camera=camera.reshape(N*V,-1)
129
+
130
+
131
+ # embed camera
132
+ camera_embeddings = self.camera_embedder(camera)
133
+ assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
134
+ f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
135
+
136
+ # encode image
137
+ image_feats = self.encoder(image, camera_embeddings)
138
+ assert image_feats.shape[-1] == self.encoder_feat_dim, \
139
+ f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
140
+
141
+ image_feats=image_feats.reshape(N,V*image_feats.shape[-2],image_feats.shape[-1])
142
+
143
+ # transformer generating planes
144
+ tokens = self.forward_transformer(image_feats)
145
+
146
+ app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.reshape_upsample(tokens)
147
+
148
+ return app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines
149
+
150
+ def forward(self, image,source_camera):
151
+ # image: [N,V, C_img, H_img, W_img]
152
+ # source_camera: [N, V, D_cam_raw]
153
+
154
+ assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera"
155
+ planes = self.forward_planes(image, source_camera)
156
+
157
+ #B,3,dim,H,W
158
+ return planes
core/models.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import mcubes
6
+
7
+ import kiui
8
+ from kiui.lpips import LPIPS
9
+
10
+ from core.lrm_reconstructor import LRM_VSD_Mesh_Net
11
+ from core.options import Options
12
+ from core.tensoRF import TensorVMSplit_Mesh,TensorVMSplit_NeRF
13
+ from torchvision.transforms import v2
14
+ from core.geometry.camera.perspective_camera import PerspectiveCamera
15
+ from core.geometry.render.neural_render import NeuralRender
16
+ from core.geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
17
+ import nvdiffrast.torch as dr
18
+ from core.instant_utils.mesh_util import xatlas_uvmap
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ #tensorSDF + transformer + volume_rendering
23
+ class LTRFM_NeRF(nn.Module):
24
+ def __init__(
25
+ self,
26
+ opt: Options,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.opt = opt
31
+
32
+
33
+ #predict svd using transformer
34
+ self.vsd_net = LRM_VSD_Mesh_Net(
35
+ camera_embed_dim=opt.camera_embed_dim,
36
+ transformer_dim=opt.transformer_dim,
37
+ transformer_layers=opt.transformer_layers,
38
+ transformer_heads=opt.transformer_heads,
39
+ triplane_low_res=opt.triplane_low_res,
40
+ triplane_high_res=opt.triplane_high_res,
41
+ triplane_dim=opt.triplane_dim,
42
+ encoder_freeze=opt.encoder_freeze,
43
+ encoder_type=opt.encoder_type,
44
+ encoder_model_name=opt.encoder_model_name,
45
+ encoder_feat_dim=opt.encoder_feat_dim,
46
+ app_dim=opt.app_dim,
47
+ density_dim=opt.density_dim,
48
+ app_n_comp=opt.app_n_comp,
49
+ density_n_comp=opt.density_n_comp,
50
+ )
51
+
52
+ aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device)
53
+ grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device)
54
+ near_far =torch.tensor([opt.znear, opt.zfar]).to(device)
55
+ # tensorf Renderer
56
+ self.tensorRF = TensorVMSplit_NeRF(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\
57
+ density_dim=opt.density_dim,near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe)
58
+
59
+ # LPIPS loss
60
+ if self.opt.lambda_lpips > 0:
61
+ self.lpips_loss = LPIPS(net='vgg')
62
+ self.lpips_loss.requires_grad_(False)
63
+
64
+
65
+ def state_dict(self, **kwargs):
66
+ # remove lpips_loss
67
+ state_dict = super().state_dict(**kwargs)
68
+ for k in list(state_dict.keys()):
69
+ if 'lpips_loss' in k:
70
+ del state_dict[k]
71
+ return state_dict
72
+
73
+ def set_beta(self,t):
74
+ self.tensorRF.lap_density.set_beta(t)
75
+
76
+
77
+
78
+ # predict svd_volume
79
+ def forward_svd_volume(self, images, data):
80
+ # images: [B, 4, 9, H, W]
81
+ # return: Gaussians: [B, dim_t]
82
+ B, V, C, H, W = images.shape
83
+
84
+
85
+ source_camera=data['source_camera']
86
+ images_vit=data['input_vit'] # for transformer
87
+ source_camera=source_camera.reshape(B,V,-1) # [B*V, 16]
88
+ app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera)
89
+
90
+
91
+ app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size)
92
+ app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1)
93
+ density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size)
94
+ density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1)
95
+
96
+ results = {
97
+ 'app_planes': app_planes,
98
+ 'app_lines': app_lines,
99
+ 'basis_mat':basis_mat,
100
+ 'd_basis_mat':d_basis_mat,
101
+ 'density_planes':density_planes,
102
+ 'density_lines':density_lines
103
+ }
104
+
105
+ return results
106
+
107
+ def extract_mesh(self,
108
+ planes: torch.Tensor,
109
+ mesh_resolution: int = 256,
110
+ mesh_threshold: int = 0.009,
111
+ use_texture_map: bool = False,
112
+ texture_resolution: int = 1024,):
113
+
114
+ device = planes['app_planes'].device
115
+
116
+ grid_size = mesh_resolution
117
+ points = torch.linspace(-1, 1, steps=grid_size).half()
118
+
119
+ x, y, z = torch.meshgrid(points, points, points)
120
+
121
+ xyz_samples = torch.stack((x, y, z), dim=0).unsqueeze(0).to(device)
122
+ xyz_samples=xyz_samples.permute(0,2,3,4,1)
123
+ xyz_samples=xyz_samples.view(1,-1,1,3)
124
+
125
+
126
+ grid_out = self.tensorRF.predict_sdf(planes,xyz_samples)
127
+ grid_out['sigma']=grid_out['sigma'].view(grid_size,grid_size,grid_size).float()
128
+
129
+ vertices, faces = mcubes.marching_cubes(
130
+ grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
131
+ mesh_threshold,
132
+ )
133
+ vertices = vertices / (mesh_resolution - 1) * 2 - 1
134
+
135
+ if not use_texture_map:
136
+ # query vertex colors
137
+ vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
138
+ rgb_colors = self.tensorRF.predict_color(
139
+ planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
140
+ rgb_colors = (rgb_colors * 255).astype(np.uint8)
141
+
142
+ albedob_colors = self.tensorRF.predict_color(
143
+ planes, vertices_tensor)['albedo'].squeeze(0).cpu().numpy()
144
+ albedob_colors = (albedob_colors * 255).astype(np.uint8)
145
+
146
+ shading_colors = self.tensorRF.predict_color(
147
+ planes, vertices_tensor)['shading'].squeeze(0).cpu().numpy()
148
+ shading_colors = (shading_colors * 255).astype(np.uint8)
149
+
150
+ return vertices, faces, [rgb_colors,albedob_colors,shading_colors]
151
+
152
+ # use x-atlas to get uv mapping for the mesh
153
+ vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
154
+ faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
155
+
156
+ ctx = dr.RasterizeCudaContext(device=device)
157
+ uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
158
+ ctx, vertices, faces, resolution=texture_resolution)
159
+ tex_hard_mask = tex_hard_mask.float().cpu()
160
+
161
+ # query the texture field to get the RGB color for texture map
162
+ #TBD here
163
+ query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3)
164
+
165
+ vertices_colors = self.tensorRF.predict_color(
166
+ planes, query_vertices)['rgb'].squeeze(0).cpu()
167
+
168
+ vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3)
169
+
170
+ background_feature = torch.zeros_like(vertices_colors)
171
+ img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask.half())
172
+ texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
173
+ #albedo
174
+ vertices_colors_albedo = self.tensorRF.predict_color(
175
+ planes, query_vertices)['albedo'].squeeze(0).cpu()
176
+
177
+ vertices_colors_albedo=vertices_colors_albedo.reshape(1,texture_resolution,texture_resolution,3)
178
+
179
+ background_feature = torch.zeros_like(vertices_colors_albedo)
180
+ img_feat = torch.lerp(background_feature, vertices_colors_albedo, tex_hard_mask.half())
181
+ texture_map_albedo = img_feat.permute(0, 3, 1, 2).squeeze(0)
182
+
183
+ return vertices, faces, uvs, mesh_tex_idx, [texture_map,texture_map_albedo]
184
+
185
+
186
+ def forward(self, data, step_ratio=1):
187
+ # data: output of the dataloader
188
+ # return: loss
189
+ #self.set_beta(data['t'])
190
+ results = {}
191
+ loss = 0
192
+
193
+ images = data['input'] # [B, 4, 9, h, W], input features
194
+
195
+ # use the first view to predict gaussians
196
+ svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
197
+
198
+ results['svd_volume'] = svd_volume
199
+
200
+ # always use white bg
201
+ bg_color = torch.ones(3, dtype=torch.float32).to(device)
202
+
203
+ # use the other views for rendering and supervision
204
+ results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
205
+ pred_shading = results['image'] # [B, V, C, output_size, output_size]
206
+ pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
207
+ pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
208
+
209
+ pred_images = pred_shading*pred_albedos
210
+
211
+ results['images_pred'] = pred_images
212
+ results['alphas_pred'] = pred_alphas
213
+ results['pred_albedos'] = pred_albedos
214
+ results['pred_shading'] = pred_shading
215
+
216
+
217
+ gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
218
+ gt_albedos = data['albedos_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
219
+ gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
220
+
221
+ gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
222
+ gt_albedos = gt_albedos * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
223
+
224
+ loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + F.mse_loss(pred_albedos, gt_albedos)
225
+ loss = loss + loss_mse
226
+
227
+ # eikonal_loss = ((results['eik_grads'].norm(2, dim=1) - 1) ** 2).mean()
228
+ # loss = loss+ 0.1*eikonal_loss
229
+
230
+ if self.opt.lambda_lpips > 0:
231
+ loss_lpips = self.lpips_loss(
232
+ F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
233
+ F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
234
+ ).mean()
235
+ results['loss_lpips'] = loss_lpips
236
+ loss = loss + self.opt.lambda_lpips * loss_lpips
237
+
238
+ results['loss'] = loss
239
+
240
+ # metric
241
+ with torch.no_grad():
242
+ psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
243
+ results['psnr'] = psnr
244
+
245
+ return results
246
+
247
+
248
+ def render_frame(self, data):
249
+ # data: output of the dataloader
250
+ # return: loss
251
+ #self.set_beta(data['t'])
252
+ results = {}
253
+ loss = 0
254
+
255
+ images = data['input_vit']
256
+
257
+ # use the first view to predict gaussians
258
+ svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
259
+
260
+ results['svd_volume'] = svd_volume
261
+
262
+ # always use white bg
263
+ bg_color = torch.ones(3, dtype=torch.float32).to(device)
264
+
265
+ # use the other views for rendering and supervision
266
+ results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
267
+ pred_shading = results['image'] # [B, V, C, output_size, output_size]
268
+ pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
269
+ pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
270
+
271
+ pred_images = pred_shading*pred_albedos
272
+
273
+ results['images_pred'] = pred_images
274
+ results['alphas_pred'] = pred_alphas
275
+ results['pred_albedos'] = pred_albedos
276
+ results['pred_shading'] = pred_shading
277
+
278
+
279
+ return results
280
+
281
+
282
+
283
+
284
+
285
+ #tensorSDF + transformer + SDF + Mesh
286
+ class LTRFM_Mesh(nn.Module):
287
+ def __init__(
288
+ self,
289
+ opt: Options,
290
+ ):
291
+ super().__init__()
292
+
293
+ self.opt = opt
294
+
295
+ # attributes
296
+ self.grid_res = 128 #grid_res
297
+ self.grid_scale = 2.0 #grid_scale
298
+ self.deformation_multiplier = 4.0
299
+
300
+
301
+ self.init_flexicubes_geometry(device, self.opt)
302
+
303
+ #predict svd using transformer
304
+ self.vsd_net = LRM_VSD_Mesh_Net(
305
+ camera_embed_dim=opt.camera_embed_dim,
306
+ transformer_dim=opt.transformer_dim,
307
+ transformer_layers=opt.transformer_layers,
308
+ transformer_heads=opt.transformer_heads,
309
+ triplane_low_res=opt.triplane_low_res,
310
+ triplane_high_res=opt.triplane_high_res,
311
+ triplane_dim=opt.triplane_dim,
312
+ encoder_freeze=opt.encoder_freeze,
313
+ encoder_type=opt.encoder_type,
314
+ encoder_model_name=opt.encoder_model_name,
315
+ encoder_feat_dim=opt.encoder_feat_dim,
316
+ app_dim=opt.app_dim,
317
+ density_dim=opt.density_dim,
318
+ app_n_comp=opt.app_n_comp,
319
+ density_n_comp=opt.density_n_comp,
320
+ )
321
+
322
+ aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device)
323
+ grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device)
324
+ near_far =torch.tensor([opt.znear, opt.zfar]).to(device)
325
+ # tensorf Renderer
326
+ self.tensorRF = TensorVMSplit_Mesh(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\
327
+ density_dim=opt.density_dim, near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe)
328
+
329
+ # LPIPS loss
330
+ if self.opt.lambda_lpips > 0:
331
+ self.lpips_loss = LPIPS(net='vgg')
332
+ self.lpips_loss.requires_grad_(False)
333
+
334
+
335
+ # load ckpt
336
+ if opt.ckpt_nerf is not None:
337
+ sd = torch.load(opt.ckpt_nerf, map_location='cpu')['model']
338
+ #sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
339
+ sd_fc = {}
340
+ for k, v in sd.items():
341
+ k=k.replace('module.', '')
342
+ if k.startswith('vsd.renderModule.'):
343
+ continue
344
+ else:
345
+ sd_fc[k] = v
346
+ sd_fc = {k.replace('vsd_net.', ''): v for k, v in sd_fc.items()}
347
+ sd_fc = {k.replace('tensorRF.', ''): v for k, v in sd_fc.items()}
348
+ # missing `net_deformation` and `net_weight` parameters
349
+ self.vsd_net.load_state_dict(sd_fc, strict=False)
350
+ self.tensorRF.load_state_dict(sd_fc, strict=False)
351
+ print(f'Loaded weights from {opt.ckpt_nerf}')
352
+
353
+
354
+ def state_dict(self, **kwargs):
355
+ # remove lpips_loss
356
+ state_dict = super().state_dict(**kwargs)
357
+ for k in list(state_dict.keys()):
358
+ if 'lpips_loss' in k:
359
+ del state_dict[k]
360
+ return state_dict
361
+
362
+
363
+ # predict svd_volume
364
+ def forward_svd_volume(self, images, data):
365
+ # images: [B, 4, 9, H, W]
366
+ # return: Gaussians: [B, dim_t]
367
+ B, V, C, H, W = images.shape
368
+
369
+ source_camera=data['source_camera']
370
+ images_vit=data['input_vit'] # for transformer
371
+ source_camera=source_camera.reshape(B,V,-1) # [B*V, 16]
372
+ app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera)
373
+
374
+
375
+ app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size)
376
+ app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1)
377
+ density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size)
378
+ density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1)
379
+
380
+ results = {
381
+ 'app_planes': app_planes,
382
+ 'app_lines': app_lines,
383
+ 'basis_mat':basis_mat,
384
+ 'd_basis_mat':d_basis_mat,
385
+ 'density_planes':density_planes,
386
+ 'density_lines':density_lines
387
+ }
388
+
389
+ return results
390
+
391
+
392
+ def init_flexicubes_geometry(self, device, opt):
393
+ camera = PerspectiveCamera(opt, device=device)
394
+ renderer = NeuralRender(device, camera_model=camera)
395
+ self.geometry = FlexiCubesGeometry(
396
+ grid_res=self.grid_res,
397
+ scale=self.grid_scale,
398
+ renderer=renderer,
399
+ render_type='neural_render',
400
+ device=device,
401
+ )
402
+
403
+
404
+ # query vsd for sdf weight and ...
405
+ def get_sdf_deformation_prediction(self, planes):
406
+ '''
407
+ Predict SDF and deformation for tetrahedron vertices
408
+ :param planes: triplane feature map for the geometry
409
+ '''
410
+ B = planes['app_lines'].shape[0]
411
+ init_position = self.geometry.verts.unsqueeze(0).expand(B, -1, -1)
412
+
413
+
414
+ sdf, deformation, weight = self.tensorRF.get_geometry_prediction(planes,init_position,self.geometry.indices)
415
+
416
+ deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
417
+ sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
418
+
419
+ sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
420
+ sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
421
+ pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
422
+ neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
423
+ zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
424
+ if torch.sum(zero_surface).item() > 0:
425
+ update_sdf = torch.zeros_like(sdf[0:1])
426
+ max_sdf = sdf.max()
427
+ min_sdf = sdf.min()
428
+ update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero
429
+ update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero
430
+ new_sdf = torch.zeros_like(sdf)
431
+ for i_batch in range(zero_surface.shape[0]):
432
+ if zero_surface[i_batch]:
433
+ new_sdf[i_batch:i_batch + 1] += update_sdf
434
+ update_mask = (new_sdf == 0).float()
435
+ # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
436
+ sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
437
+ sdf_reg_loss = sdf_reg_loss * zero_surface.float()
438
+ sdf = sdf * update_mask + new_sdf * (1 - update_mask)
439
+
440
+ final_sdf = []
441
+ final_def = []
442
+ for i_batch in range(zero_surface.shape[0]):
443
+ if zero_surface[i_batch]:
444
+ final_sdf.append(sdf[i_batch: i_batch + 1].detach())
445
+ final_def.append(deformation[i_batch: i_batch + 1].detach())
446
+ else:
447
+ final_sdf.append(sdf[i_batch: i_batch + 1])
448
+ final_def.append(deformation[i_batch: i_batch + 1])
449
+ sdf = torch.cat(final_sdf, dim=0)
450
+ deformation = torch.cat(final_def, dim=0)
451
+ return sdf, deformation, sdf_reg_loss, weight
452
+
453
+ def get_geometry_prediction(self, planes=None):
454
+ '''
455
+ Function to generate mesh with give triplanes
456
+ :param planes: triplane features
457
+ '''
458
+
459
+ sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
460
+
461
+
462
+ v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
463
+ tets = self.geometry.indices
464
+ n_batch = planes['app_planes'].shape[0]
465
+ v_list = []
466
+ f_list = []
467
+ flexicubes_surface_reg_list = []
468
+
469
+
470
+ for i_batch in range(n_batch):
471
+ verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
472
+ v_deformed[i_batch],
473
+ sdf[i_batch].squeeze(dim=-1),
474
+ with_uv=False,
475
+ indices=tets,
476
+ weight_n=weight[i_batch].squeeze(dim=-1),
477
+ is_training=self.training,
478
+ )
479
+ flexicubes_surface_reg_list.append(flexicubes_surface_reg)
480
+ v_list.append(verts)
481
+ f_list.append(faces)
482
+
483
+ flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
484
+ flexicubes_weight_reg = (weight ** 2).mean()
485
+
486
+ return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
487
+
488
+ def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
489
+ '''
490
+ Predict Texture given triplanes
491
+ :param planes: the triplane feature map
492
+ :param tex_pos: Position we want to query the texture field
493
+ :param hard_mask: 2D silhoueete of the rendered image
494
+ '''
495
+ B = planes['app_planes'].shape[0]
496
+ tex_pos = torch.cat(tex_pos, dim=0)
497
+ if not hard_mask is None:
498
+ tex_pos = tex_pos * hard_mask.float()
499
+ batch_size = tex_pos.shape[0]
500
+ tex_pos = tex_pos.reshape(batch_size, -1, 3)
501
+ ###################
502
+ # We use mask to get the texture location (to save the memory)
503
+ if hard_mask is not None:
504
+ n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
505
+ sample_tex_pose_list = []
506
+ max_point = n_point_list.max()
507
+ if max_point==0: # xrg: hard mask may filter all points, and don not left any point
508
+ max_point=max_point+1
509
+ expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
510
+ for i in range(tex_pos.shape[0]):
511
+ tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
512
+ if tex_pos_one_shape.shape[1] < max_point:
513
+ tex_pos_one_shape = torch.cat(
514
+ [tex_pos_one_shape, torch.zeros(
515
+ 1, max_point - tex_pos_one_shape.shape[1], 3,
516
+ device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
517
+ sample_tex_pose_list.append(tex_pos_one_shape)
518
+ tex_pos = torch.cat(sample_tex_pose_list, dim=0)
519
+
520
+
521
+ #return texture rgb
522
+ tex_feat = self.tensorRF.get_texture_prediction(tex_pos,vsd_vome=planes)
523
+
524
+ if hard_mask is not None:
525
+ final_tex_feat = torch.zeros(
526
+ B, hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
527
+ expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
528
+ for i in range(B):
529
+ final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
530
+ tex_feat = final_tex_feat
531
+
532
+ return tex_feat.reshape(B, hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
533
+
534
+ def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
535
+ '''
536
+ Function to render a generated mesh with nvdiffrast
537
+ :param mesh_v: List of vertices for the mesh
538
+ :param mesh_f: List of faces for the mesh
539
+ :param cam_mv: 4x4 rotation matrix
540
+ :return:
541
+ '''
542
+ return_value_list = []
543
+ for i_mesh in range(len(mesh_v)):
544
+ return_value = self.geometry.render_mesh(
545
+ mesh_v[i_mesh],
546
+ mesh_f[i_mesh].int(),
547
+ cam_mv[i_mesh],
548
+ resolution=render_size,
549
+ hierarchical_mask=False
550
+ )
551
+ return_value_list.append(return_value)
552
+
553
+ return_keys = return_value_list[0].keys()
554
+ return_value = dict()
555
+ for k in return_keys:
556
+ value = [v[k] for v in return_value_list]
557
+ return_value[k] = value
558
+
559
+ mask = torch.cat(return_value['mask'], dim=0)
560
+ hard_mask = torch.cat(return_value['hard_mask'], dim=0)
561
+ tex_pos = return_value['tex_pos']
562
+ depth = torch.cat(return_value['depth'], dim=0)
563
+ normal = torch.cat(return_value['normal'], dim=0)
564
+ return mask, hard_mask, tex_pos, depth, normal
565
+
566
+ def forward_geometry(self, planes, render_cameras, render_size=256):
567
+ '''
568
+ Main function of our Generator. It first generate 3D mesh, then render it into 2D image
569
+ with given `render_cameras`.
570
+ :param planes: triplane features
571
+ :param render_cameras: cameras to render generated 3D shape, a w2c matrix
572
+ '''
573
+ B, NV = render_cameras.shape[:2]
574
+
575
+ # Generate 3D mesh first
576
+ mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
577
+
578
+ # Render the mesh into 2D image (get 3d position of each image plane) continue for here
579
+ cam_mv = render_cameras
580
+ run_n_view = cam_mv.shape[1]
581
+ antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
582
+
583
+ tex_hard_mask = hard_mask
584
+ tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
585
+ tex_hard_mask = torch.cat(
586
+ [torch.cat(
587
+ [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
588
+ for i_view in range(run_n_view)], dim=2)
589
+ for i in range(B)], dim=0)
590
+
591
+ # Querying the texture field to predict the texture feature for each pixel on the image
592
+ tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
593
+ background_feature = torch.ones_like(tex_feat) # white background
594
+
595
+ # Merge them together
596
+ img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
597
+
598
+ # We should split it back to the original image shape
599
+ img_feat = torch.cat(
600
+ [torch.cat(
601
+ [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
602
+ for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
603
+
604
+ img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
605
+
606
+ albedo=img[:,:,3:6,:,:]
607
+ img=img[:,:,0:3,:,:]
608
+
609
+ antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
610
+ depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive
611
+ normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
612
+
613
+ out = {
614
+ 'image': img,
615
+ 'albedo': albedo,
616
+ 'mask': antilias_mask,
617
+ 'depth': depth,
618
+ 'normal': normal,
619
+ 'sdf': sdf,
620
+ 'mesh_v': mesh_v,
621
+ 'mesh_f': mesh_f,
622
+ 'sdf_reg_loss': sdf_reg_loss,
623
+ }
624
+ return out
625
+
626
+ def forward(self, data, step_ratio=1):
627
+ # data: output of the dataloader
628
+ # return: loss
629
+
630
+ results = {}
631
+ loss = 0
632
+
633
+ images = data['input'] # [B, 4, 9, h, W], input features
634
+
635
+ # use the first view to predict gaussians
636
+ svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
637
+
638
+ results['svd_volume'] = svd_volume
639
+
640
+ # return the rendered images
641
+ results = self.forward_geometry(svd_volume, data['w2c'], self.opt.output_size)
642
+
643
+
644
+ # always use white bg
645
+ bg_color = torch.ones(3, dtype=torch.float32).to(device)
646
+
647
+ # use the other views for rendering and supervision
648
+ #results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
649
+
650
+
651
+ pred_shading = results['image'] # [B, V, C, output_size, output_size]
652
+ pred_alphas = results['mask'] # [B, V, 1, output_size, output_size]
653
+ pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
654
+
655
+ pred_images=pred_shading*pred_albedos
656
+
657
+ results['images_pred'] = pred_images
658
+ results['alphas_pred'] = pred_alphas
659
+ results['pred_albedos'] = pred_albedos
660
+ results['pred_shading'] = pred_shading
661
+
662
+
663
+ gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
664
+ gt_albedos = data['albedos_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
665
+ gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
666
+
667
+ gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
668
+ gt_albedos = gt_albedos * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
669
+
670
+ loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + F.mse_loss(pred_albedos, gt_albedos)
671
+ loss = loss + loss_mse
672
+
673
+ if self.opt.lambda_lpips > 0:
674
+ loss_lpips = self.lpips_loss(
675
+ F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
676
+ F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
677
+ ).mean()
678
+ results['loss_lpips'] = loss_lpips
679
+ loss = loss + self.opt.lambda_lpips * loss_lpips
680
+
681
+ results['loss'] = loss
682
+
683
+ # metric
684
+ with torch.no_grad():
685
+ psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
686
+ results['psnr'] = psnr
687
+
688
+ return results
689
+
690
+
691
+ def render_frame(self, data):
692
+ # data: output of the dataloader
693
+ # return: loss
694
+
695
+ results = {}
696
+
697
+ images = data['input_vit'] # [B, 4, 9, h, W], input features
698
+
699
+ # use the first view to predict gaussians
700
+ svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
701
+
702
+ results['svd_volume'] = svd_volume
703
+
704
+ # return the rendered images
705
+ results = self.forward_geometry(svd_volume, data['w2c'], self.opt.infer_render_size)
706
+
707
+
708
+ # always use white bg
709
+ bg_color = torch.ones(3, dtype=torch.float32).to(device)
710
+
711
+
712
+ pred_shading = results['image'] # [B, V, C, output_size, output_size]
713
+ pred_alphas = results['mask'] # [B, V, 1, output_size, output_size]
714
+ pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
715
+
716
+ pred_images=pred_shading*pred_albedos
717
+
718
+ results['images_pred'] = pred_images
719
+ results['alphas_pred'] = pred_alphas
720
+ results['pred_albedos'] = pred_albedos
721
+ results['pred_shading'] = pred_shading
722
+
723
+ return results
724
+
725
+ def extract_mesh(
726
+ self,
727
+ planes: torch.Tensor,
728
+ use_texture_map: bool = False,
729
+ texture_resolution: int = 1024,
730
+ **kwargs,
731
+ ):
732
+ '''
733
+ Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
734
+ :param planes: triplane features
735
+ :param use_texture_map: use texture map or vertex color
736
+ :param texture_resolution: the resolution of texure map
737
+ '''
738
+ assert planes['app_planes'].shape[0] == 1
739
+ device = planes['app_planes'].device
740
+
741
+
742
+ # predict geometry first
743
+ mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
744
+ vertices, faces = mesh_v[0], mesh_f[0]
745
+
746
+ if not use_texture_map:
747
+ # query vertex colors
748
+ vertices_tensor = vertices.unsqueeze(0)
749
+ rgb_colors = self.tensorRF.predict_color(planes, vertices_tensor)['rgb'].clamp(0, 1).squeeze(0).cpu().numpy()
750
+ rgb_colors = (rgb_colors * 255).astype(np.uint8)
751
+
752
+ albedob_colors = self.tensorRF.predict_color(planes, vertices_tensor)['albedo'].clamp(0, 1).squeeze(0).cpu().numpy()
753
+ albedob_colors = (albedob_colors * 255).astype(np.uint8)
754
+
755
+ shading_colors = self.tensorRF.predict_color(planes, vertices_tensor)['shading'].clamp(0, 1).squeeze(0).cpu().numpy()
756
+ shading_colors = (shading_colors * 255).astype(np.uint8)
757
+
758
+
759
+ return vertices.cpu().numpy(), faces.cpu().numpy(), [rgb_colors,albedob_colors,shading_colors]
760
+
761
+ # use x-atlas to get uv mapping for the mesh
762
+ ctx = dr.RasterizeCudaContext(device=device)
763
+ uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
764
+ self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
765
+
766
+ tex_hard_mask = tex_hard_mask.float().cpu()
767
+
768
+ # query the texture field to get the RGB color for texture map
769
+ #TBD here
770
+ query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3)
771
+
772
+ vertices_colors = self.tensorRF.predict_color(
773
+ planes, query_vertices)['rgb'].squeeze(0).cpu()
774
+
775
+ vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3)
776
+
777
+ background_feature = torch.zeros_like(vertices_colors)
778
+ img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask)
779
+ texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
780
+
781
+ return vertices, faces, uvs, mesh_tex_idx, texture_map
782
+
783
+
core/modulate.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class ModLN(nn.Module):
21
+ """
22
+ Modulation with adaLN.
23
+
24
+ References:
25
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
26
+ """
27
+ def __init__(self, inner_dim: int, mod_dim: int, eps: float):
28
+ super().__init__()
29
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
30
+ self.mlp = nn.Sequential(
31
+ nn.SiLU(),
32
+ nn.Linear(mod_dim, inner_dim * 2),
33
+ )
34
+
35
+ @staticmethod
36
+ def modulate(x, shift, scale):
37
+ # x: [N, L, D]
38
+ # shift, scale: [N, D]
39
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
40
+
41
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
42
+ shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D]
43
+ return self.modulate(self.norm(x), shift, scale) # [N, L, D]