xrg
commited on
Commit
•
915f69b
1
Parent(s):
16ef2cb
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- .vscode/launch.json +23 -0
- app.py +385 -0
- core/__init__.py +0 -0
- core/block.py +124 -0
- core/embedder.py +37 -0
- core/encoders/__init__.py +15 -0
- core/encoders/dino_wrapper.py +68 -0
- core/encoders/dinov2/__init__.py +15 -0
- core/encoders/dinov2/hub/__init__.py +4 -0
- core/encoders/dinov2/hub/backbones.py +166 -0
- core/encoders/dinov2/hub/classifiers.py +268 -0
- core/encoders/dinov2/hub/depth/__init__.py +7 -0
- core/encoders/dinov2/hub/depth/decode_heads.py +747 -0
- core/encoders/dinov2/hub/depth/encoder_decoder.py +351 -0
- core/encoders/dinov2/hub/depth/ops.py +28 -0
- core/encoders/dinov2/hub/depthers.py +246 -0
- core/encoders/dinov2/hub/utils.py +39 -0
- core/encoders/dinov2/layers/__init__.py +20 -0
- core/encoders/dinov2/layers/attention.py +89 -0
- core/encoders/dinov2/layers/block.py +296 -0
- core/encoders/dinov2/layers/dino_head.py +58 -0
- core/encoders/dinov2/layers/drop_path.py +34 -0
- core/encoders/dinov2/layers/layer_scale.py +27 -0
- core/encoders/dinov2/layers/mlp.py +40 -0
- core/encoders/dinov2/layers/patch_embed.py +88 -0
- core/encoders/dinov2/layers/swiglu_ffn.py +72 -0
- core/encoders/dinov2/models/__init__.py +43 -0
- core/encoders/dinov2/models/vision_transformer.py +443 -0
- core/encoders/dinov2_wrapper.py +67 -0
- core/geometry/__init__.py +7 -0
- core/geometry/camera/__init__.py +16 -0
- core/geometry/camera/perspective_camera.py +51 -0
- core/geometry/render/__init__.py +8 -0
- core/geometry/render/neural_render.py +121 -0
- core/geometry/rep_3d/__init__.py +18 -0
- core/geometry/rep_3d/dmtet.py +504 -0
- core/geometry/rep_3d/dmtet_utils.py +20 -0
- core/geometry/rep_3d/extract_texture_map.py +40 -0
- core/geometry/rep_3d/flexicubes.py +579 -0
- core/geometry/rep_3d/flexicubes_geometry.py +120 -0
- core/geometry/rep_3d/tables.py +791 -0
- core/instant_utils/__init__.py +0 -0
- core/instant_utils/camera_util.py +111 -0
- core/instant_utils/infer_util.py +97 -0
- core/instant_utils/mesh_util.py +181 -0
- core/instant_utils/train_util.py +26 -0
- core/lrm_reconstructor.py +158 -0
- core/models.py +783 -0
- core/modulate.py +43 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
workspace_test
|
2 |
+
__pycache__
|
.vscode/launch.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "0.2.0",
|
3 |
+
"configurations": [
|
4 |
+
{
|
5 |
+
"name": "app",
|
6 |
+
"type": "debugpy",
|
7 |
+
"request": "launch",
|
8 |
+
"program": "./app.py",
|
9 |
+
"console": "integratedTerminal",
|
10 |
+
"env": {
|
11 |
+
"CUDA_VISIBLE_DEVICES": "2"
|
12 |
+
},
|
13 |
+
// "args": [
|
14 |
+
// "tiny_trf_trans_nerf",//"tiny_trf_trans_nerf" tiny_trf_trans_nerf_123plus
|
15 |
+
// "--resume",
|
16 |
+
// "pretrained/last6view060804_24.ckpt",//"pretrained/last_060302_49.ckpt",//"pretrained/last_060302_49.ckpt",
|
17 |
+
// "--output_size",
|
18 |
+
// "64"
|
19 |
+
// ],
|
20 |
+
"justMyCode": true
|
21 |
+
},
|
22 |
+
]
|
23 |
+
}
|
app.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tyro
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
import tqdm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms.functional as TF
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
import rembg
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
import kiui
|
15 |
+
from kiui.op import recenter
|
16 |
+
from kiui.cam import orbit_camera
|
17 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
18 |
+
|
19 |
+
from core.options import AllConfigs, Options
|
20 |
+
from core.models import LTRFM_Mesh,LTRFM_NeRF
|
21 |
+
from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl
|
22 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
23 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
27 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
28 |
+
GRADIO_VIDEO_PATH = 'gradio_output.mp4'
|
29 |
+
GRADIO_OBJ_PATH = 'gradio_output_rgb.obj'
|
30 |
+
GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj'
|
31 |
+
GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj'
|
32 |
+
|
33 |
+
#opt = tyro.cli(AllConfigs)
|
34 |
+
|
35 |
+
ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt")
|
36 |
+
|
37 |
+
opt = Options(
|
38 |
+
input_size=512,
|
39 |
+
down_channels=(32, 64, 128, 256, 512),
|
40 |
+
down_attention=(False, False, False, False, True),
|
41 |
+
up_channels=(512, 256, 128),
|
42 |
+
up_attention=(True, False, False, False),
|
43 |
+
volume_mode='TRF_NeRF',
|
44 |
+
splat_size=64,
|
45 |
+
output_size=62, #crop patch
|
46 |
+
data_mode='s5',
|
47 |
+
num_views=8,
|
48 |
+
gradient_accumulation_steps=1, #2
|
49 |
+
mixed_precision='bf16',
|
50 |
+
resume=ckpt_path,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
# model
|
55 |
+
if opt.volume_mode == 'TRF_Mesh':
|
56 |
+
model = LTRFM_Mesh(opt)
|
57 |
+
elif opt.volume_mode == 'TRF_NeRF':
|
58 |
+
model = LTRFM_NeRF(opt)
|
59 |
+
else:
|
60 |
+
model = LGM(opt)
|
61 |
+
|
62 |
+
# resume pretrained checkpoint
|
63 |
+
if opt.resume is not None:
|
64 |
+
if opt.resume.endswith('safetensors'):
|
65 |
+
ckpt = load_file(opt.resume, device='cpu')
|
66 |
+
else: #ckpt
|
67 |
+
ckpt_dict = torch.load(opt.resume, map_location='cpu')
|
68 |
+
ckpt=ckpt_dict["model"]
|
69 |
+
|
70 |
+
state_dict = model.state_dict()
|
71 |
+
for k, v in ckpt.items():
|
72 |
+
k=k.replace('module.', '')
|
73 |
+
if k in state_dict:
|
74 |
+
if state_dict[k].shape == v.shape:
|
75 |
+
state_dict[k].copy_(v)
|
76 |
+
else:
|
77 |
+
print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
|
78 |
+
else:
|
79 |
+
print(f'[WARN] unexpected param {k}: {v.shape}')
|
80 |
+
print(f'[INFO] load resume success!')
|
81 |
+
|
82 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
83 |
+
model = model.half().to(device)
|
84 |
+
model.eval()
|
85 |
+
|
86 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
87 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
|
88 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
89 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
90 |
+
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
91 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
92 |
+
proj_matrix[2, 3] = 1
|
93 |
+
|
94 |
+
# load dreams
|
95 |
+
pipe_text = MVDreamPipeline.from_pretrained(
|
96 |
+
'ashawkey/mvdream-sd2.1-diffusers', # remote weights
|
97 |
+
torch_dtype=torch.float16,
|
98 |
+
trust_remote_code=True,
|
99 |
+
# local_files_only=True,
|
100 |
+
)
|
101 |
+
pipe_text = pipe_text.to(device)
|
102 |
+
|
103 |
+
# mvdream
|
104 |
+
pipe_image = MVDreamPipeline.from_pretrained(
|
105 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
106 |
+
torch_dtype=torch.float16,
|
107 |
+
trust_remote_code=True,
|
108 |
+
# local_files_only=True,
|
109 |
+
)
|
110 |
+
pipe_image = pipe_image.to(device)
|
111 |
+
|
112 |
+
|
113 |
+
print('Loading 123plus model ...')
|
114 |
+
pipe_image_plus = DiffusionPipeline.from_pretrained(
|
115 |
+
"sudo-ai/zero123plus-v1.2",
|
116 |
+
custom_pipeline="zero123plus",
|
117 |
+
torch_dtype=torch.float16,
|
118 |
+
trust_remote_code=True,
|
119 |
+
local_files_only=True,
|
120 |
+
)
|
121 |
+
pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
122 |
+
pipe_image_plus.scheduler.config, timestep_spacing='trailing'
|
123 |
+
)
|
124 |
+
|
125 |
+
unet_path='./pretrained/diffusion_pytorch_model.bin'
|
126 |
+
|
127 |
+
print('Loading custom white-background unet ...')
|
128 |
+
if os.path.exists(unet_path):
|
129 |
+
unet_ckpt_path = unet_path
|
130 |
+
else:
|
131 |
+
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
132 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
133 |
+
pipe_image_plus.unet.load_state_dict(state_dict, strict=True)
|
134 |
+
pipe_image_plus = pipe_image_plus.to(device)
|
135 |
+
|
136 |
+
# load rembg
|
137 |
+
bg_remover = rembg.new_session()
|
138 |
+
|
139 |
+
# process function
|
140 |
+
def process(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None):
|
141 |
+
|
142 |
+
# seed
|
143 |
+
kiui.seed_everything(input_seed)
|
144 |
+
|
145 |
+
os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True)
|
146 |
+
output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH)
|
147 |
+
output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH)
|
148 |
+
output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH)
|
149 |
+
output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH)
|
150 |
+
|
151 |
+
# text-conditioned
|
152 |
+
if condition_input_image is None:
|
153 |
+
mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation)
|
154 |
+
mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
|
155 |
+
# bg removal
|
156 |
+
mv_image = []
|
157 |
+
for i in range(4):
|
158 |
+
image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
|
159 |
+
# to white bg
|
160 |
+
image = image.astype(np.float32) / 255
|
161 |
+
image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
|
162 |
+
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
|
163 |
+
mv_image.append(image)
|
164 |
+
|
165 |
+
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
|
166 |
+
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
|
167 |
+
|
168 |
+
processed_image=None
|
169 |
+
# image-conditioned (may also input text, but no text usually works too)
|
170 |
+
else:
|
171 |
+
condition_input_image = np.array(condition_input_image) # uint8
|
172 |
+
# bg removal
|
173 |
+
carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4]
|
174 |
+
mask = carved_image[..., -1] > 0
|
175 |
+
image = recenter(carved_image, mask, border_ratio=0.2)
|
176 |
+
image = image.astype(np.float32) / 255.0
|
177 |
+
processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
178 |
+
|
179 |
+
if mv_moedl_option=='mvdream':
|
180 |
+
mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation)
|
181 |
+
|
182 |
+
mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1)
|
183 |
+
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0)
|
184 |
+
else:
|
185 |
+
from PIL import Image
|
186 |
+
from einops import rearrange, repeat
|
187 |
+
|
188 |
+
# input_image=input_image* 255
|
189 |
+
processed_image = Image.fromarray((processed_image * 255).astype(np.uint8))
|
190 |
+
mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0]
|
191 |
+
mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0
|
192 |
+
mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
193 |
+
mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy()
|
194 |
+
mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy()
|
195 |
+
input_image = mv_image
|
196 |
+
|
197 |
+
# generate gaussians
|
198 |
+
# [4, 256, 256, 3], float32
|
199 |
+
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
200 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
201 |
+
|
202 |
+
images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False)
|
203 |
+
|
204 |
+
data = {}
|
205 |
+
input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W]
|
206 |
+
images_input_vit=images_input_vit.unsqueeze(0)
|
207 |
+
data['input_vit']=images_input_vit
|
208 |
+
|
209 |
+
elevation = 0
|
210 |
+
cam_poses =[]
|
211 |
+
if mv_moedl_option=='mvdream' or condition_input_image is None:
|
212 |
+
azimuth = np.arange(0, 360, 90, dtype=np.int32)
|
213 |
+
for azi in tqdm.tqdm(azimuth):
|
214 |
+
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
215 |
+
cam_poses.append(cam_pose)
|
216 |
+
else:
|
217 |
+
azimuth = np.arange(30, 360, 60, dtype=np.int32)
|
218 |
+
cnt = 0
|
219 |
+
for azi in tqdm.tqdm(azimuth):
|
220 |
+
if (cnt+1) % 2!= 0:
|
221 |
+
elevation=-20
|
222 |
+
else:
|
223 |
+
elevation=30
|
224 |
+
cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
225 |
+
cam_poses.append(cam_pose)
|
226 |
+
cnt=cnt+1
|
227 |
+
|
228 |
+
cam_poses = torch.cat(cam_poses,0)
|
229 |
+
radius = torch.norm(cam_poses[0, :3, 3])
|
230 |
+
cam_poses[:, :3, 3] *= opt.cam_radius / radius
|
231 |
+
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0])
|
232 |
+
cam_poses = transform.unsqueeze(0) @ cam_poses
|
233 |
+
|
234 |
+
cam_poses=cam_poses.unsqueeze(0)
|
235 |
+
data['source_camera']=cam_poses
|
236 |
+
|
237 |
+
with torch.no_grad():
|
238 |
+
if opt.volume_mode == 'TRF_Mesh':
|
239 |
+
with torch.autocast(device_type='cuda', dtype=torch.float32):
|
240 |
+
svd_volume = model.forward_svd_volume(input_image,data)
|
241 |
+
else:
|
242 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
243 |
+
svd_volume = model.forward_svd_volume(input_image,data)
|
244 |
+
|
245 |
+
#time-consuming
|
246 |
+
export_texmap=False
|
247 |
+
|
248 |
+
mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap)
|
249 |
+
|
250 |
+
if export_texmap:
|
251 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
252 |
+
|
253 |
+
for i in range(len(tex_map)):
|
254 |
+
mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj')
|
255 |
+
save_obj_with_mtl(
|
256 |
+
vertices.data.cpu().numpy(),
|
257 |
+
uvs.data.cpu().numpy(),
|
258 |
+
faces.data.cpu().numpy(),
|
259 |
+
mesh_tex_idx.data.cpu().numpy(),
|
260 |
+
tex_map[i].permute(1, 2, 0).data.cpu().numpy(),
|
261 |
+
mesh_path,
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
vertices, faces, vertex_colors = mesh_out
|
265 |
+
|
266 |
+
save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path)
|
267 |
+
save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path)
|
268 |
+
save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path)
|
269 |
+
|
270 |
+
|
271 |
+
return mv_image_grid, processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path
|
272 |
+
|
273 |
+
# gradio UI
|
274 |
+
|
275 |
+
_TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation'''
|
276 |
+
|
277 |
+
_DESCRIPTION = '''
|
278 |
+
|
279 |
+
|
280 |
+
* Input can be text prompt, image.
|
281 |
+
* If you find the output unsatisfying, try using different seeds!
|
282 |
+
'''
|
283 |
+
|
284 |
+
block = gr.Blocks(title=_TITLE).queue()
|
285 |
+
with block:
|
286 |
+
with gr.Row():
|
287 |
+
with gr.Column(scale=1):
|
288 |
+
gr.Markdown('# ' + _TITLE)
|
289 |
+
gr.Markdown(_DESCRIPTION)
|
290 |
+
|
291 |
+
with gr.Row(variant='panel'):
|
292 |
+
with gr.Column(scale=1):
|
293 |
+
with gr.Tab("Image-to-3D"):
|
294 |
+
# input image
|
295 |
+
with gr.Row():
|
296 |
+
condition_input_image = gr.Image(
|
297 |
+
label="Input Image",
|
298 |
+
image_mode="RGBA",
|
299 |
+
type="pil"
|
300 |
+
)
|
301 |
+
|
302 |
+
processed_image = gr.Image(
|
303 |
+
label="Processed Image",
|
304 |
+
image_mode="RGBA",
|
305 |
+
type="pil",
|
306 |
+
interactive=False
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
with gr.Row():
|
311 |
+
mv_moedl_option = gr.Radio([
|
312 |
+
"zero123plus",
|
313 |
+
"mvdream"
|
314 |
+
], value="zero123plus",
|
315 |
+
label="Multi-view Diffusion")
|
316 |
+
|
317 |
+
with gr.Row(variant="panel"):
|
318 |
+
gr.Examples(
|
319 |
+
examples=[
|
320 |
+
os.path.join("example", img_name) for img_name in sorted(os.listdir("example"))
|
321 |
+
],
|
322 |
+
inputs=[condition_input_image],
|
323 |
+
fn=lambda x: process(condition_input_image=x, prompt=''),
|
324 |
+
cache_examples=False,
|
325 |
+
examples_per_page=20,
|
326 |
+
label='Image-to-3D Examples'
|
327 |
+
)
|
328 |
+
|
329 |
+
with gr.Tab("Text-to-3D"):
|
330 |
+
# input prompt
|
331 |
+
with gr.Row():
|
332 |
+
input_text = gr.Textbox(label="prompt")
|
333 |
+
# negative prompt
|
334 |
+
with gr.Row():
|
335 |
+
input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate')
|
336 |
+
|
337 |
+
with gr.Row(variant="panel"):
|
338 |
+
gr.Examples(
|
339 |
+
examples=[
|
340 |
+
"a hamburger",
|
341 |
+
"a furry red fox head",
|
342 |
+
"a teddy bear",
|
343 |
+
"a motorbike",
|
344 |
+
],
|
345 |
+
inputs=[input_text],
|
346 |
+
fn=lambda x: process(condition_input_image=None, prompt=x),
|
347 |
+
cache_examples=False,
|
348 |
+
label='Text-to-3D Examples'
|
349 |
+
)
|
350 |
+
|
351 |
+
# elevation
|
352 |
+
input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
|
353 |
+
# inference steps
|
354 |
+
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30)
|
355 |
+
# random seed
|
356 |
+
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
|
357 |
+
# gen button
|
358 |
+
button_gen = gr.Button("Generate")
|
359 |
+
|
360 |
+
|
361 |
+
with gr.Column(scale=1):
|
362 |
+
with gr.Row():
|
363 |
+
# multi-view results
|
364 |
+
mv_image_grid = gr.Image(interactive=False, show_label=False)
|
365 |
+
with gr.Row():
|
366 |
+
output_obj_rgb_path = gr.Model3D(
|
367 |
+
label="RGB Model (OBJ Format)",
|
368 |
+
interactive=False,
|
369 |
+
)
|
370 |
+
with gr.Row():
|
371 |
+
output_obj_albedo_path = gr.Model3D(
|
372 |
+
label="Albedo Model (OBJ Format)",
|
373 |
+
interactive=False,
|
374 |
+
)
|
375 |
+
with gr.Row():
|
376 |
+
output_obj_shading_path = gr.Model3D(
|
377 |
+
label="Shading Model (OBJ Format)",
|
378 |
+
interactive=False,
|
379 |
+
)
|
380 |
+
|
381 |
+
|
382 |
+
button_gen.click(process, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed,mv_moedl_option], outputs=[mv_image_grid,processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path])
|
383 |
+
|
384 |
+
|
385 |
+
block.launch(server_name="0.0.0.0", share=False)
|
core/__init__.py
ADDED
File without changes
|
core/block.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
from .modulate import ModLN
|
19 |
+
|
20 |
+
|
21 |
+
class BasicBlock(nn.Module):
|
22 |
+
"""
|
23 |
+
Transformer block that is in its simplest form.
|
24 |
+
Designed for PF-LRM architecture.
|
25 |
+
"""
|
26 |
+
# Block contains a self-attention layer and an MLP
|
27 |
+
def __init__(self, inner_dim: int, num_heads: int, eps: float,
|
28 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
29 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
30 |
+
super().__init__()
|
31 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
32 |
+
self.self_attn = nn.MultiheadAttention(
|
33 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
34 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
35 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
36 |
+
self.mlp = nn.Sequential(
|
37 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
38 |
+
nn.GELU(),
|
39 |
+
nn.Dropout(mlp_drop),
|
40 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
41 |
+
nn.Dropout(mlp_drop),
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
# x: [N, L, D]
|
46 |
+
before_sa = self.norm1(x)
|
47 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
48 |
+
x = x + self.mlp(self.norm2(x))
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
class ConditionBlock(nn.Module):
|
53 |
+
"""
|
54 |
+
Transformer block that takes in a cross-attention condition.
|
55 |
+
Designed for SparseLRM architecture.
|
56 |
+
"""
|
57 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
58 |
+
def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
|
59 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
60 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
61 |
+
super().__init__()
|
62 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
63 |
+
self.cross_attn = nn.MultiheadAttention(
|
64 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
65 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
66 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
67 |
+
self.self_attn = nn.MultiheadAttention(
|
68 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
69 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
70 |
+
self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
|
71 |
+
self.mlp = nn.Sequential(
|
72 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
73 |
+
nn.GELU(),
|
74 |
+
nn.Dropout(mlp_drop),
|
75 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
76 |
+
nn.Dropout(mlp_drop),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x, cond):
|
80 |
+
# x: [N, L, D]
|
81 |
+
# cond: [N, L_cond, D_cond]
|
82 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
|
83 |
+
before_sa = self.norm2(x)
|
84 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
85 |
+
x = x + self.mlp(self.norm3(x))
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class ConditionModulationBlock(nn.Module):
|
90 |
+
"""
|
91 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
92 |
+
Designed for raw LRM architecture.
|
93 |
+
"""
|
94 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
95 |
+
def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
|
96 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
97 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
98 |
+
super().__init__()
|
99 |
+
self.norm1 = ModLN(inner_dim, mod_dim, eps)
|
100 |
+
self.cross_attn = nn.MultiheadAttention(
|
101 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
102 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
103 |
+
self.norm2 = ModLN(inner_dim, mod_dim, eps)
|
104 |
+
self.self_attn = nn.MultiheadAttention(
|
105 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
106 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
107 |
+
self.norm3 = ModLN(inner_dim, mod_dim, eps)
|
108 |
+
self.mlp = nn.Sequential(
|
109 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
110 |
+
nn.GELU(),
|
111 |
+
nn.Dropout(mlp_drop),
|
112 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
113 |
+
nn.Dropout(mlp_drop),
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x, cond, mod):
|
117 |
+
# x: [N, L, D]
|
118 |
+
# cond: [N, L_cond, D_cond]
|
119 |
+
# mod: [N, D_mod]
|
120 |
+
x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
|
121 |
+
before_sa = self.norm2(x, mod)
|
122 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
123 |
+
x = x + self.mlp(self.norm3(x, mod))
|
124 |
+
return x
|
core/embedder.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class CameraEmbedder(nn.Module):
|
21 |
+
"""
|
22 |
+
Embed camera features to a high-dimensional vector.
|
23 |
+
|
24 |
+
Reference:
|
25 |
+
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
|
26 |
+
"""
|
27 |
+
def __init__(self, raw_dim: int, embed_dim: int):
|
28 |
+
super().__init__()
|
29 |
+
self.mlp = nn.Sequential(
|
30 |
+
nn.Linear(raw_dim, embed_dim),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(embed_dim, embed_dim),
|
33 |
+
)
|
34 |
+
|
35 |
+
@torch.compile
|
36 |
+
def forward(self, x):
|
37 |
+
return self.mlp(x)
|
core/encoders/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Empty
|
core/encoders/dino_wrapper.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from transformers import ViTImageProcessor, ViTModel
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
|
21 |
+
|
22 |
+
logger = get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class DinoWrapper(nn.Module):
|
26 |
+
"""
|
27 |
+
Dino v1 wrapper using huggingface transformer implementation.
|
28 |
+
"""
|
29 |
+
def __init__(self, model_name: str, freeze: bool = True):
|
30 |
+
super().__init__()
|
31 |
+
self.model, self.processor = self._build_dino(model_name)
|
32 |
+
if freeze:
|
33 |
+
self._freeze()
|
34 |
+
|
35 |
+
@torch.compile
|
36 |
+
def forward_model(self, inputs):
|
37 |
+
return self.model(**inputs, interpolate_pos_encoding=True)
|
38 |
+
|
39 |
+
def forward(self, image):
|
40 |
+
# image: [N, C, H, W], on cpu
|
41 |
+
# RGB image with [0,1] scale and properly sized
|
42 |
+
inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
|
43 |
+
# This resampling of positional embedding uses bicubic interpolation
|
44 |
+
outputs = self.forward_model(inputs)
|
45 |
+
last_hidden_states = outputs.last_hidden_state
|
46 |
+
return last_hidden_states
|
47 |
+
|
48 |
+
def _freeze(self):
|
49 |
+
logger.warning(f"======== Freezing DinoWrapper ========")
|
50 |
+
self.model.eval()
|
51 |
+
for name, param in self.model.named_parameters():
|
52 |
+
param.requires_grad = False
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
56 |
+
import requests
|
57 |
+
try:
|
58 |
+
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
59 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
60 |
+
return model, processor
|
61 |
+
except requests.exceptions.ProxyError as err:
|
62 |
+
if proxy_error_retries > 0:
|
63 |
+
print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
|
64 |
+
import time
|
65 |
+
time.sleep(proxy_error_cooldown)
|
66 |
+
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
67 |
+
else:
|
68 |
+
raise err
|
core/encoders/dinov2/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Empty
|
core/encoders/dinov2/hub/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
core/encoders/dinov2/hub/backbones.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
12 |
+
|
13 |
+
|
14 |
+
class Weights(Enum):
|
15 |
+
LVD142M = "LVD142M"
|
16 |
+
|
17 |
+
|
18 |
+
def _make_dinov2_model(
|
19 |
+
*,
|
20 |
+
arch_name: str = "vit_large",
|
21 |
+
img_size: int = 518,
|
22 |
+
patch_size: int = 14,
|
23 |
+
init_values: float = 1.0,
|
24 |
+
ffn_layer: str = "mlp",
|
25 |
+
block_chunks: int = 0,
|
26 |
+
num_register_tokens: int = 0,
|
27 |
+
interpolate_antialias: bool = False,
|
28 |
+
interpolate_offset: float = 0.1,
|
29 |
+
pretrained: bool = True,
|
30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
from ..models import vision_transformer as vits
|
34 |
+
|
35 |
+
if isinstance(weights, str):
|
36 |
+
try:
|
37 |
+
weights = Weights[weights]
|
38 |
+
except KeyError:
|
39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
40 |
+
|
41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
42 |
+
vit_kwargs = dict(
|
43 |
+
img_size=img_size,
|
44 |
+
patch_size=patch_size,
|
45 |
+
init_values=init_values,
|
46 |
+
ffn_layer=ffn_layer,
|
47 |
+
block_chunks=block_chunks,
|
48 |
+
num_register_tokens=num_register_tokens,
|
49 |
+
interpolate_antialias=interpolate_antialias,
|
50 |
+
interpolate_offset=interpolate_offset,
|
51 |
+
)
|
52 |
+
vit_kwargs.update(**kwargs)
|
53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
54 |
+
|
55 |
+
if pretrained:
|
56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
59 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
60 |
+
state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
|
61 |
+
if vit_kwargs.get("modulation_dim") is not None:
|
62 |
+
state_dict = {
|
63 |
+
k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
|
64 |
+
for k, v in state_dict.items()
|
65 |
+
}
|
66 |
+
model.load_state_dict(state_dict, strict=False)
|
67 |
+
else:
|
68 |
+
model.load_state_dict(state_dict, strict=True)
|
69 |
+
# ********************************************************
|
70 |
+
|
71 |
+
return model
|
72 |
+
|
73 |
+
|
74 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
75 |
+
"""
|
76 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
77 |
+
"""
|
78 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
82 |
+
"""
|
83 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
84 |
+
"""
|
85 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
86 |
+
|
87 |
+
|
88 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
89 |
+
"""
|
90 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
91 |
+
"""
|
92 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
93 |
+
|
94 |
+
|
95 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
96 |
+
"""
|
97 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
98 |
+
"""
|
99 |
+
return _make_dinov2_model(
|
100 |
+
arch_name="vit_giant2",
|
101 |
+
ffn_layer="swiglufused",
|
102 |
+
weights=weights,
|
103 |
+
pretrained=pretrained,
|
104 |
+
**kwargs,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
109 |
+
"""
|
110 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
111 |
+
"""
|
112 |
+
return _make_dinov2_model(
|
113 |
+
arch_name="vit_small",
|
114 |
+
pretrained=pretrained,
|
115 |
+
weights=weights,
|
116 |
+
num_register_tokens=4,
|
117 |
+
interpolate_antialias=True,
|
118 |
+
interpolate_offset=0.0,
|
119 |
+
**kwargs,
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
124 |
+
"""
|
125 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
126 |
+
"""
|
127 |
+
return _make_dinov2_model(
|
128 |
+
arch_name="vit_base",
|
129 |
+
pretrained=pretrained,
|
130 |
+
weights=weights,
|
131 |
+
num_register_tokens=4,
|
132 |
+
interpolate_antialias=True,
|
133 |
+
interpolate_offset=0.0,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
139 |
+
"""
|
140 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
141 |
+
"""
|
142 |
+
return _make_dinov2_model(
|
143 |
+
arch_name="vit_large",
|
144 |
+
pretrained=pretrained,
|
145 |
+
weights=weights,
|
146 |
+
num_register_tokens=4,
|
147 |
+
interpolate_antialias=True,
|
148 |
+
interpolate_offset=0.0,
|
149 |
+
**kwargs,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
154 |
+
"""
|
155 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
156 |
+
"""
|
157 |
+
return _make_dinov2_model(
|
158 |
+
arch_name="vit_giant2",
|
159 |
+
ffn_layer="swiglufused",
|
160 |
+
weights=weights,
|
161 |
+
pretrained=pretrained,
|
162 |
+
num_register_tokens=4,
|
163 |
+
interpolate_antialias=True,
|
164 |
+
interpolate_offset=0.0,
|
165 |
+
**kwargs,
|
166 |
+
)
|
core/encoders/dinov2/hub/classifiers.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .backbones import _make_dinov2_model
|
13 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
14 |
+
|
15 |
+
|
16 |
+
class Weights(Enum):
|
17 |
+
IMAGENET1K = "IMAGENET1K"
|
18 |
+
|
19 |
+
|
20 |
+
def _make_dinov2_linear_classification_head(
|
21 |
+
*,
|
22 |
+
arch_name: str = "vit_large",
|
23 |
+
patch_size: int = 14,
|
24 |
+
embed_dim: int = 1024,
|
25 |
+
layers: int = 4,
|
26 |
+
pretrained: bool = True,
|
27 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
28 |
+
num_register_tokens: int = 0,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
if layers not in (1, 4):
|
32 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
33 |
+
if isinstance(weights, str):
|
34 |
+
try:
|
35 |
+
weights = Weights[weights]
|
36 |
+
except KeyError:
|
37 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
38 |
+
|
39 |
+
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
|
40 |
+
|
41 |
+
if pretrained:
|
42 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
43 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
44 |
+
layers_str = str(layers) if layers == 4 else ""
|
45 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
|
46 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
47 |
+
linear_head.load_state_dict(state_dict, strict=True)
|
48 |
+
|
49 |
+
return linear_head
|
50 |
+
|
51 |
+
|
52 |
+
class _LinearClassifierWrapper(nn.Module):
|
53 |
+
def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
|
54 |
+
super().__init__()
|
55 |
+
self.backbone = backbone
|
56 |
+
self.linear_head = linear_head
|
57 |
+
self.layers = layers
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
if self.layers == 1:
|
61 |
+
x = self.backbone.forward_features(x)
|
62 |
+
cls_token = x["x_norm_clstoken"]
|
63 |
+
patch_tokens = x["x_norm_patchtokens"]
|
64 |
+
# fmt: off
|
65 |
+
linear_input = torch.cat([
|
66 |
+
cls_token,
|
67 |
+
patch_tokens.mean(dim=1),
|
68 |
+
], dim=1)
|
69 |
+
# fmt: on
|
70 |
+
elif self.layers == 4:
|
71 |
+
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
|
72 |
+
# fmt: off
|
73 |
+
linear_input = torch.cat([
|
74 |
+
x[0][1],
|
75 |
+
x[1][1],
|
76 |
+
x[2][1],
|
77 |
+
x[3][1],
|
78 |
+
x[3][0].mean(dim=1),
|
79 |
+
], dim=1)
|
80 |
+
# fmt: on
|
81 |
+
else:
|
82 |
+
assert False, f"Unsupported number of layers: {self.layers}"
|
83 |
+
return self.linear_head(linear_input)
|
84 |
+
|
85 |
+
|
86 |
+
def _make_dinov2_linear_classifier(
|
87 |
+
*,
|
88 |
+
arch_name: str = "vit_large",
|
89 |
+
layers: int = 4,
|
90 |
+
pretrained: bool = True,
|
91 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
92 |
+
num_register_tokens: int = 0,
|
93 |
+
interpolate_antialias: bool = False,
|
94 |
+
interpolate_offset: float = 0.1,
|
95 |
+
**kwargs,
|
96 |
+
):
|
97 |
+
backbone = _make_dinov2_model(
|
98 |
+
arch_name=arch_name,
|
99 |
+
pretrained=pretrained,
|
100 |
+
num_register_tokens=num_register_tokens,
|
101 |
+
interpolate_antialias=interpolate_antialias,
|
102 |
+
interpolate_offset=interpolate_offset,
|
103 |
+
**kwargs,
|
104 |
+
)
|
105 |
+
|
106 |
+
embed_dim = backbone.embed_dim
|
107 |
+
patch_size = backbone.patch_size
|
108 |
+
linear_head = _make_dinov2_linear_classification_head(
|
109 |
+
arch_name=arch_name,
|
110 |
+
patch_size=patch_size,
|
111 |
+
embed_dim=embed_dim,
|
112 |
+
layers=layers,
|
113 |
+
pretrained=pretrained,
|
114 |
+
weights=weights,
|
115 |
+
num_register_tokens=num_register_tokens,
|
116 |
+
)
|
117 |
+
|
118 |
+
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
|
119 |
+
|
120 |
+
|
121 |
+
def dinov2_vits14_lc(
|
122 |
+
*,
|
123 |
+
layers: int = 4,
|
124 |
+
pretrained: bool = True,
|
125 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
126 |
+
**kwargs,
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
130 |
+
"""
|
131 |
+
return _make_dinov2_linear_classifier(
|
132 |
+
arch_name="vit_small",
|
133 |
+
layers=layers,
|
134 |
+
pretrained=pretrained,
|
135 |
+
weights=weights,
|
136 |
+
**kwargs,
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def dinov2_vitb14_lc(
|
141 |
+
*,
|
142 |
+
layers: int = 4,
|
143 |
+
pretrained: bool = True,
|
144 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
145 |
+
**kwargs,
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
149 |
+
"""
|
150 |
+
return _make_dinov2_linear_classifier(
|
151 |
+
arch_name="vit_base",
|
152 |
+
layers=layers,
|
153 |
+
pretrained=pretrained,
|
154 |
+
weights=weights,
|
155 |
+
**kwargs,
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
def dinov2_vitl14_lc(
|
160 |
+
*,
|
161 |
+
layers: int = 4,
|
162 |
+
pretrained: bool = True,
|
163 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
164 |
+
**kwargs,
|
165 |
+
):
|
166 |
+
"""
|
167 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
168 |
+
"""
|
169 |
+
return _make_dinov2_linear_classifier(
|
170 |
+
arch_name="vit_large",
|
171 |
+
layers=layers,
|
172 |
+
pretrained=pretrained,
|
173 |
+
weights=weights,
|
174 |
+
**kwargs,
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def dinov2_vitg14_lc(
|
179 |
+
*,
|
180 |
+
layers: int = 4,
|
181 |
+
pretrained: bool = True,
|
182 |
+
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
183 |
+
**kwargs,
|
184 |
+
):
|
185 |
+
"""
|
186 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
187 |
+
"""
|
188 |
+
return _make_dinov2_linear_classifier(
|
189 |
+
arch_name="vit_giant2",
|
190 |
+
layers=layers,
|
191 |
+
ffn_layer="swiglufused",
|
192 |
+
pretrained=pretrained,
|
193 |
+
weights=weights,
|
194 |
+
**kwargs,
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
def dinov2_vits14_reg_lc(
|
199 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
203 |
+
"""
|
204 |
+
return _make_dinov2_linear_classifier(
|
205 |
+
arch_name="vit_small",
|
206 |
+
layers=layers,
|
207 |
+
pretrained=pretrained,
|
208 |
+
weights=weights,
|
209 |
+
num_register_tokens=4,
|
210 |
+
interpolate_antialias=True,
|
211 |
+
interpolate_offset=0.0,
|
212 |
+
**kwargs,
|
213 |
+
)
|
214 |
+
|
215 |
+
|
216 |
+
def dinov2_vitb14_reg_lc(
|
217 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
218 |
+
):
|
219 |
+
"""
|
220 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
221 |
+
"""
|
222 |
+
return _make_dinov2_linear_classifier(
|
223 |
+
arch_name="vit_base",
|
224 |
+
layers=layers,
|
225 |
+
pretrained=pretrained,
|
226 |
+
weights=weights,
|
227 |
+
num_register_tokens=4,
|
228 |
+
interpolate_antialias=True,
|
229 |
+
interpolate_offset=0.0,
|
230 |
+
**kwargs,
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
def dinov2_vitl14_reg_lc(
|
235 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
236 |
+
):
|
237 |
+
"""
|
238 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
239 |
+
"""
|
240 |
+
return _make_dinov2_linear_classifier(
|
241 |
+
arch_name="vit_large",
|
242 |
+
layers=layers,
|
243 |
+
pretrained=pretrained,
|
244 |
+
weights=weights,
|
245 |
+
num_register_tokens=4,
|
246 |
+
interpolate_antialias=True,
|
247 |
+
interpolate_offset=0.0,
|
248 |
+
**kwargs,
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
def dinov2_vitg14_reg_lc(
|
253 |
+
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
257 |
+
"""
|
258 |
+
return _make_dinov2_linear_classifier(
|
259 |
+
arch_name="vit_giant2",
|
260 |
+
layers=layers,
|
261 |
+
ffn_layer="swiglufused",
|
262 |
+
pretrained=pretrained,
|
263 |
+
weights=weights,
|
264 |
+
num_register_tokens=4,
|
265 |
+
interpolate_antialias=True,
|
266 |
+
interpolate_offset=0.0,
|
267 |
+
**kwargs,
|
268 |
+
)
|
core/encoders/dinov2/hub/depth/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .decode_heads import BNHead, DPTHead
|
7 |
+
from .encoder_decoder import DepthEncoderDecoder
|
core/encoders/dinov2/hub/depth/decode_heads.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import copy
|
7 |
+
from functools import partial
|
8 |
+
import math
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from .ops import resize
|
15 |
+
|
16 |
+
|
17 |
+
# XXX: (Untested) replacement for mmcv.imdenormalize()
|
18 |
+
def _imdenormalize(img, mean, std, to_bgr=True):
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
mean = mean.reshape(1, -1).astype(np.float64)
|
22 |
+
std = std.reshape(1, -1).astype(np.float64)
|
23 |
+
img = (img * std) + mean
|
24 |
+
if to_bgr:
|
25 |
+
img = img[::-1]
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
class DepthBaseDecodeHead(nn.Module):
|
30 |
+
"""Base class for BaseDecodeHead.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
in_channels (List): Input channels.
|
34 |
+
channels (int): Channels after modules, before conv_depth.
|
35 |
+
conv_layer (nn.Module): Conv layers. Default: None.
|
36 |
+
act_layer (nn.Module): Activation layers. Default: nn.ReLU.
|
37 |
+
loss_decode (dict): Config of decode loss.
|
38 |
+
Default: ().
|
39 |
+
sampler (dict|None): The config of depth map sampler.
|
40 |
+
Default: None.
|
41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
42 |
+
Default: False.
|
43 |
+
min_depth (int): Min depth in dataset setting.
|
44 |
+
Default: 1e-3.
|
45 |
+
max_depth (int): Max depth in dataset setting.
|
46 |
+
Default: None.
|
47 |
+
norm_layer (dict|None): Norm layers.
|
48 |
+
Default: None.
|
49 |
+
classify (bool): Whether predict depth in a cls.-reg. manner.
|
50 |
+
Default: False.
|
51 |
+
n_bins (int): The number of bins used in cls. step.
|
52 |
+
Default: 256.
|
53 |
+
bins_strategy (str): The discrete strategy used in cls. step.
|
54 |
+
Default: 'UD'.
|
55 |
+
norm_strategy (str): The norm strategy on cls. probability
|
56 |
+
distribution. Default: 'linear'
|
57 |
+
scale_up (str): Whether predict depth in a scale-up manner.
|
58 |
+
Default: False.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
in_channels,
|
64 |
+
conv_layer=None,
|
65 |
+
act_layer=nn.ReLU,
|
66 |
+
channels=96,
|
67 |
+
loss_decode=(),
|
68 |
+
sampler=None,
|
69 |
+
align_corners=False,
|
70 |
+
min_depth=1e-3,
|
71 |
+
max_depth=None,
|
72 |
+
norm_layer=None,
|
73 |
+
classify=False,
|
74 |
+
n_bins=256,
|
75 |
+
bins_strategy="UD",
|
76 |
+
norm_strategy="linear",
|
77 |
+
scale_up=False,
|
78 |
+
):
|
79 |
+
super(DepthBaseDecodeHead, self).__init__()
|
80 |
+
|
81 |
+
self.in_channels = in_channels
|
82 |
+
self.channels = channels
|
83 |
+
self.conf_layer = conv_layer
|
84 |
+
self.act_layer = act_layer
|
85 |
+
self.loss_decode = loss_decode
|
86 |
+
self.align_corners = align_corners
|
87 |
+
self.min_depth = min_depth
|
88 |
+
self.max_depth = max_depth
|
89 |
+
self.norm_layer = norm_layer
|
90 |
+
self.classify = classify
|
91 |
+
self.n_bins = n_bins
|
92 |
+
self.scale_up = scale_up
|
93 |
+
|
94 |
+
if self.classify:
|
95 |
+
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
96 |
+
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
97 |
+
|
98 |
+
self.bins_strategy = bins_strategy
|
99 |
+
self.norm_strategy = norm_strategy
|
100 |
+
self.softmax = nn.Softmax(dim=1)
|
101 |
+
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
102 |
+
else:
|
103 |
+
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
104 |
+
|
105 |
+
self.relu = nn.ReLU()
|
106 |
+
self.sigmoid = nn.Sigmoid()
|
107 |
+
|
108 |
+
def forward(self, inputs, img_metas):
|
109 |
+
"""Placeholder of forward function."""
|
110 |
+
pass
|
111 |
+
|
112 |
+
def forward_train(self, img, inputs, img_metas, depth_gt):
|
113 |
+
"""Forward function for training.
|
114 |
+
Args:
|
115 |
+
inputs (list[Tensor]): List of multi-level img features.
|
116 |
+
img_metas (list[dict]): List of image info dict where each dict
|
117 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
118 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
119 |
+
For details on the values of these keys see
|
120 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
121 |
+
depth_gt (Tensor): GT depth
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
dict[str, Tensor]: a dictionary of loss components
|
125 |
+
"""
|
126 |
+
depth_pred = self.forward(inputs, img_metas)
|
127 |
+
losses = self.losses(depth_pred, depth_gt)
|
128 |
+
|
129 |
+
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
130 |
+
losses.update(**log_imgs)
|
131 |
+
|
132 |
+
return losses
|
133 |
+
|
134 |
+
def forward_test(self, inputs, img_metas):
|
135 |
+
"""Forward function for testing.
|
136 |
+
Args:
|
137 |
+
inputs (list[Tensor]): List of multi-level img features.
|
138 |
+
img_metas (list[dict]): List of image info dict where each dict
|
139 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
140 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
141 |
+
For details on the values of these keys see
|
142 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Tensor: Output depth map.
|
146 |
+
"""
|
147 |
+
return self.forward(inputs, img_metas)
|
148 |
+
|
149 |
+
def depth_pred(self, feat):
|
150 |
+
"""Prediction each pixel."""
|
151 |
+
if self.classify:
|
152 |
+
logit = self.conv_depth(feat)
|
153 |
+
|
154 |
+
if self.bins_strategy == "UD":
|
155 |
+
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
156 |
+
elif self.bins_strategy == "SID":
|
157 |
+
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
158 |
+
|
159 |
+
# following Adabins, default linear
|
160 |
+
if self.norm_strategy == "linear":
|
161 |
+
logit = torch.relu(logit)
|
162 |
+
eps = 0.1
|
163 |
+
logit = logit + eps
|
164 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
165 |
+
elif self.norm_strategy == "softmax":
|
166 |
+
logit = torch.softmax(logit, dim=1)
|
167 |
+
elif self.norm_strategy == "sigmoid":
|
168 |
+
logit = torch.sigmoid(logit)
|
169 |
+
logit = logit / logit.sum(dim=1, keepdim=True)
|
170 |
+
|
171 |
+
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
172 |
+
|
173 |
+
else:
|
174 |
+
if self.scale_up:
|
175 |
+
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
176 |
+
else:
|
177 |
+
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
178 |
+
return output
|
179 |
+
|
180 |
+
def losses(self, depth_pred, depth_gt):
|
181 |
+
"""Compute depth loss."""
|
182 |
+
loss = dict()
|
183 |
+
depth_pred = resize(
|
184 |
+
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
185 |
+
)
|
186 |
+
if not isinstance(self.loss_decode, nn.ModuleList):
|
187 |
+
losses_decode = [self.loss_decode]
|
188 |
+
else:
|
189 |
+
losses_decode = self.loss_decode
|
190 |
+
for loss_decode in losses_decode:
|
191 |
+
if loss_decode.loss_name not in loss:
|
192 |
+
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
193 |
+
else:
|
194 |
+
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
195 |
+
return loss
|
196 |
+
|
197 |
+
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
198 |
+
import numpy as np
|
199 |
+
|
200 |
+
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
201 |
+
show_img = show_img.numpy().astype(np.float32)
|
202 |
+
show_img = _imdenormalize(
|
203 |
+
show_img,
|
204 |
+
img_meta["img_norm_cfg"]["mean"],
|
205 |
+
img_meta["img_norm_cfg"]["std"],
|
206 |
+
img_meta["img_norm_cfg"]["to_rgb"],
|
207 |
+
)
|
208 |
+
show_img = np.clip(show_img, 0, 255)
|
209 |
+
show_img = show_img.astype(np.uint8)
|
210 |
+
show_img = show_img[:, :, ::-1]
|
211 |
+
show_img = show_img.transpose(0, 2, 1)
|
212 |
+
show_img = show_img.transpose(1, 0, 2)
|
213 |
+
|
214 |
+
depth_pred = depth_pred / torch.max(depth_pred)
|
215 |
+
depth_gt = depth_gt / torch.max(depth_gt)
|
216 |
+
|
217 |
+
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
218 |
+
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
219 |
+
|
220 |
+
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
221 |
+
|
222 |
+
|
223 |
+
class BNHead(DepthBaseDecodeHead):
|
224 |
+
"""Just a batchnorm."""
|
225 |
+
|
226 |
+
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
227 |
+
super().__init__(**kwargs)
|
228 |
+
self.input_transform = input_transform
|
229 |
+
self.in_index = in_index
|
230 |
+
self.upsample = upsample
|
231 |
+
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
232 |
+
if self.classify:
|
233 |
+
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
234 |
+
else:
|
235 |
+
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
236 |
+
|
237 |
+
def _transform_inputs(self, inputs):
|
238 |
+
"""Transform inputs for decoder.
|
239 |
+
Args:
|
240 |
+
inputs (list[Tensor]): List of multi-level img features.
|
241 |
+
Returns:
|
242 |
+
Tensor: The transformed inputs
|
243 |
+
"""
|
244 |
+
|
245 |
+
if "concat" in self.input_transform:
|
246 |
+
inputs = [inputs[i] for i in self.in_index]
|
247 |
+
if "resize" in self.input_transform:
|
248 |
+
inputs = [
|
249 |
+
resize(
|
250 |
+
input=x,
|
251 |
+
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
252 |
+
mode="bilinear",
|
253 |
+
align_corners=self.align_corners,
|
254 |
+
)
|
255 |
+
for x in inputs
|
256 |
+
]
|
257 |
+
inputs = torch.cat(inputs, dim=1)
|
258 |
+
elif self.input_transform == "multiple_select":
|
259 |
+
inputs = [inputs[i] for i in self.in_index]
|
260 |
+
else:
|
261 |
+
inputs = inputs[self.in_index]
|
262 |
+
|
263 |
+
return inputs
|
264 |
+
|
265 |
+
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
266 |
+
"""Forward function for feature maps before classifying each pixel with
|
267 |
+
``self.cls_seg`` fc.
|
268 |
+
Args:
|
269 |
+
inputs (list[Tensor]): List of multi-level img features.
|
270 |
+
Returns:
|
271 |
+
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
272 |
+
H, W) which is feature map for last layer of decoder head.
|
273 |
+
"""
|
274 |
+
# accept lists (for cls token)
|
275 |
+
inputs = list(inputs)
|
276 |
+
for i, x in enumerate(inputs):
|
277 |
+
if len(x) == 2:
|
278 |
+
x, cls_token = x[0], x[1]
|
279 |
+
if len(x.shape) == 2:
|
280 |
+
x = x[:, :, None, None]
|
281 |
+
cls_token = cls_token[:, :, None, None].expand_as(x)
|
282 |
+
inputs[i] = torch.cat((x, cls_token), 1)
|
283 |
+
else:
|
284 |
+
x = x[0]
|
285 |
+
if len(x.shape) == 2:
|
286 |
+
x = x[:, :, None, None]
|
287 |
+
inputs[i] = x
|
288 |
+
x = self._transform_inputs(inputs)
|
289 |
+
# feats = self.bn(x)
|
290 |
+
return x
|
291 |
+
|
292 |
+
def forward(self, inputs, img_metas=None, **kwargs):
|
293 |
+
"""Forward function."""
|
294 |
+
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
295 |
+
output = self.depth_pred(output)
|
296 |
+
return output
|
297 |
+
|
298 |
+
|
299 |
+
class ConvModule(nn.Module):
|
300 |
+
"""A conv block that bundles conv/norm/activation layers.
|
301 |
+
|
302 |
+
This block simplifies the usage of convolution layers, which are commonly
|
303 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
304 |
+
It is based upon three build methods: `build_conv_layer()`,
|
305 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
306 |
+
|
307 |
+
Besides, we add some additional features in this module.
|
308 |
+
1. Automatically set `bias` of the conv layer.
|
309 |
+
2. Spectral norm is supported.
|
310 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
311 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
in_channels (int): Number of channels in the input feature map.
|
315 |
+
Same as that in ``nn._ConvNd``.
|
316 |
+
out_channels (int): Number of channels produced by the convolution.
|
317 |
+
Same as that in ``nn._ConvNd``.
|
318 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
319 |
+
Same as that in ``nn._ConvNd``.
|
320 |
+
stride (int | tuple[int]): Stride of the convolution.
|
321 |
+
Same as that in ``nn._ConvNd``.
|
322 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
323 |
+
the input. Same as that in ``nn._ConvNd``.
|
324 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
325 |
+
Same as that in ``nn._ConvNd``.
|
326 |
+
groups (int): Number of blocked connections from input channels to
|
327 |
+
output channels. Same as that in ``nn._ConvNd``.
|
328 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
329 |
+
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
|
330 |
+
False. Default: "auto".
|
331 |
+
conv_layer (nn.Module): Convolution layer. Default: None,
|
332 |
+
which means using conv2d.
|
333 |
+
norm_layer (nn.Module): Normalization layer. Default: None.
|
334 |
+
act_layer (nn.Module): Activation layer. Default: nn.ReLU.
|
335 |
+
inplace (bool): Whether to use inplace mode for activation.
|
336 |
+
Default: True.
|
337 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
338 |
+
Default: False.
|
339 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
340 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
341 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
342 |
+
implementation and ['reflect'] with our own implementation.
|
343 |
+
Default: 'zeros'.
|
344 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
345 |
+
sequence of "conv", "norm" and "act". Common examples are
|
346 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
347 |
+
Default: ('conv', 'norm', 'act').
|
348 |
+
"""
|
349 |
+
|
350 |
+
_abbr_ = "conv_block"
|
351 |
+
|
352 |
+
def __init__(
|
353 |
+
self,
|
354 |
+
in_channels,
|
355 |
+
out_channels,
|
356 |
+
kernel_size,
|
357 |
+
stride=1,
|
358 |
+
padding=0,
|
359 |
+
dilation=1,
|
360 |
+
groups=1,
|
361 |
+
bias="auto",
|
362 |
+
conv_layer=nn.Conv2d,
|
363 |
+
norm_layer=None,
|
364 |
+
act_layer=nn.ReLU,
|
365 |
+
inplace=True,
|
366 |
+
with_spectral_norm=False,
|
367 |
+
padding_mode="zeros",
|
368 |
+
order=("conv", "norm", "act"),
|
369 |
+
):
|
370 |
+
super(ConvModule, self).__init__()
|
371 |
+
official_padding_mode = ["zeros", "circular"]
|
372 |
+
self.conv_layer = conv_layer
|
373 |
+
self.norm_layer = norm_layer
|
374 |
+
self.act_layer = act_layer
|
375 |
+
self.inplace = inplace
|
376 |
+
self.with_spectral_norm = with_spectral_norm
|
377 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
378 |
+
self.order = order
|
379 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
380 |
+
assert set(order) == set(["conv", "norm", "act"])
|
381 |
+
|
382 |
+
self.with_norm = norm_layer is not None
|
383 |
+
self.with_activation = act_layer is not None
|
384 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
385 |
+
if bias == "auto":
|
386 |
+
bias = not self.with_norm
|
387 |
+
self.with_bias = bias
|
388 |
+
|
389 |
+
if self.with_explicit_padding:
|
390 |
+
if padding_mode == "zeros":
|
391 |
+
padding_layer = nn.ZeroPad2d
|
392 |
+
else:
|
393 |
+
raise AssertionError(f"Unsupported padding mode: {padding_mode}")
|
394 |
+
self.pad = padding_layer(padding)
|
395 |
+
|
396 |
+
# reset padding to 0 for conv module
|
397 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
398 |
+
# build convolution layer
|
399 |
+
self.conv = self.conv_layer(
|
400 |
+
in_channels,
|
401 |
+
out_channels,
|
402 |
+
kernel_size,
|
403 |
+
stride=stride,
|
404 |
+
padding=conv_padding,
|
405 |
+
dilation=dilation,
|
406 |
+
groups=groups,
|
407 |
+
bias=bias,
|
408 |
+
)
|
409 |
+
# export the attributes of self.conv to a higher level for convenience
|
410 |
+
self.in_channels = self.conv.in_channels
|
411 |
+
self.out_channels = self.conv.out_channels
|
412 |
+
self.kernel_size = self.conv.kernel_size
|
413 |
+
self.stride = self.conv.stride
|
414 |
+
self.padding = padding
|
415 |
+
self.dilation = self.conv.dilation
|
416 |
+
self.transposed = self.conv.transposed
|
417 |
+
self.output_padding = self.conv.output_padding
|
418 |
+
self.groups = self.conv.groups
|
419 |
+
|
420 |
+
if self.with_spectral_norm:
|
421 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
422 |
+
|
423 |
+
# build normalization layers
|
424 |
+
if self.with_norm:
|
425 |
+
# norm layer is after conv layer
|
426 |
+
if order.index("norm") > order.index("conv"):
|
427 |
+
norm_channels = out_channels
|
428 |
+
else:
|
429 |
+
norm_channels = in_channels
|
430 |
+
norm = partial(norm_layer, num_features=norm_channels)
|
431 |
+
self.add_module("norm", norm)
|
432 |
+
if self.with_bias:
|
433 |
+
from torch.nnModules.batchnorm import _BatchNorm
|
434 |
+
from torch.nnModules.instancenorm import _InstanceNorm
|
435 |
+
|
436 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
437 |
+
warnings.warn("Unnecessary conv bias before batch/instance norm")
|
438 |
+
else:
|
439 |
+
self.norm_name = None
|
440 |
+
|
441 |
+
# build activation layer
|
442 |
+
if self.with_activation:
|
443 |
+
# nn.Tanh has no 'inplace' argument
|
444 |
+
# (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
|
445 |
+
if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
|
446 |
+
act_layer = partial(act_layer, inplace=inplace)
|
447 |
+
self.activate = act_layer()
|
448 |
+
|
449 |
+
# Use msra init by default
|
450 |
+
self.init_weights()
|
451 |
+
|
452 |
+
@property
|
453 |
+
def norm(self):
|
454 |
+
if self.norm_name:
|
455 |
+
return getattr(self, self.norm_name)
|
456 |
+
else:
|
457 |
+
return None
|
458 |
+
|
459 |
+
def init_weights(self):
|
460 |
+
# 1. It is mainly for customized conv layers with their own
|
461 |
+
# initialization manners by calling their own ``init_weights()``,
|
462 |
+
# and we do not want ConvModule to override the initialization.
|
463 |
+
# 2. For customized conv layers without their own initialization
|
464 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
465 |
+
# and PyTorch's conv layers, they will be initialized by
|
466 |
+
# this method with default ``kaiming_init``.
|
467 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
468 |
+
# initialization implementation using default ``kaiming_init``.
|
469 |
+
if not hasattr(self.conv, "init_weights"):
|
470 |
+
if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
|
471 |
+
nonlinearity = "leaky_relu"
|
472 |
+
a = 0.01 # XXX: default negative_slope
|
473 |
+
else:
|
474 |
+
nonlinearity = "relu"
|
475 |
+
a = 0
|
476 |
+
if hasattr(self.conv, "weight") and self.conv.weight is not None:
|
477 |
+
nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
|
478 |
+
if hasattr(self.conv, "bias") and self.conv.bias is not None:
|
479 |
+
nn.init.constant_(self.conv.bias, 0)
|
480 |
+
if self.with_norm:
|
481 |
+
if hasattr(self.norm, "weight") and self.norm.weight is not None:
|
482 |
+
nn.init.constant_(self.norm.weight, 1)
|
483 |
+
if hasattr(self.norm, "bias") and self.norm.bias is not None:
|
484 |
+
nn.init.constant_(self.norm.bias, 0)
|
485 |
+
|
486 |
+
def forward(self, x, activate=True, norm=True):
|
487 |
+
for layer in self.order:
|
488 |
+
if layer == "conv":
|
489 |
+
if self.with_explicit_padding:
|
490 |
+
x = self.pad(x)
|
491 |
+
x = self.conv(x)
|
492 |
+
elif layer == "norm" and norm and self.with_norm:
|
493 |
+
x = self.norm(x)
|
494 |
+
elif layer == "act" and activate and self.with_activation:
|
495 |
+
x = self.activate(x)
|
496 |
+
return x
|
497 |
+
|
498 |
+
|
499 |
+
class Interpolate(nn.Module):
|
500 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
501 |
+
super(Interpolate, self).__init__()
|
502 |
+
self.interp = nn.functional.interpolate
|
503 |
+
self.scale_factor = scale_factor
|
504 |
+
self.mode = mode
|
505 |
+
self.align_corners = align_corners
|
506 |
+
|
507 |
+
def forward(self, x):
|
508 |
+
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
509 |
+
return x
|
510 |
+
|
511 |
+
|
512 |
+
class HeadDepth(nn.Module):
|
513 |
+
def __init__(self, features):
|
514 |
+
super(HeadDepth, self).__init__()
|
515 |
+
self.head = nn.Sequential(
|
516 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
517 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
518 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
519 |
+
nn.ReLU(),
|
520 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
521 |
+
)
|
522 |
+
|
523 |
+
def forward(self, x):
|
524 |
+
x = self.head(x)
|
525 |
+
return x
|
526 |
+
|
527 |
+
|
528 |
+
class ReassembleBlocks(nn.Module):
|
529 |
+
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
530 |
+
rearrange the feature vector to feature map.
|
531 |
+
Args:
|
532 |
+
in_channels (int): ViT feature channels. Default: 768.
|
533 |
+
out_channels (List): output channels of each stage.
|
534 |
+
Default: [96, 192, 384, 768].
|
535 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
536 |
+
patch_size (int): The patch size. Default: 16.
|
537 |
+
"""
|
538 |
+
|
539 |
+
def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
|
540 |
+
super(ReassembleBlocks, self).__init__()
|
541 |
+
|
542 |
+
assert readout_type in ["ignore", "add", "project"]
|
543 |
+
self.readout_type = readout_type
|
544 |
+
self.patch_size = patch_size
|
545 |
+
|
546 |
+
self.projects = nn.ModuleList(
|
547 |
+
[
|
548 |
+
ConvModule(
|
549 |
+
in_channels=in_channels,
|
550 |
+
out_channels=out_channel,
|
551 |
+
kernel_size=1,
|
552 |
+
act_layer=None,
|
553 |
+
)
|
554 |
+
for out_channel in out_channels
|
555 |
+
]
|
556 |
+
)
|
557 |
+
|
558 |
+
self.resize_layers = nn.ModuleList(
|
559 |
+
[
|
560 |
+
nn.ConvTranspose2d(
|
561 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
562 |
+
),
|
563 |
+
nn.ConvTranspose2d(
|
564 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
565 |
+
),
|
566 |
+
nn.Identity(),
|
567 |
+
nn.Conv2d(
|
568 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
569 |
+
),
|
570 |
+
]
|
571 |
+
)
|
572 |
+
if self.readout_type == "project":
|
573 |
+
self.readout_projects = nn.ModuleList()
|
574 |
+
for _ in range(len(self.projects)):
|
575 |
+
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
576 |
+
|
577 |
+
def forward(self, inputs):
|
578 |
+
assert isinstance(inputs, list)
|
579 |
+
out = []
|
580 |
+
for i, x in enumerate(inputs):
|
581 |
+
assert len(x) == 2
|
582 |
+
x, cls_token = x[0], x[1]
|
583 |
+
feature_shape = x.shape
|
584 |
+
if self.readout_type == "project":
|
585 |
+
x = x.flatten(2).permute((0, 2, 1))
|
586 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
587 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
588 |
+
x = x.permute(0, 2, 1).reshape(feature_shape)
|
589 |
+
elif self.readout_type == "add":
|
590 |
+
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
591 |
+
x = x.reshape(feature_shape)
|
592 |
+
else:
|
593 |
+
pass
|
594 |
+
x = self.projects[i](x)
|
595 |
+
x = self.resize_layers[i](x)
|
596 |
+
out.append(x)
|
597 |
+
return out
|
598 |
+
|
599 |
+
|
600 |
+
class PreActResidualConvUnit(nn.Module):
|
601 |
+
"""ResidualConvUnit, pre-activate residual unit.
|
602 |
+
Args:
|
603 |
+
in_channels (int): number of channels in the input feature map.
|
604 |
+
act_layer (nn.Module): activation layer.
|
605 |
+
norm_layer (nn.Module): norm layer.
|
606 |
+
stride (int): stride of the first block. Default: 1
|
607 |
+
dilation (int): dilation rate for convs layers. Default: 1.
|
608 |
+
"""
|
609 |
+
|
610 |
+
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
|
611 |
+
super(PreActResidualConvUnit, self).__init__()
|
612 |
+
|
613 |
+
self.conv1 = ConvModule(
|
614 |
+
in_channels,
|
615 |
+
in_channels,
|
616 |
+
3,
|
617 |
+
stride=stride,
|
618 |
+
padding=dilation,
|
619 |
+
dilation=dilation,
|
620 |
+
norm_layer=norm_layer,
|
621 |
+
act_layer=act_layer,
|
622 |
+
bias=False,
|
623 |
+
order=("act", "conv", "norm"),
|
624 |
+
)
|
625 |
+
|
626 |
+
self.conv2 = ConvModule(
|
627 |
+
in_channels,
|
628 |
+
in_channels,
|
629 |
+
3,
|
630 |
+
padding=1,
|
631 |
+
norm_layer=norm_layer,
|
632 |
+
act_layer=act_layer,
|
633 |
+
bias=False,
|
634 |
+
order=("act", "conv", "norm"),
|
635 |
+
)
|
636 |
+
|
637 |
+
def forward(self, inputs):
|
638 |
+
inputs_ = inputs.clone()
|
639 |
+
x = self.conv1(inputs)
|
640 |
+
x = self.conv2(x)
|
641 |
+
return x + inputs_
|
642 |
+
|
643 |
+
|
644 |
+
class FeatureFusionBlock(nn.Module):
|
645 |
+
"""FeatureFusionBlock, merge feature map from different stages.
|
646 |
+
Args:
|
647 |
+
in_channels (int): Input channels.
|
648 |
+
act_layer (nn.Module): activation layer for ResidualConvUnit.
|
649 |
+
norm_layer (nn.Module): normalization layer.
|
650 |
+
expand (bool): Whether expand the channels in post process block.
|
651 |
+
Default: False.
|
652 |
+
align_corners (bool): align_corner setting for bilinear upsample.
|
653 |
+
Default: True.
|
654 |
+
"""
|
655 |
+
|
656 |
+
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
|
657 |
+
super(FeatureFusionBlock, self).__init__()
|
658 |
+
|
659 |
+
self.in_channels = in_channels
|
660 |
+
self.expand = expand
|
661 |
+
self.align_corners = align_corners
|
662 |
+
|
663 |
+
self.out_channels = in_channels
|
664 |
+
if self.expand:
|
665 |
+
self.out_channels = in_channels // 2
|
666 |
+
|
667 |
+
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
|
668 |
+
|
669 |
+
self.res_conv_unit1 = PreActResidualConvUnit(
|
670 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
671 |
+
)
|
672 |
+
self.res_conv_unit2 = PreActResidualConvUnit(
|
673 |
+
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
674 |
+
)
|
675 |
+
|
676 |
+
def forward(self, *inputs):
|
677 |
+
x = inputs[0]
|
678 |
+
if len(inputs) == 2:
|
679 |
+
if x.shape != inputs[1].shape:
|
680 |
+
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
681 |
+
else:
|
682 |
+
res = inputs[1]
|
683 |
+
x = x + self.res_conv_unit1(res)
|
684 |
+
x = self.res_conv_unit2(x)
|
685 |
+
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
686 |
+
x = self.project(x)
|
687 |
+
return x
|
688 |
+
|
689 |
+
|
690 |
+
class DPTHead(DepthBaseDecodeHead):
|
691 |
+
"""Vision Transformers for Dense Prediction.
|
692 |
+
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
693 |
+
Args:
|
694 |
+
embed_dims (int): The embed dimension of the ViT backbone.
|
695 |
+
Default: 768.
|
696 |
+
post_process_channels (List): Out channels of post process conv
|
697 |
+
layers. Default: [96, 192, 384, 768].
|
698 |
+
readout_type (str): Type of readout operation. Default: 'ignore'.
|
699 |
+
patch_size (int): The patch size. Default: 16.
|
700 |
+
expand_channels (bool): Whether expand the channels in post process
|
701 |
+
block. Default: False.
|
702 |
+
"""
|
703 |
+
|
704 |
+
def __init__(
|
705 |
+
self,
|
706 |
+
embed_dims=768,
|
707 |
+
post_process_channels=[96, 192, 384, 768],
|
708 |
+
readout_type="ignore",
|
709 |
+
patch_size=16,
|
710 |
+
expand_channels=False,
|
711 |
+
**kwargs,
|
712 |
+
):
|
713 |
+
super(DPTHead, self).__init__(**kwargs)
|
714 |
+
|
715 |
+
self.in_channels = self.in_channels
|
716 |
+
self.expand_channels = expand_channels
|
717 |
+
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
718 |
+
|
719 |
+
self.post_process_channels = [
|
720 |
+
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
721 |
+
]
|
722 |
+
self.convs = nn.ModuleList()
|
723 |
+
for channel in self.post_process_channels:
|
724 |
+
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
|
725 |
+
self.fusion_blocks = nn.ModuleList()
|
726 |
+
for _ in range(len(self.convs)):
|
727 |
+
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
|
728 |
+
self.fusion_blocks[0].res_conv_unit1 = None
|
729 |
+
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
|
730 |
+
self.num_fusion_blocks = len(self.fusion_blocks)
|
731 |
+
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
732 |
+
self.num_post_process_channels = len(self.post_process_channels)
|
733 |
+
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
734 |
+
assert self.num_reassemble_blocks == self.num_post_process_channels
|
735 |
+
self.conv_depth = HeadDepth(self.channels)
|
736 |
+
|
737 |
+
def forward(self, inputs, img_metas):
|
738 |
+
assert len(inputs) == self.num_reassemble_blocks
|
739 |
+
x = [inp for inp in inputs]
|
740 |
+
x = self.reassemble_blocks(x)
|
741 |
+
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
742 |
+
out = self.fusion_blocks[0](x[-1])
|
743 |
+
for i in range(1, len(self.fusion_blocks)):
|
744 |
+
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
745 |
+
out = self.project(out)
|
746 |
+
out = self.depth_pred(out)
|
747 |
+
return out
|
core/encoders/dinov2/hub/depth/encoder_decoder.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .ops import resize
|
13 |
+
|
14 |
+
|
15 |
+
def add_prefix(inputs, prefix):
|
16 |
+
"""Add prefix for dict.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
inputs (dict): The input dict with str keys.
|
20 |
+
prefix (str): The prefix to add.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
|
24 |
+
dict: The dict with keys updated with ``prefix``.
|
25 |
+
"""
|
26 |
+
|
27 |
+
outputs = dict()
|
28 |
+
for name, value in inputs.items():
|
29 |
+
outputs[f"{prefix}.{name}"] = value
|
30 |
+
|
31 |
+
return outputs
|
32 |
+
|
33 |
+
|
34 |
+
class DepthEncoderDecoder(nn.Module):
|
35 |
+
"""Encoder Decoder depther.
|
36 |
+
|
37 |
+
EncoderDecoder typically consists of backbone and decode_head.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, backbone, decode_head):
|
41 |
+
super(DepthEncoderDecoder, self).__init__()
|
42 |
+
|
43 |
+
self.backbone = backbone
|
44 |
+
self.decode_head = decode_head
|
45 |
+
self.align_corners = self.decode_head.align_corners
|
46 |
+
|
47 |
+
def extract_feat(self, img):
|
48 |
+
"""Extract features from images."""
|
49 |
+
return self.backbone(img)
|
50 |
+
|
51 |
+
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
52 |
+
"""Encode images with backbone and decode into a depth estimation
|
53 |
+
map of the same size as input."""
|
54 |
+
x = self.extract_feat(img)
|
55 |
+
out = self._decode_head_forward_test(x, img_metas)
|
56 |
+
# crop the pred depth to the certain range.
|
57 |
+
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
58 |
+
if rescale:
|
59 |
+
if size is None:
|
60 |
+
if img_metas is not None:
|
61 |
+
size = img_metas[0]["ori_shape"][:2]
|
62 |
+
else:
|
63 |
+
size = img.shape[2:]
|
64 |
+
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
65 |
+
return out
|
66 |
+
|
67 |
+
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
68 |
+
"""Run forward function and calculate loss for decode head in
|
69 |
+
training."""
|
70 |
+
losses = dict()
|
71 |
+
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
|
72 |
+
losses.update(add_prefix(loss_decode, "decode"))
|
73 |
+
return losses
|
74 |
+
|
75 |
+
def _decode_head_forward_test(self, x, img_metas):
|
76 |
+
"""Run forward function and calculate loss for decode head in
|
77 |
+
inference."""
|
78 |
+
depth_pred = self.decode_head.forward_test(x, img_metas)
|
79 |
+
return depth_pred
|
80 |
+
|
81 |
+
def forward_dummy(self, img):
|
82 |
+
"""Dummy forward function."""
|
83 |
+
depth = self.encode_decode(img, None)
|
84 |
+
|
85 |
+
return depth
|
86 |
+
|
87 |
+
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
88 |
+
"""Forward function for training.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
img (Tensor): Input images.
|
92 |
+
img_metas (list[dict]): List of image info dict where each dict
|
93 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
94 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
95 |
+
For details on the values of these keys see
|
96 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
97 |
+
depth_gt (Tensor): Depth gt
|
98 |
+
used if the architecture supports depth estimation task.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
dict[str, Tensor]: a dictionary of loss components
|
102 |
+
"""
|
103 |
+
|
104 |
+
x = self.extract_feat(img)
|
105 |
+
|
106 |
+
losses = dict()
|
107 |
+
|
108 |
+
# the last of x saves the info from neck
|
109 |
+
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
110 |
+
|
111 |
+
losses.update(loss_decode)
|
112 |
+
|
113 |
+
return losses
|
114 |
+
|
115 |
+
def whole_inference(self, img, img_meta, rescale, size=None):
|
116 |
+
"""Inference with full image."""
|
117 |
+
return self.encode_decode(img, img_meta, rescale, size=size)
|
118 |
+
|
119 |
+
def slide_inference(self, img, img_meta, rescale, stride, crop_size):
|
120 |
+
"""Inference by sliding-window with overlap.
|
121 |
+
|
122 |
+
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
123 |
+
decode without padding.
|
124 |
+
"""
|
125 |
+
|
126 |
+
h_stride, w_stride = stride
|
127 |
+
h_crop, w_crop = crop_size
|
128 |
+
batch_size, _, h_img, w_img = img.size()
|
129 |
+
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
130 |
+
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
131 |
+
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
132 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
133 |
+
for h_idx in range(h_grids):
|
134 |
+
for w_idx in range(w_grids):
|
135 |
+
y1 = h_idx * h_stride
|
136 |
+
x1 = w_idx * w_stride
|
137 |
+
y2 = min(y1 + h_crop, h_img)
|
138 |
+
x2 = min(x1 + w_crop, w_img)
|
139 |
+
y1 = max(y2 - h_crop, 0)
|
140 |
+
x1 = max(x2 - w_crop, 0)
|
141 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
142 |
+
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
143 |
+
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
144 |
+
|
145 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
146 |
+
assert (count_mat == 0).sum() == 0
|
147 |
+
if torch.onnx.is_in_onnx_export():
|
148 |
+
# cast count_mat to constant while exporting to ONNX
|
149 |
+
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
150 |
+
preds = preds / count_mat
|
151 |
+
return preds
|
152 |
+
|
153 |
+
def inference(self, img, img_meta, rescale, size=None, mode="whole"):
|
154 |
+
"""Inference with slide/whole style.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
img (Tensor): The input image of shape (N, 3, H, W).
|
158 |
+
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
159 |
+
'scale_factor', 'flip', and may also contain
|
160 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
161 |
+
For details on the values of these keys see
|
162 |
+
`depth/datasets/pipelines/formatting.py:Collect`.
|
163 |
+
rescale (bool): Whether rescale back to original shape.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: The output depth map.
|
167 |
+
"""
|
168 |
+
|
169 |
+
assert mode in ["slide", "whole"]
|
170 |
+
ori_shape = img_meta[0]["ori_shape"]
|
171 |
+
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
172 |
+
if mode == "slide":
|
173 |
+
depth_pred = self.slide_inference(img, img_meta, rescale)
|
174 |
+
else:
|
175 |
+
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
176 |
+
output = depth_pred
|
177 |
+
flip = img_meta[0]["flip"]
|
178 |
+
if flip:
|
179 |
+
flip_direction = img_meta[0]["flip_direction"]
|
180 |
+
assert flip_direction in ["horizontal", "vertical"]
|
181 |
+
if flip_direction == "horizontal":
|
182 |
+
output = output.flip(dims=(3,))
|
183 |
+
elif flip_direction == "vertical":
|
184 |
+
output = output.flip(dims=(2,))
|
185 |
+
|
186 |
+
return output
|
187 |
+
|
188 |
+
def simple_test(self, img, img_meta, rescale=True):
|
189 |
+
"""Simple test with single image."""
|
190 |
+
depth_pred = self.inference(img, img_meta, rescale)
|
191 |
+
if torch.onnx.is_in_onnx_export():
|
192 |
+
# our inference backend only support 4D output
|
193 |
+
depth_pred = depth_pred.unsqueeze(0)
|
194 |
+
return depth_pred
|
195 |
+
depth_pred = depth_pred.cpu().numpy()
|
196 |
+
# unravel batch dim
|
197 |
+
depth_pred = list(depth_pred)
|
198 |
+
return depth_pred
|
199 |
+
|
200 |
+
def aug_test(self, imgs, img_metas, rescale=True):
|
201 |
+
"""Test with augmentations.
|
202 |
+
|
203 |
+
Only rescale=True is supported.
|
204 |
+
"""
|
205 |
+
# aug_test rescale all imgs back to ori_shape for now
|
206 |
+
assert rescale
|
207 |
+
# to save memory, we get augmented depth logit inplace
|
208 |
+
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
209 |
+
for i in range(1, len(imgs)):
|
210 |
+
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
211 |
+
depth_pred += cur_depth_pred
|
212 |
+
depth_pred /= len(imgs)
|
213 |
+
depth_pred = depth_pred.cpu().numpy()
|
214 |
+
# unravel batch dim
|
215 |
+
depth_pred = list(depth_pred)
|
216 |
+
return depth_pred
|
217 |
+
|
218 |
+
def forward_test(self, imgs, img_metas, **kwargs):
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
imgs (List[Tensor]): the outer list indicates test-time
|
222 |
+
augmentations and inner Tensor should have a shape NxCxHxW,
|
223 |
+
which contains all images in the batch.
|
224 |
+
img_metas (List[List[dict]]): the outer list indicates test-time
|
225 |
+
augs (multiscale, flip, etc.) and the inner list indicates
|
226 |
+
images in a batch.
|
227 |
+
"""
|
228 |
+
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
229 |
+
if not isinstance(var, list):
|
230 |
+
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
231 |
+
num_augs = len(imgs)
|
232 |
+
if num_augs != len(img_metas):
|
233 |
+
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
234 |
+
# all images in the same aug batch all of the same ori_shape and pad
|
235 |
+
# shape
|
236 |
+
for img_meta in img_metas:
|
237 |
+
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
238 |
+
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
239 |
+
img_shapes = [_["img_shape"] for _ in img_meta]
|
240 |
+
assert all(shape == img_shapes[0] for shape in img_shapes)
|
241 |
+
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
242 |
+
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
243 |
+
|
244 |
+
if num_augs == 1:
|
245 |
+
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
246 |
+
else:
|
247 |
+
return self.aug_test(imgs, img_metas, **kwargs)
|
248 |
+
|
249 |
+
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
250 |
+
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
251 |
+
on whether ``return_loss`` is ``True``.
|
252 |
+
|
253 |
+
Note this setting will change the expected inputs. When
|
254 |
+
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
255 |
+
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
256 |
+
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
257 |
+
the outer list indicating test time augmentations.
|
258 |
+
"""
|
259 |
+
if return_loss:
|
260 |
+
return self.forward_train(img, img_metas, **kwargs)
|
261 |
+
else:
|
262 |
+
return self.forward_test(img, img_metas, **kwargs)
|
263 |
+
|
264 |
+
def train_step(self, data_batch, optimizer, **kwargs):
|
265 |
+
"""The iteration step during training.
|
266 |
+
|
267 |
+
This method defines an iteration step during training, except for the
|
268 |
+
back propagation and optimizer updating, which are done in an optimizer
|
269 |
+
hook. Note that in some complicated cases or models, the whole process
|
270 |
+
including back propagation and optimizer updating is also defined in
|
271 |
+
this method, such as GAN.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
data (dict): The output of dataloader.
|
275 |
+
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
276 |
+
runner is passed to ``train_step()``. This argument is unused
|
277 |
+
and reserved.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
281 |
+
``num_samples``.
|
282 |
+
``loss`` is a tensor for back propagation, which can be a
|
283 |
+
weighted sum of multiple losses.
|
284 |
+
``log_vars`` contains all the variables to be sent to the
|
285 |
+
logger.
|
286 |
+
``num_samples`` indicates the batch size (when the model is
|
287 |
+
DDP, it means the batch size on each GPU), which is used for
|
288 |
+
averaging the logs.
|
289 |
+
"""
|
290 |
+
losses = self(**data_batch)
|
291 |
+
|
292 |
+
# split losses and images
|
293 |
+
real_losses = {}
|
294 |
+
log_imgs = {}
|
295 |
+
for k, v in losses.items():
|
296 |
+
if "img" in k:
|
297 |
+
log_imgs[k] = v
|
298 |
+
else:
|
299 |
+
real_losses[k] = v
|
300 |
+
|
301 |
+
loss, log_vars = self._parse_losses(real_losses)
|
302 |
+
|
303 |
+
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
304 |
+
|
305 |
+
return outputs
|
306 |
+
|
307 |
+
def val_step(self, data_batch, **kwargs):
|
308 |
+
"""The iteration step during validation.
|
309 |
+
|
310 |
+
This method shares the same signature as :func:`train_step`, but used
|
311 |
+
during val epochs. Note that the evaluation after training epochs is
|
312 |
+
not implemented with this method, but an evaluation hook.
|
313 |
+
"""
|
314 |
+
output = self(**data_batch, **kwargs)
|
315 |
+
return output
|
316 |
+
|
317 |
+
@staticmethod
|
318 |
+
def _parse_losses(losses):
|
319 |
+
import torch.distributed as dist
|
320 |
+
|
321 |
+
"""Parse the raw outputs (losses) of the network.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
losses (dict): Raw output of the network, which usually contain
|
325 |
+
losses and other necessary information.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
329 |
+
which may be a weighted sum of all losses, log_vars contains
|
330 |
+
all the variables to be sent to the logger.
|
331 |
+
"""
|
332 |
+
log_vars = OrderedDict()
|
333 |
+
for loss_name, loss_value in losses.items():
|
334 |
+
if isinstance(loss_value, torch.Tensor):
|
335 |
+
log_vars[loss_name] = loss_value.mean()
|
336 |
+
elif isinstance(loss_value, list):
|
337 |
+
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
338 |
+
else:
|
339 |
+
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
340 |
+
|
341 |
+
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
342 |
+
|
343 |
+
log_vars["loss"] = loss
|
344 |
+
for loss_name, loss_value in log_vars.items():
|
345 |
+
# reduce loss when distributed training
|
346 |
+
if dist.is_available() and dist.is_initialized():
|
347 |
+
loss_value = loss_value.data.clone()
|
348 |
+
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
349 |
+
log_vars[loss_name] = loss_value.item()
|
350 |
+
|
351 |
+
return loss, log_vars
|
core/encoders/dinov2/hub/depth/ops.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
|
12 |
+
if warning:
|
13 |
+
if size is not None and align_corners:
|
14 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
15 |
+
output_h, output_w = tuple(int(x) for x in size)
|
16 |
+
if output_h > input_h or output_w > output_h:
|
17 |
+
if (
|
18 |
+
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
|
19 |
+
and (output_h - 1) % (input_h - 1)
|
20 |
+
and (output_w - 1) % (input_w - 1)
|
21 |
+
):
|
22 |
+
warnings.warn(
|
23 |
+
f"When align_corners={align_corners}, "
|
24 |
+
"the output would more aligned if "
|
25 |
+
f"input size {(input_h, input_w)} is `x+1` and "
|
26 |
+
f"out size {(output_h, output_w)} is `nx+1`"
|
27 |
+
)
|
28 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
core/encoders/dinov2/hub/depthers.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from functools import partial
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .backbones import _make_dinov2_model
|
13 |
+
from .depth import BNHead, DepthEncoderDecoder, DPTHead
|
14 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
|
15 |
+
|
16 |
+
|
17 |
+
class Weights(Enum):
|
18 |
+
NYU = "NYU"
|
19 |
+
KITTI = "KITTI"
|
20 |
+
|
21 |
+
|
22 |
+
def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
|
23 |
+
if not pretrained: # Default
|
24 |
+
return (0.001, 10.0)
|
25 |
+
|
26 |
+
# Pretrained, set according to the training dataset for the provided weights
|
27 |
+
if weights == Weights.KITTI:
|
28 |
+
return (0.001, 80.0)
|
29 |
+
|
30 |
+
if weights == Weights.NYU:
|
31 |
+
return (0.001, 10.0)
|
32 |
+
|
33 |
+
return (0.001, 10.0)
|
34 |
+
|
35 |
+
|
36 |
+
def _make_dinov2_linear_depth_head(
|
37 |
+
*,
|
38 |
+
embed_dim: int,
|
39 |
+
layers: int,
|
40 |
+
min_depth: float,
|
41 |
+
max_depth: float,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
if layers not in (1, 4):
|
45 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
46 |
+
|
47 |
+
if layers == 1:
|
48 |
+
in_index = [0]
|
49 |
+
else:
|
50 |
+
assert layers == 4
|
51 |
+
in_index = [0, 1, 2, 3]
|
52 |
+
|
53 |
+
return BNHead(
|
54 |
+
classify=True,
|
55 |
+
n_bins=256,
|
56 |
+
bins_strategy="UD",
|
57 |
+
norm_strategy="linear",
|
58 |
+
upsample=4,
|
59 |
+
in_channels=[embed_dim] * len(in_index),
|
60 |
+
in_index=in_index,
|
61 |
+
input_transform="resize_concat",
|
62 |
+
channels=embed_dim * len(in_index) * 2,
|
63 |
+
align_corners=False,
|
64 |
+
min_depth=0.001,
|
65 |
+
max_depth=80,
|
66 |
+
loss_decode=(),
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def _make_dinov2_linear_depther(
|
71 |
+
*,
|
72 |
+
arch_name: str = "vit_large",
|
73 |
+
layers: int = 4,
|
74 |
+
pretrained: bool = True,
|
75 |
+
weights: Union[Weights, str] = Weights.NYU,
|
76 |
+
depth_range: Optional[Tuple[float, float]] = None,
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
if layers not in (1, 4):
|
80 |
+
raise AssertionError(f"Unsupported number of layers: {layers}")
|
81 |
+
if isinstance(weights, str):
|
82 |
+
try:
|
83 |
+
weights = Weights[weights]
|
84 |
+
except KeyError:
|
85 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
86 |
+
|
87 |
+
if depth_range is None:
|
88 |
+
depth_range = _get_depth_range(pretrained, weights)
|
89 |
+
min_depth, max_depth = depth_range
|
90 |
+
|
91 |
+
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
92 |
+
|
93 |
+
embed_dim = backbone.embed_dim
|
94 |
+
patch_size = backbone.patch_size
|
95 |
+
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
96 |
+
linear_depth_head = _make_dinov2_linear_depth_head(
|
97 |
+
embed_dim=embed_dim,
|
98 |
+
layers=layers,
|
99 |
+
min_depth=min_depth,
|
100 |
+
max_depth=max_depth,
|
101 |
+
)
|
102 |
+
|
103 |
+
layer_count = {
|
104 |
+
"vit_small": 12,
|
105 |
+
"vit_base": 12,
|
106 |
+
"vit_large": 24,
|
107 |
+
"vit_giant2": 40,
|
108 |
+
}[arch_name]
|
109 |
+
|
110 |
+
if layers == 4:
|
111 |
+
out_index = {
|
112 |
+
"vit_small": [2, 5, 8, 11],
|
113 |
+
"vit_base": [2, 5, 8, 11],
|
114 |
+
"vit_large": [4, 11, 17, 23],
|
115 |
+
"vit_giant2": [9, 19, 29, 39],
|
116 |
+
}[arch_name]
|
117 |
+
else:
|
118 |
+
assert layers == 1
|
119 |
+
out_index = [layer_count - 1]
|
120 |
+
|
121 |
+
model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
|
122 |
+
model.backbone.forward = partial(
|
123 |
+
backbone.get_intermediate_layers,
|
124 |
+
n=out_index,
|
125 |
+
reshape=True,
|
126 |
+
return_class_token=True,
|
127 |
+
norm=False,
|
128 |
+
)
|
129 |
+
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
|
130 |
+
|
131 |
+
if pretrained:
|
132 |
+
layers_str = str(layers) if layers == 4 else ""
|
133 |
+
weights_str = weights.value.lower()
|
134 |
+
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
|
135 |
+
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
136 |
+
if "state_dict" in checkpoint:
|
137 |
+
state_dict = checkpoint["state_dict"]
|
138 |
+
model.load_state_dict(state_dict, strict=False)
|
139 |
+
|
140 |
+
return model
|
141 |
+
|
142 |
+
|
143 |
+
def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
144 |
+
return _make_dinov2_linear_depther(
|
145 |
+
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
150 |
+
return _make_dinov2_linear_depther(
|
151 |
+
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
156 |
+
return _make_dinov2_linear_depther(
|
157 |
+
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
162 |
+
return _make_dinov2_linear_depther(
|
163 |
+
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
|
168 |
+
return DPTHead(
|
169 |
+
in_channels=[embed_dim] * 4,
|
170 |
+
channels=256,
|
171 |
+
embed_dims=embed_dim,
|
172 |
+
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
|
173 |
+
readout_type="project",
|
174 |
+
min_depth=min_depth,
|
175 |
+
max_depth=max_depth,
|
176 |
+
loss_decode=(),
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
def _make_dinov2_dpt_depther(
|
181 |
+
*,
|
182 |
+
arch_name: str = "vit_large",
|
183 |
+
pretrained: bool = True,
|
184 |
+
weights: Union[Weights, str] = Weights.NYU,
|
185 |
+
depth_range: Optional[Tuple[float, float]] = None,
|
186 |
+
**kwargs,
|
187 |
+
):
|
188 |
+
if isinstance(weights, str):
|
189 |
+
try:
|
190 |
+
weights = Weights[weights]
|
191 |
+
except KeyError:
|
192 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
193 |
+
|
194 |
+
if depth_range is None:
|
195 |
+
depth_range = _get_depth_range(pretrained, weights)
|
196 |
+
min_depth, max_depth = depth_range
|
197 |
+
|
198 |
+
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
199 |
+
|
200 |
+
model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
|
201 |
+
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
|
202 |
+
|
203 |
+
out_index = {
|
204 |
+
"vit_small": [2, 5, 8, 11],
|
205 |
+
"vit_base": [2, 5, 8, 11],
|
206 |
+
"vit_large": [4, 11, 17, 23],
|
207 |
+
"vit_giant2": [9, 19, 29, 39],
|
208 |
+
}[arch_name]
|
209 |
+
|
210 |
+
model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
|
211 |
+
model.backbone.forward = partial(
|
212 |
+
backbone.get_intermediate_layers,
|
213 |
+
n=out_index,
|
214 |
+
reshape=True,
|
215 |
+
return_class_token=True,
|
216 |
+
norm=False,
|
217 |
+
)
|
218 |
+
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
|
219 |
+
|
220 |
+
if pretrained:
|
221 |
+
weights_str = weights.value.lower()
|
222 |
+
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
|
223 |
+
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
224 |
+
if "state_dict" in checkpoint:
|
225 |
+
state_dict = checkpoint["state_dict"]
|
226 |
+
model.load_state_dict(state_dict, strict=False)
|
227 |
+
|
228 |
+
return model
|
229 |
+
|
230 |
+
|
231 |
+
def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
232 |
+
return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
233 |
+
|
234 |
+
|
235 |
+
def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
236 |
+
return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
237 |
+
|
238 |
+
|
239 |
+
def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
240 |
+
return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
241 |
+
|
242 |
+
|
243 |
+
def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
244 |
+
return _make_dinov2_dpt_depther(
|
245 |
+
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
246 |
+
)
|
core/encoders/dinov2/hub/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
15 |
+
|
16 |
+
|
17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
21 |
+
|
22 |
+
|
23 |
+
class CenterPadding(nn.Module):
|
24 |
+
def __init__(self, multiple):
|
25 |
+
super().__init__()
|
26 |
+
self.multiple = multiple
|
27 |
+
|
28 |
+
def _get_pad(self, size):
|
29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
30 |
+
pad_size = new_size - size
|
31 |
+
pad_size_left = pad_size // 2
|
32 |
+
pad_size_right = pad_size - pad_size_left
|
33 |
+
return pad_size_left, pad_size_right
|
34 |
+
|
35 |
+
@torch.inference_mode()
|
36 |
+
def forward(self, x):
|
37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
38 |
+
output = F.pad(x, pads)
|
39 |
+
return output
|
core/encoders/dinov2/layers/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# ******************************************************************************
|
7 |
+
# Code modified by Zexin He in 2023-2024.
|
8 |
+
# Modifications are marked with clearly visible comments
|
9 |
+
# licensed under the Apache License, Version 2.0.
|
10 |
+
# ******************************************************************************
|
11 |
+
|
12 |
+
from .dino_head import DINOHead
|
13 |
+
from .mlp import Mlp
|
14 |
+
from .patch_embed import PatchEmbed
|
15 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
16 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
17 |
+
# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
|
18 |
+
from .block import Block, BlockWithModulation
|
19 |
+
# ********************************************************
|
20 |
+
from .attention import MemEffAttention
|
core/encoders/dinov2/layers/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
+
try:
|
23 |
+
if XFORMERS_ENABLED:
|
24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
25 |
+
|
26 |
+
XFORMERS_AVAILABLE = True
|
27 |
+
warnings.warn("xFormers is available (Attention)")
|
28 |
+
else:
|
29 |
+
warnings.warn("xFormers is disabled (Attention)")
|
30 |
+
raise ImportError
|
31 |
+
except ImportError:
|
32 |
+
XFORMERS_AVAILABLE = False
|
33 |
+
warnings.warn("xFormers is not available (Attention)")
|
34 |
+
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int = 8,
|
41 |
+
qkv_bias: bool = False,
|
42 |
+
proj_bias: bool = True,
|
43 |
+
attn_drop: float = 0.0,
|
44 |
+
proj_drop: float = 0.0,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
|
56 |
+
def forward(self, x: Tensor) -> Tensor:
|
57 |
+
B, N, C = x.shape
|
58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
59 |
+
|
60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
+
attn = q @ k.transpose(-2, -1)
|
62 |
+
|
63 |
+
attn = attn.softmax(dim=-1)
|
64 |
+
attn = self.attn_drop(attn)
|
65 |
+
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
+
x = self.proj(x)
|
68 |
+
x = self.proj_drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class MemEffAttention(Attention):
|
73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
+
if not XFORMERS_AVAILABLE:
|
75 |
+
if attn_bias is not None:
|
76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
77 |
+
return super().forward(x)
|
78 |
+
|
79 |
+
B, N, C = x.shape
|
80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
81 |
+
|
82 |
+
q, k, v = unbind(qkv, 2)
|
83 |
+
|
84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
85 |
+
x = x.reshape([B, N, C])
|
86 |
+
|
87 |
+
x = self.proj(x)
|
88 |
+
x = self.proj_drop(x)
|
89 |
+
return x
|
core/encoders/dinov2/layers/block.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
# ******************************************************************************
|
11 |
+
# Code modified by Zexin He in 2023-2024.
|
12 |
+
# Modifications are marked with clearly visible comments
|
13 |
+
# licensed under the Apache License, Version 2.0.
|
14 |
+
# ******************************************************************************
|
15 |
+
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
from typing import Callable, List, Any, Tuple, Dict
|
19 |
+
import warnings
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import nn, Tensor
|
23 |
+
|
24 |
+
from .attention import Attention, MemEffAttention
|
25 |
+
from .drop_path import DropPath
|
26 |
+
from .layer_scale import LayerScale
|
27 |
+
from .mlp import Mlp
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.getLogger("dinov2")
|
31 |
+
|
32 |
+
|
33 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
34 |
+
try:
|
35 |
+
if XFORMERS_ENABLED:
|
36 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
37 |
+
|
38 |
+
XFORMERS_AVAILABLE = True
|
39 |
+
warnings.warn("xFormers is available (Block)")
|
40 |
+
else:
|
41 |
+
warnings.warn("xFormers is disabled (Block)")
|
42 |
+
raise ImportError
|
43 |
+
except ImportError:
|
44 |
+
XFORMERS_AVAILABLE = False
|
45 |
+
|
46 |
+
warnings.warn("xFormers is not available (Block)")
|
47 |
+
|
48 |
+
|
49 |
+
class Block(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
dim: int,
|
53 |
+
num_heads: int,
|
54 |
+
mlp_ratio: float = 4.0,
|
55 |
+
qkv_bias: bool = False,
|
56 |
+
proj_bias: bool = True,
|
57 |
+
ffn_bias: bool = True,
|
58 |
+
drop: float = 0.0,
|
59 |
+
attn_drop: float = 0.0,
|
60 |
+
init_values=None,
|
61 |
+
drop_path: float = 0.0,
|
62 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
63 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
64 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
65 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
66 |
+
) -> None:
|
67 |
+
super().__init__()
|
68 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
69 |
+
self.norm1 = norm_layer(dim)
|
70 |
+
self.attn = attn_class(
|
71 |
+
dim,
|
72 |
+
num_heads=num_heads,
|
73 |
+
qkv_bias=qkv_bias,
|
74 |
+
proj_bias=proj_bias,
|
75 |
+
attn_drop=attn_drop,
|
76 |
+
proj_drop=drop,
|
77 |
+
)
|
78 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
79 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
80 |
+
|
81 |
+
self.norm2 = norm_layer(dim)
|
82 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
83 |
+
self.mlp = ffn_layer(
|
84 |
+
in_features=dim,
|
85 |
+
hidden_features=mlp_hidden_dim,
|
86 |
+
act_layer=act_layer,
|
87 |
+
drop=drop,
|
88 |
+
bias=ffn_bias,
|
89 |
+
)
|
90 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
91 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
92 |
+
|
93 |
+
self.sample_drop_ratio = drop_path
|
94 |
+
|
95 |
+
def forward(self, x: Tensor) -> Tensor:
|
96 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
97 |
+
return self.ls1(self.attn(self.norm1(x)))
|
98 |
+
|
99 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
100 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
101 |
+
|
102 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
103 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
104 |
+
x = drop_add_residual_stochastic_depth(
|
105 |
+
x,
|
106 |
+
residual_func=attn_residual_func,
|
107 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
108 |
+
)
|
109 |
+
x = drop_add_residual_stochastic_depth(
|
110 |
+
x,
|
111 |
+
residual_func=ffn_residual_func,
|
112 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
113 |
+
)
|
114 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
115 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
116 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
117 |
+
else:
|
118 |
+
x = x + attn_residual_func(x)
|
119 |
+
x = x + ffn_residual_func(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
124 |
+
# Override forward with modulation input
|
125 |
+
class BlockWithModulation(Block):
|
126 |
+
def __init__(self, *args, **kwargs) -> None:
|
127 |
+
super().__init__(*args, **kwargs)
|
128 |
+
|
129 |
+
def forward(self, x: Tensor, mod: Tensor) -> Tensor:
|
130 |
+
def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
|
131 |
+
return self.ls1(self.attn(self.norm1(x, mod)))
|
132 |
+
|
133 |
+
def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
|
134 |
+
return self.ls2(self.mlp(self.norm2(x, mod)))
|
135 |
+
|
136 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
137 |
+
raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet")
|
138 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
139 |
+
x = x + self.drop_path1(attn_residual_func(x, mod))
|
140 |
+
x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2
|
141 |
+
else:
|
142 |
+
x = x + attn_residual_func(x, mod)
|
143 |
+
x = x + ffn_residual_func(x, mod)
|
144 |
+
return x
|
145 |
+
# ********************************************************
|
146 |
+
|
147 |
+
|
148 |
+
def drop_add_residual_stochastic_depth(
|
149 |
+
x: Tensor,
|
150 |
+
residual_func: Callable[[Tensor], Tensor],
|
151 |
+
sample_drop_ratio: float = 0.0,
|
152 |
+
) -> Tensor:
|
153 |
+
# 1) extract subset using permutation
|
154 |
+
b, n, d = x.shape
|
155 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
156 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
157 |
+
x_subset = x[brange]
|
158 |
+
|
159 |
+
# 2) apply residual_func to get residual
|
160 |
+
residual = residual_func(x_subset)
|
161 |
+
|
162 |
+
x_flat = x.flatten(1)
|
163 |
+
residual = residual.flatten(1)
|
164 |
+
|
165 |
+
residual_scale_factor = b / sample_subset_size
|
166 |
+
|
167 |
+
# 3) add the residual
|
168 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
169 |
+
return x_plus_residual.view_as(x)
|
170 |
+
|
171 |
+
|
172 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
173 |
+
b, n, d = x.shape
|
174 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
175 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
176 |
+
residual_scale_factor = b / sample_subset_size
|
177 |
+
return brange, residual_scale_factor
|
178 |
+
|
179 |
+
|
180 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
181 |
+
if scaling_vector is None:
|
182 |
+
x_flat = x.flatten(1)
|
183 |
+
residual = residual.flatten(1)
|
184 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
185 |
+
else:
|
186 |
+
x_plus_residual = scaled_index_add(
|
187 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
188 |
+
)
|
189 |
+
return x_plus_residual
|
190 |
+
|
191 |
+
|
192 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
193 |
+
|
194 |
+
|
195 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
196 |
+
"""
|
197 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
198 |
+
"""
|
199 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
200 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
201 |
+
if all_shapes not in attn_bias_cache.keys():
|
202 |
+
seqlens = []
|
203 |
+
for b, x in zip(batch_sizes, x_list):
|
204 |
+
for _ in range(b):
|
205 |
+
seqlens.append(x.shape[1])
|
206 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
207 |
+
attn_bias._batch_sizes = batch_sizes
|
208 |
+
attn_bias_cache[all_shapes] = attn_bias
|
209 |
+
|
210 |
+
if branges is not None:
|
211 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
212 |
+
else:
|
213 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
214 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
215 |
+
|
216 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
217 |
+
|
218 |
+
|
219 |
+
def drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list: List[Tensor],
|
221 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
222 |
+
sample_drop_ratio: float = 0.0,
|
223 |
+
scaling_vector=None,
|
224 |
+
) -> Tensor:
|
225 |
+
# 1) generate random set of indices for dropping samples in the batch
|
226 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
227 |
+
branges = [s[0] for s in branges_scales]
|
228 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
229 |
+
|
230 |
+
# 2) get attention bias and index+concat the tensors
|
231 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
232 |
+
|
233 |
+
# 3) apply residual_func to get residual, and split the result
|
234 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
235 |
+
|
236 |
+
outputs = []
|
237 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
238 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
239 |
+
return outputs
|
240 |
+
|
241 |
+
|
242 |
+
class NestedTensorBlock(Block):
|
243 |
+
|
244 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
245 |
+
warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning)
|
246 |
+
# ********************************************************
|
247 |
+
|
248 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
249 |
+
"""
|
250 |
+
x_list contains a list of tensors to nest together and run
|
251 |
+
"""
|
252 |
+
assert isinstance(self.attn, MemEffAttention)
|
253 |
+
|
254 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
255 |
+
|
256 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
257 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
258 |
+
|
259 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
260 |
+
return self.mlp(self.norm2(x))
|
261 |
+
|
262 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
263 |
+
x_list,
|
264 |
+
residual_func=attn_residual_func,
|
265 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
266 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
267 |
+
)
|
268 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
269 |
+
x_list,
|
270 |
+
residual_func=ffn_residual_func,
|
271 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
272 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
273 |
+
)
|
274 |
+
return x_list
|
275 |
+
else:
|
276 |
+
|
277 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
278 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
279 |
+
|
280 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
281 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
282 |
+
|
283 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
284 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
285 |
+
x = x + ffn_residual_func(x)
|
286 |
+
return attn_bias.split(x)
|
287 |
+
|
288 |
+
def forward(self, x_or_x_list):
|
289 |
+
if isinstance(x_or_x_list, Tensor):
|
290 |
+
return super().forward(x_or_x_list)
|
291 |
+
elif isinstance(x_or_x_list, list):
|
292 |
+
if not XFORMERS_AVAILABLE:
|
293 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
294 |
+
return self.forward_nested(x_or_x_list)
|
295 |
+
else:
|
296 |
+
raise AssertionError
|
core/encoders/dinov2/layers/dino_head.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
+
self.apply(self._init_weights)
|
27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
+
self.last_layer.weight_g.data.fill_(1)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.mlp(x)
|
38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
+
x = self.last_layer(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
+
if nlayers == 1:
|
46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
+
else:
|
48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
+
if use_bn:
|
50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
+
layers.append(nn.GELU())
|
52 |
+
for _ in range(nlayers - 2):
|
53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
+
if use_bn:
|
55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
+
layers.append(nn.GELU())
|
57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
+
return nn.Sequential(*layers)
|
core/encoders/dinov2/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
core/encoders/dinov2/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
core/encoders/dinov2/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
core/encoders/dinov2/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
def flops(self) -> float:
|
84 |
+
Ho, Wo = self.patches_resolution
|
85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
if self.norm is not None:
|
87 |
+
flops += Ho * Wo * self.embed_dim
|
88 |
+
return flops
|
core/encoders/dinov2/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
try:
|
39 |
+
if XFORMERS_ENABLED:
|
40 |
+
from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = True
|
43 |
+
warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
else:
|
45 |
+
warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
raise ImportError
|
47 |
+
except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
core/encoders/dinov2/models/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from . import vision_transformer as vits
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger("dinov2")
|
12 |
+
|
13 |
+
|
14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
16 |
+
if "vit" in args.arch:
|
17 |
+
vit_kwargs = dict(
|
18 |
+
img_size=img_size,
|
19 |
+
patch_size=args.patch_size,
|
20 |
+
init_values=args.layerscale,
|
21 |
+
ffn_layer=args.ffn_layer,
|
22 |
+
block_chunks=args.block_chunks,
|
23 |
+
qkv_bias=args.qkv_bias,
|
24 |
+
proj_bias=args.proj_bias,
|
25 |
+
ffn_bias=args.ffn_bias,
|
26 |
+
num_register_tokens=args.num_register_tokens,
|
27 |
+
interpolate_offset=args.interpolate_offset,
|
28 |
+
interpolate_antialias=args.interpolate_antialias,
|
29 |
+
)
|
30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
31 |
+
if only_teacher:
|
32 |
+
return teacher, teacher.embed_dim
|
33 |
+
student = vits.__dict__[args.arch](
|
34 |
+
**vit_kwargs,
|
35 |
+
drop_path_rate=args.drop_path_rate,
|
36 |
+
drop_path_uniform=args.drop_path_uniform,
|
37 |
+
)
|
38 |
+
embed_dim = student.embed_dim
|
39 |
+
return student, teacher, embed_dim
|
40 |
+
|
41 |
+
|
42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
core/encoders/dinov2/models/vision_transformer.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
# ******************************************************************************
|
11 |
+
# Code modified by Zexin He in 2023-2024.
|
12 |
+
# Modifications are marked with clearly visible comments
|
13 |
+
# licensed under the Apache License, Version 2.0.
|
14 |
+
# ******************************************************************************
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
import math
|
18 |
+
import logging
|
19 |
+
from typing import Sequence, Tuple, Union, Callable
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.utils.checkpoint
|
24 |
+
from torch.nn.init import trunc_normal_
|
25 |
+
|
26 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
27 |
+
# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
|
28 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation
|
29 |
+
# ********************************************************
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.getLogger("dinov2")
|
33 |
+
|
34 |
+
|
35 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
36 |
+
if not depth_first and include_root:
|
37 |
+
fn(module=module, name=name)
|
38 |
+
for child_name, child_module in module.named_children():
|
39 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
40 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
41 |
+
if depth_first and include_root:
|
42 |
+
fn(module=module, name=name)
|
43 |
+
return module
|
44 |
+
|
45 |
+
|
46 |
+
class BlockChunk(nn.ModuleList):
|
47 |
+
def forward(self, x):
|
48 |
+
for b in self:
|
49 |
+
x = b(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class DinoVisionTransformer(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
img_size=224,
|
57 |
+
patch_size=16,
|
58 |
+
in_chans=3,
|
59 |
+
embed_dim=768,
|
60 |
+
depth=12,
|
61 |
+
num_heads=12,
|
62 |
+
mlp_ratio=4.0,
|
63 |
+
qkv_bias=True,
|
64 |
+
ffn_bias=True,
|
65 |
+
proj_bias=True,
|
66 |
+
drop_path_rate=0.0,
|
67 |
+
drop_path_uniform=False,
|
68 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
69 |
+
embed_layer=PatchEmbed,
|
70 |
+
act_layer=nn.GELU,
|
71 |
+
block_fn=Block,
|
72 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
73 |
+
modulation_dim: int = None,
|
74 |
+
# ********************************************************
|
75 |
+
ffn_layer="mlp",
|
76 |
+
block_chunks=1,
|
77 |
+
num_register_tokens=0,
|
78 |
+
interpolate_antialias=False,
|
79 |
+
interpolate_offset=0.1,
|
80 |
+
):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
img_size (int, tuple): input image size
|
84 |
+
patch_size (int, tuple): patch size
|
85 |
+
in_chans (int): number of input channels
|
86 |
+
embed_dim (int): embedding dimension
|
87 |
+
depth (int): depth of transformer
|
88 |
+
num_heads (int): number of attention heads
|
89 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
90 |
+
qkv_bias (bool): enable bias for qkv if True
|
91 |
+
proj_bias (bool): enable bias for proj in attn if True
|
92 |
+
ffn_bias (bool): enable bias for ffn if True
|
93 |
+
drop_path_rate (float): stochastic depth rate
|
94 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
95 |
+
weight_init (str): weight init scheme
|
96 |
+
init_values (float): layer-scale init values
|
97 |
+
embed_layer (nn.Module): patch embedding layer
|
98 |
+
act_layer (nn.Module): MLP activation layer
|
99 |
+
block_fn (nn.Module): transformer block class
|
100 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
101 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
102 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
103 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
104 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
105 |
+
"""
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
109 |
+
block_norm_layer = None
|
110 |
+
if modulation_dim is not None:
|
111 |
+
from ....modulate import ModLN
|
112 |
+
block_norm_layer = partial(ModLN, mod_dim=modulation_dim)
|
113 |
+
else:
|
114 |
+
block_norm_layer = nn.LayerNorm
|
115 |
+
block_norm_layer = partial(block_norm_layer, eps=1e-6)
|
116 |
+
# ********************************************************
|
117 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
118 |
+
|
119 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
120 |
+
self.num_tokens = 1
|
121 |
+
self.n_blocks = depth
|
122 |
+
self.num_heads = num_heads
|
123 |
+
self.patch_size = patch_size
|
124 |
+
self.num_register_tokens = num_register_tokens
|
125 |
+
self.interpolate_antialias = interpolate_antialias
|
126 |
+
self.interpolate_offset = interpolate_offset
|
127 |
+
|
128 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
129 |
+
num_patches = self.patch_embed.num_patches
|
130 |
+
|
131 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
132 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
133 |
+
assert num_register_tokens >= 0
|
134 |
+
self.register_tokens = (
|
135 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
136 |
+
)
|
137 |
+
|
138 |
+
if drop_path_uniform is True:
|
139 |
+
dpr = [drop_path_rate] * depth
|
140 |
+
else:
|
141 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
142 |
+
|
143 |
+
if ffn_layer == "mlp":
|
144 |
+
logger.info("using MLP layer as FFN")
|
145 |
+
ffn_layer = Mlp
|
146 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
147 |
+
logger.info("using SwiGLU layer as FFN")
|
148 |
+
ffn_layer = SwiGLUFFNFused
|
149 |
+
elif ffn_layer == "identity":
|
150 |
+
logger.info("using Identity layer as FFN")
|
151 |
+
|
152 |
+
def f(*args, **kwargs):
|
153 |
+
return nn.Identity()
|
154 |
+
|
155 |
+
ffn_layer = f
|
156 |
+
else:
|
157 |
+
raise NotImplementedError
|
158 |
+
|
159 |
+
blocks_list = [
|
160 |
+
block_fn(
|
161 |
+
dim=embed_dim,
|
162 |
+
num_heads=num_heads,
|
163 |
+
mlp_ratio=mlp_ratio,
|
164 |
+
qkv_bias=qkv_bias,
|
165 |
+
proj_bias=proj_bias,
|
166 |
+
ffn_bias=ffn_bias,
|
167 |
+
drop_path=dpr[i],
|
168 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
169 |
+
norm_layer=block_norm_layer,
|
170 |
+
# ********************************************************
|
171 |
+
act_layer=act_layer,
|
172 |
+
ffn_layer=ffn_layer,
|
173 |
+
init_values=init_values,
|
174 |
+
)
|
175 |
+
for i in range(depth)
|
176 |
+
]
|
177 |
+
if block_chunks > 0:
|
178 |
+
self.chunked_blocks = True
|
179 |
+
chunked_blocks = []
|
180 |
+
chunksize = depth // block_chunks
|
181 |
+
for i in range(0, depth, chunksize):
|
182 |
+
# this is to keep the block index consistent if we chunk the block list
|
183 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
184 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
185 |
+
else:
|
186 |
+
self.chunked_blocks = False
|
187 |
+
self.blocks = nn.ModuleList(blocks_list)
|
188 |
+
|
189 |
+
self.norm = norm_layer(embed_dim)
|
190 |
+
self.head = nn.Identity()
|
191 |
+
|
192 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
193 |
+
# hacking unused mask_token for better DDP
|
194 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
195 |
+
# ********************************************************
|
196 |
+
|
197 |
+
self.init_weights()
|
198 |
+
|
199 |
+
def init_weights(self):
|
200 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
201 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
202 |
+
if self.register_tokens is not None:
|
203 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
204 |
+
named_apply(init_weights_vit_timm, self)
|
205 |
+
|
206 |
+
def interpolate_pos_encoding(self, x, w, h):
|
207 |
+
previous_dtype = x.dtype
|
208 |
+
npatch = x.shape[1] - 1
|
209 |
+
N = self.pos_embed.shape[1] - 1
|
210 |
+
if npatch == N and w == h:
|
211 |
+
return self.pos_embed
|
212 |
+
pos_embed = self.pos_embed.float()
|
213 |
+
class_pos_embed = pos_embed[:, 0]
|
214 |
+
patch_pos_embed = pos_embed[:, 1:]
|
215 |
+
dim = x.shape[-1]
|
216 |
+
w0 = w // self.patch_size
|
217 |
+
h0 = h // self.patch_size
|
218 |
+
# we add a small number to avoid floating point error in the interpolation
|
219 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
220 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
221 |
+
|
222 |
+
sqrt_N = math.sqrt(N)
|
223 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
224 |
+
patch_pos_embed = nn.functional.interpolate(
|
225 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
226 |
+
scale_factor=(sx, sy),
|
227 |
+
mode="bicubic",
|
228 |
+
antialias=self.interpolate_antialias,
|
229 |
+
)
|
230 |
+
|
231 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
232 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
233 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
234 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
235 |
+
|
236 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
237 |
+
B, nc, w, h = x.shape
|
238 |
+
x = self.patch_embed(x)
|
239 |
+
if masks is not None:
|
240 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
241 |
+
raise NotImplementedError("Masking is not supported in hacked DINOv2")
|
242 |
+
# x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
243 |
+
# ********************************************************
|
244 |
+
|
245 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
246 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
247 |
+
|
248 |
+
if self.register_tokens is not None:
|
249 |
+
x = torch.cat(
|
250 |
+
(
|
251 |
+
x[:, :1],
|
252 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
253 |
+
x[:, 1:],
|
254 |
+
),
|
255 |
+
dim=1,
|
256 |
+
)
|
257 |
+
|
258 |
+
return x
|
259 |
+
|
260 |
+
def forward_features_list(self, x_list, masks_list):
|
261 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
262 |
+
for blk in self.blocks:
|
263 |
+
x = blk(x)
|
264 |
+
|
265 |
+
all_x = x
|
266 |
+
output = []
|
267 |
+
for x, masks in zip(all_x, masks_list):
|
268 |
+
x_norm = self.norm(x)
|
269 |
+
output.append(
|
270 |
+
{
|
271 |
+
"x_norm_clstoken": x_norm[:, 0],
|
272 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
273 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
274 |
+
"x_prenorm": x,
|
275 |
+
"masks": masks,
|
276 |
+
}
|
277 |
+
)
|
278 |
+
return output
|
279 |
+
|
280 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
281 |
+
def forward_features(self, x, masks=None, mod=None):
|
282 |
+
if isinstance(x, list):
|
283 |
+
raise DeprecationWarning("forward_features_list is deprecated, use forward_features")
|
284 |
+
return self.forward_features_list(x, masks)
|
285 |
+
|
286 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
287 |
+
|
288 |
+
if mod is None:
|
289 |
+
for blk in self.blocks:
|
290 |
+
x = blk(x)
|
291 |
+
else:
|
292 |
+
for blk in self.blocks:
|
293 |
+
x = blk(x, mod)
|
294 |
+
|
295 |
+
x_norm = self.norm(x)
|
296 |
+
return {
|
297 |
+
"x_norm_clstoken": x_norm[:, 0],
|
298 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
299 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
300 |
+
"x_prenorm": x,
|
301 |
+
"masks": masks,
|
302 |
+
}
|
303 |
+
# ********************************************************
|
304 |
+
|
305 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
306 |
+
x = self.prepare_tokens_with_masks(x)
|
307 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
308 |
+
output, total_block_len = [], len(self.blocks)
|
309 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
310 |
+
for i, blk in enumerate(self.blocks):
|
311 |
+
x = blk(x)
|
312 |
+
if i in blocks_to_take:
|
313 |
+
output.append(x)
|
314 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
315 |
+
return output
|
316 |
+
|
317 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
318 |
+
x = self.prepare_tokens_with_masks(x)
|
319 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
320 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
321 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
322 |
+
for block_chunk in self.blocks:
|
323 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
324 |
+
x = blk(x)
|
325 |
+
if i in blocks_to_take:
|
326 |
+
output.append(x)
|
327 |
+
i += 1
|
328 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
329 |
+
return output
|
330 |
+
|
331 |
+
def get_intermediate_layers(
|
332 |
+
self,
|
333 |
+
x: torch.Tensor,
|
334 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
335 |
+
reshape: bool = False,
|
336 |
+
return_class_token: bool = False,
|
337 |
+
norm=True,
|
338 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
339 |
+
if self.chunked_blocks:
|
340 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
341 |
+
else:
|
342 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
343 |
+
if norm:
|
344 |
+
outputs = [self.norm(out) for out in outputs]
|
345 |
+
class_tokens = [out[:, 0] for out in outputs]
|
346 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
347 |
+
if reshape:
|
348 |
+
B, _, w, h = x.shape
|
349 |
+
outputs = [
|
350 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
351 |
+
for out in outputs
|
352 |
+
]
|
353 |
+
if return_class_token:
|
354 |
+
return tuple(zip(outputs, class_tokens))
|
355 |
+
return tuple(outputs)
|
356 |
+
|
357 |
+
def forward(self, *args, is_training=False, **kwargs):
|
358 |
+
ret = self.forward_features(*args, **kwargs)
|
359 |
+
if is_training:
|
360 |
+
return ret
|
361 |
+
else:
|
362 |
+
return self.head(ret["x_norm_clstoken"])
|
363 |
+
|
364 |
+
|
365 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
366 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
367 |
+
if isinstance(module, nn.Linear):
|
368 |
+
trunc_normal_(module.weight, std=0.02)
|
369 |
+
if module.bias is not None:
|
370 |
+
nn.init.zeros_(module.bias)
|
371 |
+
|
372 |
+
|
373 |
+
# ********** Modified by Zexin He in 2023-2024 **********
|
374 |
+
# block class selected from Block and BlockWithModulation
|
375 |
+
|
376 |
+
def _block_cls(**kwargs):
|
377 |
+
modulation_dim = kwargs.get("modulation_dim", None)
|
378 |
+
if modulation_dim is None:
|
379 |
+
block_cls = Block
|
380 |
+
else:
|
381 |
+
block_cls = BlockWithModulation
|
382 |
+
return block_cls
|
383 |
+
|
384 |
+
|
385 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
386 |
+
model = DinoVisionTransformer(
|
387 |
+
patch_size=patch_size,
|
388 |
+
embed_dim=384,
|
389 |
+
depth=12,
|
390 |
+
num_heads=6,
|
391 |
+
mlp_ratio=4,
|
392 |
+
block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
|
393 |
+
num_register_tokens=num_register_tokens,
|
394 |
+
**kwargs,
|
395 |
+
)
|
396 |
+
return model
|
397 |
+
|
398 |
+
|
399 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
400 |
+
model = DinoVisionTransformer(
|
401 |
+
patch_size=patch_size,
|
402 |
+
embed_dim=768,
|
403 |
+
depth=12,
|
404 |
+
num_heads=12,
|
405 |
+
mlp_ratio=4,
|
406 |
+
block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
|
407 |
+
num_register_tokens=num_register_tokens,
|
408 |
+
**kwargs,
|
409 |
+
)
|
410 |
+
return model
|
411 |
+
|
412 |
+
|
413 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
414 |
+
model = DinoVisionTransformer(
|
415 |
+
patch_size=patch_size,
|
416 |
+
embed_dim=1024,
|
417 |
+
depth=24,
|
418 |
+
num_heads=16,
|
419 |
+
mlp_ratio=4,
|
420 |
+
block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
|
421 |
+
num_register_tokens=num_register_tokens,
|
422 |
+
**kwargs,
|
423 |
+
)
|
424 |
+
return model
|
425 |
+
|
426 |
+
|
427 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
428 |
+
"""
|
429 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
430 |
+
"""
|
431 |
+
model = DinoVisionTransformer(
|
432 |
+
patch_size=patch_size,
|
433 |
+
embed_dim=1536,
|
434 |
+
depth=40,
|
435 |
+
num_heads=24,
|
436 |
+
mlp_ratio=4,
|
437 |
+
block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
|
438 |
+
num_register_tokens=num_register_tokens,
|
439 |
+
**kwargs,
|
440 |
+
)
|
441 |
+
return model
|
442 |
+
|
443 |
+
# ********************************************************
|
core/encoders/dinov2_wrapper.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
# from accelerate.logging import get_logger
|
19 |
+
|
20 |
+
|
21 |
+
# logger = get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class Dinov2Wrapper(nn.Module):
|
25 |
+
"""
|
26 |
+
Dino v2 wrapper using original implementation, hacked with modulation.
|
27 |
+
"""
|
28 |
+
def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True):
|
29 |
+
super().__init__()
|
30 |
+
self.modulation_dim = modulation_dim
|
31 |
+
self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
|
32 |
+
if freeze:
|
33 |
+
if modulation_dim is not None:
|
34 |
+
raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
|
35 |
+
self._freeze()
|
36 |
+
|
37 |
+
def _freeze(self):
|
38 |
+
#logger.warning(f"======== Freezing Dinov2Wrapper ========")
|
39 |
+
self.model.eval()
|
40 |
+
for name, param in self.model.named_parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
|
45 |
+
from importlib import import_module
|
46 |
+
dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
|
47 |
+
model_fn = getattr(dinov2_hub, model_name)
|
48 |
+
#logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
|
49 |
+
model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
|
50 |
+
return model
|
51 |
+
|
52 |
+
#@torch.compile
|
53 |
+
def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
|
54 |
+
# image: [N, C, H, W]
|
55 |
+
# mod: [N, D] or None
|
56 |
+
# RGB image with [0,1] scale and properly sized
|
57 |
+
if self.modulation_dim is None:
|
58 |
+
assert mod is None, "Unexpected modulation input in dinov2 forward."
|
59 |
+
outs = self.model(image, is_training=True)
|
60 |
+
else:
|
61 |
+
assert mod is not None, "Modulation input is required in modulated dinov2 forward."
|
62 |
+
outs = self.model(image, mod=mod, is_training=True)
|
63 |
+
ret = torch.cat([
|
64 |
+
outs["x_norm_clstoken"].unsqueeze(dim=1),
|
65 |
+
outs["x_norm_patchtokens"],
|
66 |
+
], dim=1)
|
67 |
+
return ret
|
core/geometry/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
core/geometry/camera/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class Camera(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(Camera, self).__init__()
|
16 |
+
pass
|
core/geometry/camera/perspective_camera.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from . import Camera
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def projection(fovy, n=1.0, f=50.0, near_plane=None):
|
16 |
+
focal = np.tan(fovy / 180.0 * np.pi * 0.5)
|
17 |
+
if near_plane is None:
|
18 |
+
near_plane = n
|
19 |
+
return np.array(
|
20 |
+
[[n / focal, 0, 0, 0],
|
21 |
+
[0, n / -focal, 0, 0],
|
22 |
+
[0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
|
23 |
+
[0, 0, -1, 0]]).astype(np.float32)
|
24 |
+
|
25 |
+
def projection_2(opt):
|
26 |
+
zfar= opt.zfar
|
27 |
+
znear= opt.znear
|
28 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
29 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
30 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
31 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
32 |
+
proj_matrix[2, 2] = (zfar + znear) / (zfar - znear)
|
33 |
+
proj_matrix[3, 2] = - (zfar * znear) / (zfar - znear)
|
34 |
+
proj_matrix[2, 3] = 1
|
35 |
+
|
36 |
+
return proj_matrix
|
37 |
+
|
38 |
+
|
39 |
+
class PerspectiveCamera(Camera):
|
40 |
+
def __init__(self, opt, device='cuda'):
|
41 |
+
super(PerspectiveCamera, self).__init__()
|
42 |
+
self.device = device
|
43 |
+
self.proj_mtx = torch.from_numpy(projection(opt.fovy, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
|
44 |
+
#self.proj_mtx= projection_2(opt).to(self.device).unsqueeze(dim=0)
|
45 |
+
|
46 |
+
|
47 |
+
def project(self, points_bxnx4):
|
48 |
+
out = torch.matmul(
|
49 |
+
points_bxnx4,
|
50 |
+
torch.transpose(self.proj_mtx, 1, 2))
|
51 |
+
return out
|
core/geometry/render/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Renderer():
|
4 |
+
def __init__(self):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def forward(self):
|
8 |
+
pass
|
core/geometry/render/neural_render.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
from . import Renderer
|
13 |
+
|
14 |
+
_FG_LUT = None
|
15 |
+
|
16 |
+
|
17 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
18 |
+
return dr.interpolate(
|
19 |
+
attr.contiguous(), rast, attr_idx, rast_db=rast_db,
|
20 |
+
diff_attrs=None if rast_db is None else 'all')
|
21 |
+
|
22 |
+
|
23 |
+
def xfm_points(points, matrix, use_python=True):
|
24 |
+
'''Transform points.
|
25 |
+
Args:
|
26 |
+
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
27 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
28 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
29 |
+
Returns:
|
30 |
+
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
31 |
+
'''
|
32 |
+
out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
33 |
+
if torch.is_anomaly_enabled():
|
34 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
35 |
+
return out
|
36 |
+
|
37 |
+
|
38 |
+
def dot(x, y):
|
39 |
+
return torch.sum(x * y, -1, keepdim=True)
|
40 |
+
|
41 |
+
|
42 |
+
def compute_vertex_normal(v_pos, t_pos_idx):
|
43 |
+
i0 = t_pos_idx[:, 0]
|
44 |
+
i1 = t_pos_idx[:, 1]
|
45 |
+
i2 = t_pos_idx[:, 2]
|
46 |
+
|
47 |
+
v0 = v_pos[i0, :]
|
48 |
+
v1 = v_pos[i1, :]
|
49 |
+
v2 = v_pos[i2, :]
|
50 |
+
|
51 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
52 |
+
|
53 |
+
# Splat face normals to vertices
|
54 |
+
v_nrm = torch.zeros_like(v_pos)
|
55 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
56 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
57 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
58 |
+
|
59 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
60 |
+
v_nrm = torch.where(
|
61 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
62 |
+
)
|
63 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
64 |
+
assert torch.all(torch.isfinite(v_nrm))
|
65 |
+
|
66 |
+
return v_nrm
|
67 |
+
|
68 |
+
|
69 |
+
class NeuralRender(Renderer):
|
70 |
+
def __init__(self, device='cuda', camera_model=None):
|
71 |
+
super(NeuralRender, self).__init__()
|
72 |
+
self.device = device
|
73 |
+
self.ctx = dr.RasterizeCudaContext(device=device)
|
74 |
+
self.projection_mtx = None
|
75 |
+
self.camera = camera_model
|
76 |
+
|
77 |
+
def render_mesh(
|
78 |
+
self,
|
79 |
+
mesh_v_pos_bxnx3,
|
80 |
+
mesh_t_pos_idx_fx3,
|
81 |
+
camera_mv_bx4x4,
|
82 |
+
mesh_v_feat_bxnxd,
|
83 |
+
resolution=256,
|
84 |
+
spp=1,
|
85 |
+
device='cuda',
|
86 |
+
hierarchical_mask=False
|
87 |
+
):
|
88 |
+
assert not hierarchical_mask
|
89 |
+
|
90 |
+
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
|
91 |
+
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
|
92 |
+
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
|
93 |
+
|
94 |
+
v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
|
95 |
+
|
96 |
+
# Render the image,
|
97 |
+
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
|
98 |
+
num_layers = 1
|
99 |
+
mask_pyramid = None
|
100 |
+
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
|
101 |
+
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
|
102 |
+
|
103 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
|
104 |
+
for _ in range(num_layers):
|
105 |
+
rast, db = peeler.rasterize_next_layer()
|
106 |
+
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
|
107 |
+
|
108 |
+
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
|
109 |
+
antialias_mask = dr.antialias(
|
110 |
+
hard_mask.clone().contiguous(), rast, v_pos_clip,
|
111 |
+
mesh_t_pos_idx_fx3)
|
112 |
+
|
113 |
+
depth = gb_feat[..., -2:-1]
|
114 |
+
ori_mesh_feature = gb_feat[..., :-4]
|
115 |
+
|
116 |
+
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
|
117 |
+
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
|
118 |
+
normal = F.normalize(normal, dim=-1)
|
119 |
+
normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
|
120 |
+
|
121 |
+
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
|
core/geometry/rep_3d/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
class Geometry():
|
14 |
+
def __init__(self):
|
15 |
+
pass
|
16 |
+
|
17 |
+
def forward(self):
|
18 |
+
pass
|
core/geometry/rep_3d/dmtet.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from . import Geometry
|
13 |
+
from .dmtet_utils import get_center_boundary_index
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
###############################################################################
|
18 |
+
# DMTet utility functions
|
19 |
+
###############################################################################
|
20 |
+
def create_mt_variable(device):
|
21 |
+
triangle_table = torch.tensor(
|
22 |
+
[
|
23 |
+
[-1, -1, -1, -1, -1, -1],
|
24 |
+
[1, 0, 2, -1, -1, -1],
|
25 |
+
[4, 0, 3, -1, -1, -1],
|
26 |
+
[1, 4, 2, 1, 3, 4],
|
27 |
+
[3, 1, 5, -1, -1, -1],
|
28 |
+
[2, 3, 0, 2, 5, 3],
|
29 |
+
[1, 4, 0, 1, 5, 4],
|
30 |
+
[4, 2, 5, -1, -1, -1],
|
31 |
+
[4, 5, 2, -1, -1, -1],
|
32 |
+
[4, 1, 0, 4, 5, 1],
|
33 |
+
[3, 2, 0, 3, 5, 2],
|
34 |
+
[1, 3, 5, -1, -1, -1],
|
35 |
+
[4, 1, 2, 4, 3, 1],
|
36 |
+
[3, 0, 4, -1, -1, -1],
|
37 |
+
[2, 0, 1, -1, -1, -1],
|
38 |
+
[-1, -1, -1, -1, -1, -1]
|
39 |
+
], dtype=torch.long, device=device)
|
40 |
+
|
41 |
+
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
42 |
+
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
43 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
44 |
+
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
45 |
+
|
46 |
+
|
47 |
+
def sort_edges(edges_ex2):
|
48 |
+
with torch.no_grad():
|
49 |
+
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
50 |
+
order = order.unsqueeze(dim=1)
|
51 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
52 |
+
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
53 |
+
return torch.stack([a, b], -1)
|
54 |
+
|
55 |
+
|
56 |
+
###############################################################################
|
57 |
+
# marching tetrahedrons (differentiable)
|
58 |
+
###############################################################################
|
59 |
+
|
60 |
+
def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id):
|
61 |
+
with torch.no_grad():
|
62 |
+
occ_n = sdf_n > 0
|
63 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
64 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
65 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
66 |
+
occ_sum = occ_sum[valid_tets]
|
67 |
+
|
68 |
+
# find all vertices
|
69 |
+
all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
|
70 |
+
all_edges = sort_edges(all_edges)
|
71 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
72 |
+
|
73 |
+
unique_edges = unique_edges.long()
|
74 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
75 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
|
76 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
|
77 |
+
idx_map = mapping[idx_map] # map edges to verts
|
78 |
+
|
79 |
+
interp_v = unique_edges[mask_edges] # .long()
|
80 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
81 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
82 |
+
edges_to_interp_sdf[:, -1] *= -1
|
83 |
+
|
84 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
85 |
+
|
86 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
87 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
88 |
+
|
89 |
+
idx_map = idx_map.reshape(-1, 6)
|
90 |
+
|
91 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
92 |
+
num_triangles = num_triangles_table[tetindex]
|
93 |
+
|
94 |
+
# Generate triangle indices
|
95 |
+
faces = torch.cat(
|
96 |
+
(
|
97 |
+
torch.gather(
|
98 |
+
input=idx_map[num_triangles == 1], dim=1,
|
99 |
+
index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
|
100 |
+
torch.gather(
|
101 |
+
input=idx_map[num_triangles == 2], dim=1,
|
102 |
+
index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
|
103 |
+
), dim=0)
|
104 |
+
return verts, faces
|
105 |
+
|
106 |
+
|
107 |
+
def create_tetmesh_variables(device='cuda'):
|
108 |
+
tet_table = torch.tensor(
|
109 |
+
[[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
110 |
+
[0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
|
111 |
+
[1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
|
112 |
+
[1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
|
113 |
+
[2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
|
114 |
+
[2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
|
115 |
+
[2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
|
116 |
+
[6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
|
117 |
+
[3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
|
118 |
+
[3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
|
119 |
+
[3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
|
120 |
+
[5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
|
121 |
+
[3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
|
122 |
+
[4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
|
123 |
+
[4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
|
124 |
+
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
|
125 |
+
num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)
|
126 |
+
return tet_table, num_tets_table
|
127 |
+
|
128 |
+
|
129 |
+
def marching_tets_tetmesh(
|
130 |
+
pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
131 |
+
return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
132 |
+
with torch.no_grad():
|
133 |
+
occ_n = sdf_n > 0
|
134 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
135 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
136 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
137 |
+
occ_sum = occ_sum[valid_tets]
|
138 |
+
|
139 |
+
# find all vertices
|
140 |
+
all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
|
141 |
+
all_edges = sort_edges(all_edges)
|
142 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
143 |
+
|
144 |
+
unique_edges = unique_edges.long()
|
145 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
146 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
|
147 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
|
148 |
+
idx_map = mapping[idx_map] # map edges to verts
|
149 |
+
|
150 |
+
interp_v = unique_edges[mask_edges] # .long()
|
151 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
152 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
153 |
+
edges_to_interp_sdf[:, -1] *= -1
|
154 |
+
|
155 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
156 |
+
|
157 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
158 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
159 |
+
|
160 |
+
idx_map = idx_map.reshape(-1, 6)
|
161 |
+
|
162 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
163 |
+
num_triangles = num_triangles_table[tetindex]
|
164 |
+
|
165 |
+
# Generate triangle indices
|
166 |
+
faces = torch.cat(
|
167 |
+
(
|
168 |
+
torch.gather(
|
169 |
+
input=idx_map[num_triangles == 1], dim=1,
|
170 |
+
index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
|
171 |
+
torch.gather(
|
172 |
+
input=idx_map[num_triangles == 2], dim=1,
|
173 |
+
index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
|
174 |
+
), dim=0)
|
175 |
+
if not return_tet_mesh:
|
176 |
+
return verts, faces
|
177 |
+
occupied_verts = ori_v[occ_n]
|
178 |
+
mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1
|
179 |
+
mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda")
|
180 |
+
tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4))
|
181 |
+
|
182 |
+
idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10
|
183 |
+
tet_verts = torch.cat([verts, occupied_verts], 0)
|
184 |
+
num_tets = num_tets_table[tetindex]
|
185 |
+
|
186 |
+
tets = torch.cat(
|
187 |
+
(
|
188 |
+
torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape(
|
189 |
+
-1,
|
190 |
+
4),
|
191 |
+
torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape(
|
192 |
+
-1,
|
193 |
+
4),
|
194 |
+
), dim=0)
|
195 |
+
# add fully occupied tets
|
196 |
+
fully_occupied = occ_fx4.sum(-1) == 4
|
197 |
+
tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
|
198 |
+
tets = torch.cat([tets, tet_fully_occupied])
|
199 |
+
|
200 |
+
return verts, faces, tet_verts, tets
|
201 |
+
|
202 |
+
|
203 |
+
###############################################################################
|
204 |
+
# Compact tet grid
|
205 |
+
###############################################################################
|
206 |
+
|
207 |
+
def compact_tets(pos_nx3, sdf_n, tet_fx4):
|
208 |
+
with torch.no_grad():
|
209 |
+
# Find surface tets
|
210 |
+
occ_n = sdf_n > 0
|
211 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
212 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
213 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets
|
214 |
+
|
215 |
+
valid_vtx = tet_fx4[valid_tets].reshape(-1)
|
216 |
+
unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
|
217 |
+
new_pos = pos_nx3[unique_vtx]
|
218 |
+
new_sdf = sdf_n[unique_vtx]
|
219 |
+
new_tets = idx_map.reshape(-1, 4)
|
220 |
+
return new_pos, new_sdf, new_tets
|
221 |
+
|
222 |
+
|
223 |
+
###############################################################################
|
224 |
+
# Subdivide volume
|
225 |
+
###############################################################################
|
226 |
+
|
227 |
+
def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
|
228 |
+
device = tet_pos_bxnx3.device
|
229 |
+
# get new verts
|
230 |
+
tet_fx4 = tet_bxfx4[0]
|
231 |
+
edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
|
232 |
+
all_edges = tet_fx4[:, edges].reshape(-1, 2)
|
233 |
+
all_edges = sort_edges(all_edges)
|
234 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
235 |
+
idx_map = idx_map + tet_pos_bxnx3.shape[1]
|
236 |
+
all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1)
|
237 |
+
mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
|
238 |
+
all_values.shape[0], -1, 2,
|
239 |
+
all_values.shape[-1]).mean(2)
|
240 |
+
new_v = torch.cat([all_values, mid_points_pos], 1)
|
241 |
+
new_v, new_sdf = new_v[..., :3], new_v[..., 3]
|
242 |
+
|
243 |
+
# get new tets
|
244 |
+
|
245 |
+
idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
|
246 |
+
idx_ab = idx_map[0::6]
|
247 |
+
idx_ac = idx_map[1::6]
|
248 |
+
idx_ad = idx_map[2::6]
|
249 |
+
idx_bc = idx_map[3::6]
|
250 |
+
idx_bd = idx_map[4::6]
|
251 |
+
idx_cd = idx_map[5::6]
|
252 |
+
|
253 |
+
tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
|
254 |
+
tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
|
255 |
+
tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
|
256 |
+
tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
|
257 |
+
tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
|
258 |
+
tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
|
259 |
+
tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
|
260 |
+
tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
|
261 |
+
|
262 |
+
tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
|
263 |
+
tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
|
264 |
+
tet = tet_np.long().to(device)
|
265 |
+
|
266 |
+
return new_v, tet, new_sdf
|
267 |
+
|
268 |
+
|
269 |
+
###############################################################################
|
270 |
+
# Adjacency
|
271 |
+
###############################################################################
|
272 |
+
def tet_to_tet_adj_sparse(tet_tx4):
|
273 |
+
# include self connection!!!!!!!!!!!!!!!!!!!
|
274 |
+
with torch.no_grad():
|
275 |
+
t = tet_tx4.shape[0]
|
276 |
+
device = tet_tx4.device
|
277 |
+
idx_array = torch.LongTensor(
|
278 |
+
[0, 1, 2,
|
279 |
+
1, 0, 3,
|
280 |
+
2, 3, 0,
|
281 |
+
3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3)
|
282 |
+
|
283 |
+
# get all faces
|
284 |
+
all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape(
|
285 |
+
-1,
|
286 |
+
3) # (tx4, 3)
|
287 |
+
all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1)
|
288 |
+
# sort and group
|
289 |
+
all_faces_sorted, _ = torch.sort(all_faces, dim=1)
|
290 |
+
|
291 |
+
all_faces_unique, inverse_indices, counts = torch.unique(
|
292 |
+
all_faces_sorted, dim=0, return_counts=True,
|
293 |
+
return_inverse=True)
|
294 |
+
tet_face_fx3 = all_faces_unique[counts == 2]
|
295 |
+
counts = counts[inverse_indices] # tx4
|
296 |
+
valid = (counts == 2)
|
297 |
+
|
298 |
+
group = inverse_indices[valid]
|
299 |
+
# print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape)
|
300 |
+
_, indices = torch.sort(group)
|
301 |
+
all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices]
|
302 |
+
tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1)
|
303 |
+
|
304 |
+
tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])])
|
305 |
+
adj_self = torch.arange(t, device=tet_tx4.device)
|
306 |
+
adj_self = torch.stack([adj_self, adj_self], -1)
|
307 |
+
tet_adj_idx = torch.cat([tet_adj_idx, adj_self])
|
308 |
+
|
309 |
+
tet_adj_idx = torch.unique(tet_adj_idx, dim=0)
|
310 |
+
values = torch.ones(
|
311 |
+
tet_adj_idx.shape[0], device=tet_tx4.device).float()
|
312 |
+
adj_sparse = torch.sparse.FloatTensor(
|
313 |
+
tet_adj_idx.t(), values, torch.Size([t, t]))
|
314 |
+
|
315 |
+
# normalization
|
316 |
+
neighbor_num = 1.0 / torch.sparse.sum(
|
317 |
+
adj_sparse, dim=1).to_dense()
|
318 |
+
values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0])
|
319 |
+
adj_sparse = torch.sparse.FloatTensor(
|
320 |
+
tet_adj_idx.t(), values, torch.Size([t, t]))
|
321 |
+
return adj_sparse
|
322 |
+
|
323 |
+
|
324 |
+
###############################################################################
|
325 |
+
# Compact grid
|
326 |
+
###############################################################################
|
327 |
+
|
328 |
+
def get_tet_bxfx4x3(bxnxz, bxfx4):
|
329 |
+
n_batch, z = bxnxz.shape[0], bxnxz.shape[2]
|
330 |
+
gather_input = bxnxz.unsqueeze(2).expand(
|
331 |
+
n_batch, bxnxz.shape[1], 4, z)
|
332 |
+
gather_index = bxfx4.unsqueeze(-1).expand(
|
333 |
+
n_batch, bxfx4.shape[1], 4, z).long()
|
334 |
+
tet_bxfx4xz = torch.gather(
|
335 |
+
input=gather_input, dim=1, index=gather_index)
|
336 |
+
|
337 |
+
return tet_bxfx4xz
|
338 |
+
|
339 |
+
|
340 |
+
def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
|
341 |
+
with torch.no_grad():
|
342 |
+
assert tet_pos_bxnx3.shape[0] == 1
|
343 |
+
|
344 |
+
occ = grid_sdf[0] > 0
|
345 |
+
occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1)
|
346 |
+
mask = (occ_sum > 0) & (occ_sum < 4)
|
347 |
+
|
348 |
+
# build connectivity graph
|
349 |
+
adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0])
|
350 |
+
mask = mask.float().unsqueeze(-1)
|
351 |
+
|
352 |
+
# Include a one ring of neighbors
|
353 |
+
for i in range(1):
|
354 |
+
mask = torch.sparse.mm(adj_matrix, mask)
|
355 |
+
mask = mask.squeeze(-1) > 0
|
356 |
+
|
357 |
+
mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long)
|
358 |
+
new_tet_bxfx4 = tet_bxfx4[:, mask].long()
|
359 |
+
selected_verts_idx = torch.unique(new_tet_bxfx4)
|
360 |
+
new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx]
|
361 |
+
mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device)
|
362 |
+
new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape)
|
363 |
+
new_grid_sdf = grid_sdf[:, selected_verts_idx]
|
364 |
+
return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf
|
365 |
+
|
366 |
+
|
367 |
+
###############################################################################
|
368 |
+
# Regularizer
|
369 |
+
###############################################################################
|
370 |
+
|
371 |
+
def sdf_reg_loss(sdf, all_edges):
|
372 |
+
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
|
373 |
+
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
374 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
375 |
+
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
|
376 |
+
sdf_f1x6x2[..., 0],
|
377 |
+
(sdf_f1x6x2[..., 1] > 0).float()) + \
|
378 |
+
torch.nn.functional.binary_cross_entropy_with_logits(
|
379 |
+
sdf_f1x6x2[..., 1],
|
380 |
+
(sdf_f1x6x2[..., 0] > 0).float())
|
381 |
+
return sdf_diff
|
382 |
+
|
383 |
+
|
384 |
+
def sdf_reg_loss_batch(sdf, all_edges):
|
385 |
+
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
|
386 |
+
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
387 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
388 |
+
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
389 |
+
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
390 |
+
return sdf_diff
|
391 |
+
|
392 |
+
|
393 |
+
###############################################################################
|
394 |
+
# Geometry interface
|
395 |
+
###############################################################################
|
396 |
+
class DMTetGeometry(Geometry):
|
397 |
+
def __init__(
|
398 |
+
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
399 |
+
render_type='neural_render', args=None):
|
400 |
+
super(DMTetGeometry, self).__init__()
|
401 |
+
self.grid_res = grid_res
|
402 |
+
self.device = device
|
403 |
+
self.args = args
|
404 |
+
tets = np.load('data/tets/%d_compress.npz' % (grid_res))
|
405 |
+
self.verts = torch.from_numpy(tets['vertices']).float().to(self.device)
|
406 |
+
# Make sure the tet is zero-centered and length is equal to 1
|
407 |
+
length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
|
408 |
+
length = length.max()
|
409 |
+
mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0
|
410 |
+
self.verts = (self.verts - mid.unsqueeze(dim=0)) / length
|
411 |
+
if isinstance(scale, list):
|
412 |
+
self.verts[:, 0] = self.verts[:, 0] * scale[0]
|
413 |
+
self.verts[:, 1] = self.verts[:, 1] * scale[1]
|
414 |
+
self.verts[:, 2] = self.verts[:, 2] * scale[1]
|
415 |
+
else:
|
416 |
+
self.verts = self.verts * scale
|
417 |
+
self.indices = torch.from_numpy(tets['tets']).long().to(self.device)
|
418 |
+
self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device)
|
419 |
+
self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device)
|
420 |
+
# Parameters for regularization computation
|
421 |
+
edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
|
422 |
+
all_edges = self.indices[:, edges].reshape(-1, 2)
|
423 |
+
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
424 |
+
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
425 |
+
|
426 |
+
# Parameters used for fix boundary sdf
|
427 |
+
self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
|
428 |
+
self.renderer = renderer
|
429 |
+
self.render_type = render_type
|
430 |
+
|
431 |
+
def getAABB(self):
|
432 |
+
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
433 |
+
|
434 |
+
def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
|
435 |
+
if indices is None:
|
436 |
+
indices = self.indices
|
437 |
+
verts, faces = marching_tets(
|
438 |
+
v_deformed_nx3, sdf_n, indices, self.triangle_table,
|
439 |
+
self.num_triangles_table, self.base_tet_edges, self.v_id)
|
440 |
+
faces = torch.cat(
|
441 |
+
[faces[:, 0:1],
|
442 |
+
faces[:, 2:3],
|
443 |
+
faces[:, 1:2], ], dim=-1)
|
444 |
+
return verts, faces
|
445 |
+
|
446 |
+
def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
|
447 |
+
if indices is None:
|
448 |
+
indices = self.indices
|
449 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(
|
450 |
+
v_deformed_nx3, sdf_n, indices, self.triangle_table,
|
451 |
+
self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True,
|
452 |
+
num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3)
|
453 |
+
faces = torch.cat(
|
454 |
+
[faces[:, 0:1],
|
455 |
+
faces[:, 2:3],
|
456 |
+
faces[:, 1:2], ], dim=-1)
|
457 |
+
return verts, faces, tet_verts, tets
|
458 |
+
|
459 |
+
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
|
460 |
+
return_value = dict()
|
461 |
+
if self.render_type == 'neural_render':
|
462 |
+
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
|
463 |
+
mesh_v_nx3.unsqueeze(dim=0),
|
464 |
+
mesh_f_fx3.int(),
|
465 |
+
camera_mv_bx4x4,
|
466 |
+
mesh_v_nx3.unsqueeze(dim=0),
|
467 |
+
resolution=resolution,
|
468 |
+
device=self.device,
|
469 |
+
hierarchical_mask=hierarchical_mask
|
470 |
+
)
|
471 |
+
|
472 |
+
return_value['tex_pos'] = tex_pos
|
473 |
+
return_value['mask'] = mask
|
474 |
+
return_value['hard_mask'] = hard_mask
|
475 |
+
return_value['rast'] = rast
|
476 |
+
return_value['v_pos_clip'] = v_pos_clip
|
477 |
+
return_value['mask_pyramid'] = mask_pyramid
|
478 |
+
return_value['depth'] = depth
|
479 |
+
else:
|
480 |
+
raise NotImplementedError
|
481 |
+
|
482 |
+
return return_value
|
483 |
+
|
484 |
+
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
|
485 |
+
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
|
486 |
+
v_list = []
|
487 |
+
f_list = []
|
488 |
+
n_batch = v_deformed_bxnx3.shape[0]
|
489 |
+
all_render_output = []
|
490 |
+
for i_batch in range(n_batch):
|
491 |
+
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
|
492 |
+
v_list.append(verts_nx3)
|
493 |
+
f_list.append(faces_fx3)
|
494 |
+
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
|
495 |
+
all_render_output.append(render_output)
|
496 |
+
|
497 |
+
# Concatenate all render output
|
498 |
+
return_keys = all_render_output[0].keys()
|
499 |
+
return_value = dict()
|
500 |
+
for k in return_keys:
|
501 |
+
value = [v[k] for v in all_render_output]
|
502 |
+
return_value[k] = value
|
503 |
+
# We can do concatenation outside of the render
|
504 |
+
return return_value
|
core/geometry/rep_3d/dmtet_utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def get_center_boundary_index(verts):
|
13 |
+
length_ = torch.sum(verts ** 2, dim=-1)
|
14 |
+
center_idx = torch.argmin(length_)
|
15 |
+
boundary_neg = verts == verts.max()
|
16 |
+
boundary_pos = verts == verts.min()
|
17 |
+
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
18 |
+
boundary = torch.sum(boundary.float(), dim=-1)
|
19 |
+
boundary_idx = torch.nonzero(boundary)
|
20 |
+
return center_idx, boundary_idx.squeeze(dim=-1)
|
core/geometry/rep_3d/extract_texture_map.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import xatlas
|
11 |
+
import numpy as np
|
12 |
+
import nvdiffrast.torch as dr
|
13 |
+
|
14 |
+
|
15 |
+
# ==============================================================================================
|
16 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
17 |
+
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
18 |
+
|
19 |
+
|
20 |
+
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
|
21 |
+
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
|
22 |
+
|
23 |
+
# Convert to tensors
|
24 |
+
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
25 |
+
|
26 |
+
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
|
27 |
+
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
|
28 |
+
# mesh_v_tex. ture
|
29 |
+
uv_clip = uvs[None, ...] * 2.0 - 1.0
|
30 |
+
|
31 |
+
# pad to four component coordinate
|
32 |
+
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
|
33 |
+
|
34 |
+
# rasterize
|
35 |
+
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
|
36 |
+
|
37 |
+
# Interpolate world space position
|
38 |
+
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
|
39 |
+
mask = rast[..., 3:4] > 0
|
40 |
+
return uvs, mesh_tex_idx, gb_pos, mask
|
core/geometry/rep_3d/flexicubes.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import torch
|
9 |
+
from .tables import *
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'FlexiCubes'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
class FlexiCubes:
|
17 |
+
"""
|
18 |
+
This class implements the FlexiCubes method for extracting meshes from scalar fields.
|
19 |
+
It maintains a series of lookup tables and indices to support the mesh extraction process.
|
20 |
+
FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
|
21 |
+
the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
|
22 |
+
the surface representation through gradient-based optimization.
|
23 |
+
|
24 |
+
During instantiation, the class loads DMC tables from a file and transforms them into
|
25 |
+
PyTorch tensors on the specified device.
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
device (str): Specifies the computational device (default is "cuda").
|
29 |
+
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
+
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
+
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
32 |
+
the 256 MC configurations.
|
33 |
+
check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
|
34 |
+
of the DMC configurations.
|
35 |
+
tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
|
36 |
+
quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
|
37 |
+
along one diagonal.
|
38 |
+
quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
|
39 |
+
two triangles along the other diagonal.
|
40 |
+
quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
|
41 |
+
during training by connecting all edges to their midpoints.
|
42 |
+
cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
|
43 |
+
eight corners in 3D space, ordered starting from the origin (0,0,0),
|
44 |
+
moving along the x-axis, then y-axis, and finally z-axis.
|
45 |
+
Used as a blueprint for generating a voxel grid.
|
46 |
+
cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
|
47 |
+
to retrieve the case id.
|
48 |
+
cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
|
49 |
+
Used to retrieve edge vertices in DMC.
|
50 |
+
edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
|
51 |
+
their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
|
52 |
+
first edge is oriented along the x-axis.
|
53 |
+
dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
|
54 |
+
across four adjacent cubes to the shared faces of these cubes. For instance,
|
55 |
+
dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
|
56 |
+
the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
|
57 |
+
This tensor is only utilized during isosurface tetrahedralization.
|
58 |
+
adj_pairs (torch.Tensor):
|
59 |
+
A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
|
60 |
+
qef_reg_scale (float):
|
61 |
+
The scaling factor applied to the regularization loss to prevent issues with singularity
|
62 |
+
when solving the QEF. This parameter is only used when a 'grad_func' is specified.
|
63 |
+
weight_scale (float):
|
64 |
+
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
|
68 |
+
|
69 |
+
self.device = device
|
70 |
+
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
+
self.num_vd_table = torch.tensor(num_vd_table,
|
72 |
+
dtype=torch.long, device=device, requires_grad=False)
|
73 |
+
self.check_table = torch.tensor(
|
74 |
+
check_table,
|
75 |
+
dtype=torch.long, device=device, requires_grad=False)
|
76 |
+
|
77 |
+
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
78 |
+
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
79 |
+
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
80 |
+
self.quad_split_train = torch.tensor(
|
81 |
+
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
82 |
+
|
83 |
+
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
84 |
+
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
85 |
+
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
86 |
+
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
87 |
+
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
88 |
+
|
89 |
+
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
90 |
+
dtype=torch.long, device=device)
|
91 |
+
self.dir_faces_table = torch.tensor([
|
92 |
+
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
93 |
+
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
94 |
+
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
95 |
+
], dtype=torch.long, device=device)
|
96 |
+
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
97 |
+
self.qef_reg_scale = qef_reg_scale
|
98 |
+
self.weight_scale = weight_scale
|
99 |
+
|
100 |
+
def construct_voxel_grid(self, res):
|
101 |
+
"""
|
102 |
+
Generates a voxel grid based on the specified resolution.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
res (int or list[int]): The resolution of the voxel grid. If an integer
|
106 |
+
is provided, it is used for all three dimensions. If a list or tuple
|
107 |
+
of 3 integers is provided, they define the resolution for the x,
|
108 |
+
y, and z dimensions respectively.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
|
112 |
+
cube corners (index into vertices) of the constructed voxel grid.
|
113 |
+
The vertices are centered at the origin, with the length of each
|
114 |
+
dimension in the grid being one.
|
115 |
+
"""
|
116 |
+
base_cube_f = torch.arange(8).to(self.device)
|
117 |
+
if isinstance(res, int):
|
118 |
+
res = (res, res, res)
|
119 |
+
voxel_grid_template = torch.ones(res, device=self.device)
|
120 |
+
|
121 |
+
res = torch.tensor([res], dtype=torch.float, device=self.device)
|
122 |
+
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
|
123 |
+
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
|
124 |
+
cubes = (base_cube_f.unsqueeze(0) +
|
125 |
+
torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
|
126 |
+
|
127 |
+
verts_rounded = torch.round(verts * 10**5) / (10**5)
|
128 |
+
verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
|
129 |
+
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
|
130 |
+
|
131 |
+
return verts_unique - 0.5, cubes
|
132 |
+
|
133 |
+
def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
|
134 |
+
gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
|
135 |
+
r"""
|
136 |
+
Main function for mesh extraction from scalar field using FlexiCubes. This function converts
|
137 |
+
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
|
138 |
+
to triangle or tetrahedral meshes using a differentiable operation as described in
|
139 |
+
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
|
140 |
+
mesh quality and geometric fidelity by adjusting the surface representation based on gradient
|
141 |
+
optimization. The output surface is differentiable with respect to the input vertex positions,
|
142 |
+
scalar field values, and weight parameters.
|
143 |
+
|
144 |
+
If you intend to extract a surface mesh from a fixed Signed Distance Field without the
|
145 |
+
optimization of parameters, it is suggested to provide the "grad_func" which should
|
146 |
+
return the surface gradient at any given 3D position. When grad_func is provided, the process
|
147 |
+
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
|
148 |
+
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
|
149 |
+
Please note, this approach is non-differentiable.
|
150 |
+
|
151 |
+
For more details and example usage in optimization, refer to the
|
152 |
+
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
|
156 |
+
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
|
157 |
+
denote that the corresponding vertex resides inside the isosurface. This affects
|
158 |
+
the directions of the extracted triangle faces and volume to be tetrahedralized.
|
159 |
+
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
|
160 |
+
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
|
161 |
+
is used for all three dimensions. If a list or tuple of 3 integers is provided, they
|
162 |
+
specify the resolution for the x, y, and z dimensions respectively.
|
163 |
+
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
|
164 |
+
vertices positioning. Defaults to uniform value for all edges.
|
165 |
+
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
|
166 |
+
vertices positioning. Defaults to uniform value for all vertices.
|
167 |
+
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
|
168 |
+
quadrilaterals into triangles. Defaults to uniform value for all cubes.
|
169 |
+
training (bool, optional): If set to True, applies differentiable quad splitting for
|
170 |
+
training. Defaults to False.
|
171 |
+
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
|
172 |
+
outputs a triangular mesh. Defaults to False.
|
173 |
+
grad_func (callable, optional): A function to compute the surface gradient at specified
|
174 |
+
3D positions (input: Nx3 positions). The function should return gradients as an Nx3
|
175 |
+
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
|
179 |
+
- Vertices for the extracted triangular/tetrahedral mesh.
|
180 |
+
- Faces for the extracted triangular/tetrahedral mesh.
|
181 |
+
- Regularizer L_dev, computed per dual vertex.
|
182 |
+
|
183 |
+
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
|
184 |
+
https://research.nvidia.com/labs/toronto-ai/flexicubes/
|
185 |
+
.. _Manifold Dual Contouring:
|
186 |
+
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
|
187 |
+
"""
|
188 |
+
|
189 |
+
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
|
190 |
+
if surf_cubes.sum() == 0:
|
191 |
+
return torch.zeros(
|
192 |
+
(0, 3),
|
193 |
+
device=self.device), torch.zeros(
|
194 |
+
(0, 4),
|
195 |
+
dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
|
196 |
+
(0, 3),
|
197 |
+
dtype=torch.long, device=self.device), torch.zeros(
|
198 |
+
(0),
|
199 |
+
device=self.device)
|
200 |
+
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
|
201 |
+
|
202 |
+
case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
|
203 |
+
|
204 |
+
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
|
205 |
+
|
206 |
+
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
|
207 |
+
x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
|
208 |
+
vertices, faces, s_edges, edge_indices = self._triangulate(
|
209 |
+
s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
|
210 |
+
if not output_tetmesh:
|
211 |
+
return vertices, faces, L_dev
|
212 |
+
else:
|
213 |
+
vertices, tets = self._tetrahedralize(
|
214 |
+
x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
215 |
+
surf_cubes, training)
|
216 |
+
return vertices, tets, L_dev
|
217 |
+
|
218 |
+
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
219 |
+
"""
|
220 |
+
Regularizer L_dev as in Equation 8
|
221 |
+
"""
|
222 |
+
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
223 |
+
mean_l2 = torch.zeros_like(vd[:, 0])
|
224 |
+
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
225 |
+
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
226 |
+
return mad
|
227 |
+
|
228 |
+
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
|
229 |
+
"""
|
230 |
+
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
231 |
+
"""
|
232 |
+
n_cubes = surf_cubes.shape[0]
|
233 |
+
|
234 |
+
if beta_fx12 is not None:
|
235 |
+
beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
|
236 |
+
else:
|
237 |
+
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
238 |
+
|
239 |
+
if alpha_fx8 is not None:
|
240 |
+
alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
|
241 |
+
else:
|
242 |
+
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
243 |
+
|
244 |
+
if gamma_f is not None:
|
245 |
+
gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
|
246 |
+
else:
|
247 |
+
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
248 |
+
|
249 |
+
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
253 |
+
"""
|
254 |
+
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
255 |
+
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
256 |
+
supplementary material. It should be noted that this function assumes a regular grid.
|
257 |
+
"""
|
258 |
+
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
259 |
+
|
260 |
+
problem_config = self.check_table.to(self.device)[case_ids]
|
261 |
+
to_check = problem_config[..., 0] == 1
|
262 |
+
problem_config = problem_config[to_check]
|
263 |
+
if not isinstance(res, (list, tuple)):
|
264 |
+
res = [res, res, res]
|
265 |
+
|
266 |
+
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
267 |
+
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
268 |
+
# This allows efficient checking on adjacent cubes.
|
269 |
+
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
270 |
+
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
271 |
+
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
272 |
+
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
273 |
+
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
274 |
+
|
275 |
+
within_range = (
|
276 |
+
vol_idx_problem_adj[..., 0] >= 0) & (
|
277 |
+
vol_idx_problem_adj[..., 0] < res[0]) & (
|
278 |
+
vol_idx_problem_adj[..., 1] >= 0) & (
|
279 |
+
vol_idx_problem_adj[..., 1] < res[1]) & (
|
280 |
+
vol_idx_problem_adj[..., 2] >= 0) & (
|
281 |
+
vol_idx_problem_adj[..., 2] < res[2])
|
282 |
+
|
283 |
+
vol_idx_problem = vol_idx_problem[within_range]
|
284 |
+
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
285 |
+
problem_config = problem_config[within_range]
|
286 |
+
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
287 |
+
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
288 |
+
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
289 |
+
to_invert = (problem_config_adj[..., 0] == 1)
|
290 |
+
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
291 |
+
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
292 |
+
return case_ids
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
|
296 |
+
"""
|
297 |
+
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
298 |
+
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
299 |
+
and marks the cube edges with this index.
|
300 |
+
"""
|
301 |
+
occ_n = s_n < 0
|
302 |
+
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
303 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
304 |
+
|
305 |
+
unique_edges = unique_edges.long()
|
306 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
307 |
+
|
308 |
+
surf_edges_mask = mask_edges[_idx_map]
|
309 |
+
counts = counts[_idx_map]
|
310 |
+
|
311 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
|
312 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
|
313 |
+
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
314 |
+
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
315 |
+
idx_map = mapping[_idx_map]
|
316 |
+
surf_edges = unique_edges[mask_edges]
|
317 |
+
return surf_edges, idx_map, counts, surf_edges_mask
|
318 |
+
|
319 |
+
@torch.no_grad()
|
320 |
+
def _identify_surf_cubes(self, s_n, cube_fx8):
|
321 |
+
"""
|
322 |
+
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
323 |
+
all corners are not identical.
|
324 |
+
"""
|
325 |
+
occ_n = s_n < 0
|
326 |
+
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
327 |
+
_occ_sum = torch.sum(occ_fx8, -1)
|
328 |
+
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
329 |
+
return surf_cubes, occ_fx8
|
330 |
+
|
331 |
+
def _linear_interp(self, edges_weight, edges_x):
|
332 |
+
"""
|
333 |
+
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
334 |
+
"""
|
335 |
+
edge_dim = edges_weight.dim() - 2
|
336 |
+
assert edges_weight.shape[edge_dim] == 2
|
337 |
+
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
338 |
+
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
|
339 |
+
denominator = edges_weight.sum(edge_dim)
|
340 |
+
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
341 |
+
return ue
|
342 |
+
|
343 |
+
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
|
344 |
+
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
345 |
+
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
346 |
+
c_bx3 = c_bx3.reshape(-1, 3)
|
347 |
+
A = norm_bxnx3
|
348 |
+
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
349 |
+
|
350 |
+
A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
351 |
+
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
|
352 |
+
A = torch.cat([A, A_reg], 1)
|
353 |
+
B = torch.cat([B, B_reg], 1)
|
354 |
+
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
355 |
+
return dual_verts
|
356 |
+
|
357 |
+
def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
|
358 |
+
"""
|
359 |
+
Computes the location of dual vertices as described in Section 4.2
|
360 |
+
"""
|
361 |
+
alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
362 |
+
surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
363 |
+
surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
364 |
+
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
365 |
+
|
366 |
+
idx_map = idx_map.reshape(-1, 12)
|
367 |
+
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
368 |
+
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
369 |
+
|
370 |
+
total_num_vd = 0
|
371 |
+
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
372 |
+
if grad_func is not None:
|
373 |
+
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
|
374 |
+
vd = []
|
375 |
+
for num in torch.unique(num_vd):
|
376 |
+
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
377 |
+
curr_num_vd = cur_cubes.sum() * num
|
378 |
+
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
379 |
+
curr_edge_group_to_vd = torch.arange(
|
380 |
+
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
381 |
+
total_num_vd += curr_num_vd
|
382 |
+
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
383 |
+
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
384 |
+
|
385 |
+
curr_mask = (curr_edge_group != -1)
|
386 |
+
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
387 |
+
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
388 |
+
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
389 |
+
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
390 |
+
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
391 |
+
|
392 |
+
if grad_func is not None:
|
393 |
+
with torch.no_grad():
|
394 |
+
cube_e_verts_idx = idx_map[cur_cubes]
|
395 |
+
curr_edge_group[~curr_mask] = 0
|
396 |
+
|
397 |
+
verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
|
398 |
+
verts_group_idx[verts_group_idx == -1] = 0
|
399 |
+
verts_group_pos = torch.index_select(
|
400 |
+
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
|
401 |
+
v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
|
402 |
+
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
|
403 |
+
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
|
404 |
+
|
405 |
+
normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
|
406 |
+
-1, num.item(), 7,
|
407 |
+
3)
|
408 |
+
curr_mask = curr_mask.squeeze(2)
|
409 |
+
vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
|
410 |
+
verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
|
411 |
+
edge_group = torch.cat(edge_group)
|
412 |
+
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
413 |
+
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
414 |
+
vd_num_edges = torch.cat(vd_num_edges)
|
415 |
+
vd_gamma = torch.cat(vd_gamma)
|
416 |
+
|
417 |
+
if grad_func is not None:
|
418 |
+
vd = torch.cat(vd)
|
419 |
+
L_dev = torch.zeros([1], device=self.device)
|
420 |
+
else:
|
421 |
+
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
422 |
+
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
423 |
+
|
424 |
+
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
425 |
+
|
426 |
+
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
427 |
+
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
428 |
+
|
429 |
+
zero_crossing_group = torch.index_select(
|
430 |
+
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
431 |
+
|
432 |
+
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
433 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
434 |
+
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
435 |
+
|
436 |
+
beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
|
437 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
438 |
+
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
439 |
+
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
440 |
+
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
441 |
+
|
442 |
+
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
443 |
+
|
444 |
+
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
445 |
+
12 + edge_group, src=v_idx[edge_group_to_vd])
|
446 |
+
|
447 |
+
return vd, L_dev, vd_gamma, vd_idx_map
|
448 |
+
|
449 |
+
def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
|
450 |
+
"""
|
451 |
+
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
452 |
+
triangles based on the gamma parameter, as described in Section 4.3.
|
453 |
+
"""
|
454 |
+
with torch.no_grad():
|
455 |
+
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
456 |
+
group = idx_map.reshape(-1)[group_mask]
|
457 |
+
vd_idx = vd_idx_map[group_mask]
|
458 |
+
edge_indices, indices = torch.sort(group, stable=True)
|
459 |
+
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
460 |
+
|
461 |
+
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
462 |
+
s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
463 |
+
flip_mask = s_edges[:, 0] > 0
|
464 |
+
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
465 |
+
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
466 |
+
if grad_func is not None:
|
467 |
+
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
|
468 |
+
with torch.no_grad():
|
469 |
+
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
|
470 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
471 |
+
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
|
472 |
+
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
|
473 |
+
else:
|
474 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
475 |
+
gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
476 |
+
0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
|
477 |
+
gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
478 |
+
1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
|
479 |
+
if not training:
|
480 |
+
mask = (gamma_02 > gamma_13).squeeze(1)
|
481 |
+
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
482 |
+
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
483 |
+
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
484 |
+
faces = faces.reshape(-1, 3)
|
485 |
+
else:
|
486 |
+
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
487 |
+
vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
|
488 |
+
torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
|
489 |
+
vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
|
490 |
+
torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
|
491 |
+
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
492 |
+
vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
|
493 |
+
weight_sum.unsqueeze(-1)).squeeze(1)
|
494 |
+
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
495 |
+
vd = torch.cat([vd, vd_center])
|
496 |
+
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
497 |
+
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
498 |
+
return vd, faces, s_edges, edge_indices
|
499 |
+
|
500 |
+
def _tetrahedralize(
|
501 |
+
self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
502 |
+
surf_cubes, training):
|
503 |
+
"""
|
504 |
+
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
|
505 |
+
"""
|
506 |
+
occ_n = s_n < 0
|
507 |
+
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
508 |
+
occ_sum = torch.sum(occ_fx8, -1)
|
509 |
+
|
510 |
+
inside_verts = x_nx3[occ_n]
|
511 |
+
mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
|
512 |
+
mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
|
513 |
+
"""
|
514 |
+
For each grid edge connecting two grid vertices with different
|
515 |
+
signs, we first form a four-sided pyramid by connecting one
|
516 |
+
of the grid vertices with four mesh vertices that correspond
|
517 |
+
to the grid edge and then subdivide the pyramid into two tetrahedra
|
518 |
+
"""
|
519 |
+
inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
|
520 |
+
s_edges < 0]]
|
521 |
+
if not training:
|
522 |
+
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
|
523 |
+
else:
|
524 |
+
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
|
525 |
+
|
526 |
+
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
|
527 |
+
"""
|
528 |
+
For each grid edge connecting two grid vertices with the
|
529 |
+
same sign, the tetrahedron is formed by the two grid vertices
|
530 |
+
and two vertices in consecutive adjacent cells
|
531 |
+
"""
|
532 |
+
inside_cubes = (occ_sum == 8)
|
533 |
+
inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
|
534 |
+
inside_cubes_center_idx = torch.arange(
|
535 |
+
inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
|
536 |
+
|
537 |
+
surface_n_inside_cubes = surf_cubes | inside_cubes
|
538 |
+
edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
|
539 |
+
dtype=torch.long, device=x_nx3.device) * -1
|
540 |
+
surf_cubes = surf_cubes[surface_n_inside_cubes]
|
541 |
+
inside_cubes = inside_cubes[surface_n_inside_cubes]
|
542 |
+
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
|
543 |
+
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
|
544 |
+
|
545 |
+
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
|
546 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
547 |
+
unique_edges = unique_edges.long()
|
548 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
|
549 |
+
mask = mask_edges[_idx_map]
|
550 |
+
counts = counts[_idx_map]
|
551 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
|
552 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
|
553 |
+
idx_map = mapping[_idx_map]
|
554 |
+
|
555 |
+
group_mask = (counts == 4) & mask
|
556 |
+
group = idx_map.reshape(-1)[group_mask]
|
557 |
+
edge_indices, indices = torch.sort(group)
|
558 |
+
cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
|
559 |
+
device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
|
560 |
+
edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
|
561 |
+
0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
|
562 |
+
# Identify the face shared by the adjacent cells.
|
563 |
+
cube_idx_4 = cube_idx[indices].reshape(-1, 4)
|
564 |
+
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
|
565 |
+
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
|
566 |
+
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
|
567 |
+
# Identify an edge of the face with different signs and
|
568 |
+
# select the mesh vertex corresponding to the identified edge.
|
569 |
+
case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
|
570 |
+
case_ids_expand[surf_cubes] = case_ids
|
571 |
+
cases = case_ids_expand[cube_idx_4x2]
|
572 |
+
quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
|
573 |
+
mask = (quad_edge == -1).sum(-1) == 0
|
574 |
+
inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
|
575 |
+
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
|
576 |
+
|
577 |
+
tets = torch.cat([tets_surface, tets_inside])
|
578 |
+
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
+
return vertices, tets
|
core/geometry/rep_3d/flexicubes_geometry.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from . import Geometry
|
13 |
+
from .flexicubes import FlexiCubes # replace later
|
14 |
+
from .dmtet import sdf_reg_loss_batch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
def get_center_boundary_index(grid_res, device):
|
18 |
+
v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
|
19 |
+
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
|
20 |
+
center_indices = torch.nonzero(v.reshape(-1))
|
21 |
+
|
22 |
+
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
|
23 |
+
v[:2, ...] = True
|
24 |
+
v[-2:, ...] = True
|
25 |
+
v[:, :2, ...] = True
|
26 |
+
v[:, -2:, ...] = True
|
27 |
+
v[:, :, :2] = True
|
28 |
+
v[:, :, -2:] = True
|
29 |
+
boundary_indices = torch.nonzero(v.reshape(-1))
|
30 |
+
return center_indices, boundary_indices
|
31 |
+
|
32 |
+
###############################################################################
|
33 |
+
# Geometry interface
|
34 |
+
###############################################################################
|
35 |
+
class FlexiCubesGeometry(Geometry):
|
36 |
+
def __init__(
|
37 |
+
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
38 |
+
render_type='neural_render', args=None):
|
39 |
+
super(FlexiCubesGeometry, self).__init__()
|
40 |
+
self.grid_res = grid_res
|
41 |
+
self.device = device
|
42 |
+
self.args = args
|
43 |
+
self.fc = FlexiCubes(device, weight_scale=0.5)
|
44 |
+
self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
|
45 |
+
if isinstance(scale, list):
|
46 |
+
self.verts[:, 0] = self.verts[:, 0] * scale[0]
|
47 |
+
self.verts[:, 1] = self.verts[:, 1] * scale[1]
|
48 |
+
self.verts[:, 2] = self.verts[:, 2] * scale[1]
|
49 |
+
else:
|
50 |
+
self.verts = self.verts * scale
|
51 |
+
|
52 |
+
all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
|
53 |
+
self.all_edges = torch.unique(all_edges, dim=0)
|
54 |
+
|
55 |
+
# Parameters used for fix boundary sdf
|
56 |
+
self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
|
57 |
+
self.renderer = renderer
|
58 |
+
self.render_type = render_type
|
59 |
+
|
60 |
+
def getAABB(self):
|
61 |
+
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
62 |
+
|
63 |
+
def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
|
64 |
+
if indices is None:
|
65 |
+
indices = self.indices
|
66 |
+
|
67 |
+
verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
|
68 |
+
beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
|
69 |
+
gamma_f=weight_n[:, 20], training=is_training
|
70 |
+
)
|
71 |
+
return verts, faces, v_reg_loss
|
72 |
+
|
73 |
+
|
74 |
+
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
|
75 |
+
return_value = dict()
|
76 |
+
if self.render_type == 'neural_render':
|
77 |
+
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh(
|
78 |
+
mesh_v_nx3.unsqueeze(dim=0),
|
79 |
+
mesh_f_fx3.int(),
|
80 |
+
camera_mv_bx4x4,
|
81 |
+
mesh_v_nx3.unsqueeze(dim=0),
|
82 |
+
resolution=resolution,
|
83 |
+
device=self.device,
|
84 |
+
hierarchical_mask=hierarchical_mask
|
85 |
+
)
|
86 |
+
|
87 |
+
return_value['tex_pos'] = tex_pos
|
88 |
+
return_value['mask'] = mask
|
89 |
+
return_value['hard_mask'] = hard_mask
|
90 |
+
return_value['rast'] = rast
|
91 |
+
return_value['v_pos_clip'] = v_pos_clip
|
92 |
+
return_value['mask_pyramid'] = mask_pyramid
|
93 |
+
return_value['depth'] = depth
|
94 |
+
return_value['normal'] = normal
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
return return_value
|
99 |
+
|
100 |
+
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
|
101 |
+
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
|
102 |
+
v_list = []
|
103 |
+
f_list = []
|
104 |
+
n_batch = v_deformed_bxnx3.shape[0]
|
105 |
+
all_render_output = []
|
106 |
+
for i_batch in range(n_batch):
|
107 |
+
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
|
108 |
+
v_list.append(verts_nx3)
|
109 |
+
f_list.append(faces_fx3)
|
110 |
+
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
|
111 |
+
all_render_output.append(render_output)
|
112 |
+
|
113 |
+
# Concatenate all render output
|
114 |
+
return_keys = all_render_output[0].keys()
|
115 |
+
return_value = dict()
|
116 |
+
for k in return_keys:
|
117 |
+
value = [v[k] for v in all_render_output]
|
118 |
+
return_value[k] = value
|
119 |
+
# We can do concatenation outside of the render
|
120 |
+
return return_value
|
core/geometry/rep_3d/tables.py
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
dmc_table = [
|
9 |
+
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
10 |
+
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
11 |
+
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
12 |
+
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
13 |
+
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
14 |
+
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
15 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
16 |
+
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
17 |
+
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
18 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
19 |
+
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
20 |
+
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
21 |
+
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
22 |
+
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
23 |
+
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
24 |
+
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
25 |
+
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
26 |
+
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
27 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
28 |
+
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
29 |
+
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
30 |
+
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
31 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
32 |
+
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
33 |
+
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
34 |
+
[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
35 |
+
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
36 |
+
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
37 |
+
[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
38 |
+
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
39 |
+
[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
40 |
+
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
41 |
+
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
42 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
43 |
+
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
44 |
+
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
45 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
46 |
+
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
47 |
+
[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
48 |
+
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
49 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
50 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
51 |
+
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
52 |
+
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
53 |
+
[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
54 |
+
[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
55 |
+
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
56 |
+
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
57 |
+
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
58 |
+
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
59 |
+
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
60 |
+
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
61 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
62 |
+
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
63 |
+
[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
64 |
+
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
65 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
66 |
+
[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
67 |
+
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
68 |
+
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
69 |
+
[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
70 |
+
[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
71 |
+
[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
72 |
+
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
73 |
+
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
74 |
+
[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
75 |
+
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
76 |
+
[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
77 |
+
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
78 |
+
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
79 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
80 |
+
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
81 |
+
[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
82 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
83 |
+
[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
84 |
+
[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
85 |
+
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
86 |
+
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
87 |
+
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
88 |
+
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
89 |
+
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
90 |
+
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
91 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
92 |
+
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
93 |
+
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
94 |
+
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
95 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
96 |
+
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
97 |
+
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
98 |
+
[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
99 |
+
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
100 |
+
[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
101 |
+
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
102 |
+
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
103 |
+
[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
104 |
+
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
105 |
+
[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
106 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
107 |
+
[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
108 |
+
[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
109 |
+
[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
110 |
+
[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
111 |
+
[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
112 |
+
[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
113 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
114 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
|
115 |
+
[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
116 |
+
[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
117 |
+
[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
118 |
+
[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
119 |
+
[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
120 |
+
[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
121 |
+
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
122 |
+
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
123 |
+
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
124 |
+
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
125 |
+
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
126 |
+
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
127 |
+
[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
128 |
+
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
129 |
+
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
130 |
+
[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
131 |
+
[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
132 |
+
[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
133 |
+
[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
134 |
+
[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
135 |
+
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
136 |
+
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
137 |
+
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
138 |
+
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
139 |
+
[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
140 |
+
[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
141 |
+
[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
142 |
+
[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
143 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
144 |
+
[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
145 |
+
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
146 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
147 |
+
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
148 |
+
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
149 |
+
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
150 |
+
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
151 |
+
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
152 |
+
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
153 |
+
[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
154 |
+
[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
155 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
156 |
+
[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
157 |
+
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
158 |
+
[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
159 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
|
160 |
+
[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
161 |
+
[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
162 |
+
[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
163 |
+
[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
164 |
+
[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
165 |
+
[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
166 |
+
[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
167 |
+
[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
168 |
+
[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
169 |
+
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
170 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
171 |
+
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
172 |
+
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
173 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
174 |
+
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
175 |
+
[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
176 |
+
[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
177 |
+
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
178 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
179 |
+
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
180 |
+
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
181 |
+
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
182 |
+
[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
183 |
+
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
184 |
+
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
185 |
+
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
186 |
+
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
187 |
+
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
188 |
+
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
189 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
190 |
+
[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
191 |
+
[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
192 |
+
[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
193 |
+
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
194 |
+
[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
195 |
+
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
196 |
+
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
197 |
+
[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
198 |
+
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
199 |
+
[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
200 |
+
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
201 |
+
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
202 |
+
[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
203 |
+
[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
204 |
+
[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
205 |
+
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
206 |
+
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
207 |
+
[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
208 |
+
[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
209 |
+
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
210 |
+
[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
211 |
+
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
212 |
+
[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
213 |
+
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
214 |
+
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
215 |
+
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
216 |
+
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
217 |
+
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
218 |
+
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
219 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
220 |
+
[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
221 |
+
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
222 |
+
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
223 |
+
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
224 |
+
[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
225 |
+
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
226 |
+
[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
227 |
+
[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
228 |
+
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
229 |
+
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
230 |
+
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
231 |
+
[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
232 |
+
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
233 |
+
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
234 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
235 |
+
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
236 |
+
[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
237 |
+
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
238 |
+
[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
239 |
+
[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
240 |
+
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
241 |
+
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
242 |
+
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
243 |
+
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
244 |
+
[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
245 |
+
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
246 |
+
[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
247 |
+
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
248 |
+
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
249 |
+
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
250 |
+
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
251 |
+
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
252 |
+
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
253 |
+
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
254 |
+
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
255 |
+
[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
256 |
+
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
257 |
+
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
258 |
+
[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
259 |
+
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
260 |
+
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
261 |
+
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
262 |
+
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
263 |
+
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
|
264 |
+
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
|
265 |
+
]
|
266 |
+
num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
|
267 |
+
2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
|
268 |
+
1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
|
269 |
+
1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
|
270 |
+
2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
|
271 |
+
3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
|
272 |
+
2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
|
273 |
+
1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
|
274 |
+
1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
|
275 |
+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
|
276 |
+
check_table = [
|
277 |
+
[0, 0, 0, 0, 0],
|
278 |
+
[0, 0, 0, 0, 0],
|
279 |
+
[0, 0, 0, 0, 0],
|
280 |
+
[0, 0, 0, 0, 0],
|
281 |
+
[0, 0, 0, 0, 0],
|
282 |
+
[0, 0, 0, 0, 0],
|
283 |
+
[0, 0, 0, 0, 0],
|
284 |
+
[0, 0, 0, 0, 0],
|
285 |
+
[0, 0, 0, 0, 0],
|
286 |
+
[0, 0, 0, 0, 0],
|
287 |
+
[0, 0, 0, 0, 0],
|
288 |
+
[0, 0, 0, 0, 0],
|
289 |
+
[0, 0, 0, 0, 0],
|
290 |
+
[0, 0, 0, 0, 0],
|
291 |
+
[0, 0, 0, 0, 0],
|
292 |
+
[0, 0, 0, 0, 0],
|
293 |
+
[0, 0, 0, 0, 0],
|
294 |
+
[0, 0, 0, 0, 0],
|
295 |
+
[0, 0, 0, 0, 0],
|
296 |
+
[0, 0, 0, 0, 0],
|
297 |
+
[0, 0, 0, 0, 0],
|
298 |
+
[0, 0, 0, 0, 0],
|
299 |
+
[0, 0, 0, 0, 0],
|
300 |
+
[0, 0, 0, 0, 0],
|
301 |
+
[0, 0, 0, 0, 0],
|
302 |
+
[0, 0, 0, 0, 0],
|
303 |
+
[0, 0, 0, 0, 0],
|
304 |
+
[0, 0, 0, 0, 0],
|
305 |
+
[0, 0, 0, 0, 0],
|
306 |
+
[0, 0, 0, 0, 0],
|
307 |
+
[0, 0, 0, 0, 0],
|
308 |
+
[0, 0, 0, 0, 0],
|
309 |
+
[0, 0, 0, 0, 0],
|
310 |
+
[0, 0, 0, 0, 0],
|
311 |
+
[0, 0, 0, 0, 0],
|
312 |
+
[0, 0, 0, 0, 0],
|
313 |
+
[0, 0, 0, 0, 0],
|
314 |
+
[0, 0, 0, 0, 0],
|
315 |
+
[0, 0, 0, 0, 0],
|
316 |
+
[0, 0, 0, 0, 0],
|
317 |
+
[0, 0, 0, 0, 0],
|
318 |
+
[0, 0, 0, 0, 0],
|
319 |
+
[0, 0, 0, 0, 0],
|
320 |
+
[0, 0, 0, 0, 0],
|
321 |
+
[0, 0, 0, 0, 0],
|
322 |
+
[0, 0, 0, 0, 0],
|
323 |
+
[0, 0, 0, 0, 0],
|
324 |
+
[0, 0, 0, 0, 0],
|
325 |
+
[0, 0, 0, 0, 0],
|
326 |
+
[0, 0, 0, 0, 0],
|
327 |
+
[0, 0, 0, 0, 0],
|
328 |
+
[0, 0, 0, 0, 0],
|
329 |
+
[0, 0, 0, 0, 0],
|
330 |
+
[0, 0, 0, 0, 0],
|
331 |
+
[0, 0, 0, 0, 0],
|
332 |
+
[0, 0, 0, 0, 0],
|
333 |
+
[0, 0, 0, 0, 0],
|
334 |
+
[0, 0, 0, 0, 0],
|
335 |
+
[0, 0, 0, 0, 0],
|
336 |
+
[0, 0, 0, 0, 0],
|
337 |
+
[0, 0, 0, 0, 0],
|
338 |
+
[1, 1, 0, 0, 194],
|
339 |
+
[1, -1, 0, 0, 193],
|
340 |
+
[0, 0, 0, 0, 0],
|
341 |
+
[0, 0, 0, 0, 0],
|
342 |
+
[0, 0, 0, 0, 0],
|
343 |
+
[0, 0, 0, 0, 0],
|
344 |
+
[0, 0, 0, 0, 0],
|
345 |
+
[0, 0, 0, 0, 0],
|
346 |
+
[0, 0, 0, 0, 0],
|
347 |
+
[0, 0, 0, 0, 0],
|
348 |
+
[0, 0, 0, 0, 0],
|
349 |
+
[0, 0, 0, 0, 0],
|
350 |
+
[0, 0, 0, 0, 0],
|
351 |
+
[0, 0, 0, 0, 0],
|
352 |
+
[0, 0, 0, 0, 0],
|
353 |
+
[0, 0, 0, 0, 0],
|
354 |
+
[0, 0, 0, 0, 0],
|
355 |
+
[0, 0, 0, 0, 0],
|
356 |
+
[0, 0, 0, 0, 0],
|
357 |
+
[0, 0, 0, 0, 0],
|
358 |
+
[0, 0, 0, 0, 0],
|
359 |
+
[0, 0, 0, 0, 0],
|
360 |
+
[0, 0, 0, 0, 0],
|
361 |
+
[0, 0, 0, 0, 0],
|
362 |
+
[0, 0, 0, 0, 0],
|
363 |
+
[0, 0, 0, 0, 0],
|
364 |
+
[0, 0, 0, 0, 0],
|
365 |
+
[0, 0, 0, 0, 0],
|
366 |
+
[0, 0, 0, 0, 0],
|
367 |
+
[0, 0, 0, 0, 0],
|
368 |
+
[1, 0, 1, 0, 164],
|
369 |
+
[0, 0, 0, 0, 0],
|
370 |
+
[0, 0, 0, 0, 0],
|
371 |
+
[1, 0, -1, 0, 161],
|
372 |
+
[0, 0, 0, 0, 0],
|
373 |
+
[0, 0, 0, 0, 0],
|
374 |
+
[0, 0, 0, 0, 0],
|
375 |
+
[0, 0, 0, 0, 0],
|
376 |
+
[0, 0, 0, 0, 0],
|
377 |
+
[0, 0, 0, 0, 0],
|
378 |
+
[0, 0, 0, 0, 0],
|
379 |
+
[0, 0, 0, 0, 0],
|
380 |
+
[1, 0, 0, 1, 152],
|
381 |
+
[0, 0, 0, 0, 0],
|
382 |
+
[0, 0, 0, 0, 0],
|
383 |
+
[0, 0, 0, 0, 0],
|
384 |
+
[0, 0, 0, 0, 0],
|
385 |
+
[0, 0, 0, 0, 0],
|
386 |
+
[0, 0, 0, 0, 0],
|
387 |
+
[1, 0, 0, 1, 145],
|
388 |
+
[1, 0, 0, 1, 144],
|
389 |
+
[0, 0, 0, 0, 0],
|
390 |
+
[0, 0, 0, 0, 0],
|
391 |
+
[0, 0, 0, 0, 0],
|
392 |
+
[0, 0, 0, 0, 0],
|
393 |
+
[0, 0, 0, 0, 0],
|
394 |
+
[0, 0, 0, 0, 0],
|
395 |
+
[1, 0, 0, -1, 137],
|
396 |
+
[0, 0, 0, 0, 0],
|
397 |
+
[0, 0, 0, 0, 0],
|
398 |
+
[0, 0, 0, 0, 0],
|
399 |
+
[1, 0, 1, 0, 133],
|
400 |
+
[1, 0, 1, 0, 132],
|
401 |
+
[1, 1, 0, 0, 131],
|
402 |
+
[1, 1, 0, 0, 130],
|
403 |
+
[0, 0, 0, 0, 0],
|
404 |
+
[0, 0, 0, 0, 0],
|
405 |
+
[0, 0, 0, 0, 0],
|
406 |
+
[0, 0, 0, 0, 0],
|
407 |
+
[0, 0, 0, 0, 0],
|
408 |
+
[0, 0, 0, 0, 0],
|
409 |
+
[0, 0, 0, 0, 0],
|
410 |
+
[0, 0, 0, 0, 0],
|
411 |
+
[0, 0, 0, 0, 0],
|
412 |
+
[0, 0, 0, 0, 0],
|
413 |
+
[0, 0, 0, 0, 0],
|
414 |
+
[0, 0, 0, 0, 0],
|
415 |
+
[0, 0, 0, 0, 0],
|
416 |
+
[0, 0, 0, 0, 0],
|
417 |
+
[0, 0, 0, 0, 0],
|
418 |
+
[0, 0, 0, 0, 0],
|
419 |
+
[0, 0, 0, 0, 0],
|
420 |
+
[0, 0, 0, 0, 0],
|
421 |
+
[0, 0, 0, 0, 0],
|
422 |
+
[0, 0, 0, 0, 0],
|
423 |
+
[0, 0, 0, 0, 0],
|
424 |
+
[0, 0, 0, 0, 0],
|
425 |
+
[0, 0, 0, 0, 0],
|
426 |
+
[0, 0, 0, 0, 0],
|
427 |
+
[0, 0, 0, 0, 0],
|
428 |
+
[0, 0, 0, 0, 0],
|
429 |
+
[0, 0, 0, 0, 0],
|
430 |
+
[0, 0, 0, 0, 0],
|
431 |
+
[0, 0, 0, 0, 0],
|
432 |
+
[1, 0, 0, 1, 100],
|
433 |
+
[0, 0, 0, 0, 0],
|
434 |
+
[1, 0, 0, 1, 98],
|
435 |
+
[0, 0, 0, 0, 0],
|
436 |
+
[1, 0, 0, 1, 96],
|
437 |
+
[0, 0, 0, 0, 0],
|
438 |
+
[0, 0, 0, 0, 0],
|
439 |
+
[0, 0, 0, 0, 0],
|
440 |
+
[0, 0, 0, 0, 0],
|
441 |
+
[0, 0, 0, 0, 0],
|
442 |
+
[0, 0, 0, 0, 0],
|
443 |
+
[0, 0, 0, 0, 0],
|
444 |
+
[1, 0, 1, 0, 88],
|
445 |
+
[0, 0, 0, 0, 0],
|
446 |
+
[0, 0, 0, 0, 0],
|
447 |
+
[0, 0, 0, 0, 0],
|
448 |
+
[0, 0, 0, 0, 0],
|
449 |
+
[0, 0, 0, 0, 0],
|
450 |
+
[1, 0, -1, 0, 82],
|
451 |
+
[0, 0, 0, 0, 0],
|
452 |
+
[0, 0, 0, 0, 0],
|
453 |
+
[0, 0, 0, 0, 0],
|
454 |
+
[0, 0, 0, 0, 0],
|
455 |
+
[0, 0, 0, 0, 0],
|
456 |
+
[0, 0, 0, 0, 0],
|
457 |
+
[0, 0, 0, 0, 0],
|
458 |
+
[1, 0, 1, 0, 74],
|
459 |
+
[0, 0, 0, 0, 0],
|
460 |
+
[1, 0, 1, 0, 72],
|
461 |
+
[0, 0, 0, 0, 0],
|
462 |
+
[1, 0, 0, -1, 70],
|
463 |
+
[0, 0, 0, 0, 0],
|
464 |
+
[0, 0, 0, 0, 0],
|
465 |
+
[1, -1, 0, 0, 67],
|
466 |
+
[0, 0, 0, 0, 0],
|
467 |
+
[1, -1, 0, 0, 65],
|
468 |
+
[0, 0, 0, 0, 0],
|
469 |
+
[0, 0, 0, 0, 0],
|
470 |
+
[0, 0, 0, 0, 0],
|
471 |
+
[0, 0, 0, 0, 0],
|
472 |
+
[0, 0, 0, 0, 0],
|
473 |
+
[0, 0, 0, 0, 0],
|
474 |
+
[0, 0, 0, 0, 0],
|
475 |
+
[0, 0, 0, 0, 0],
|
476 |
+
[1, 1, 0, 0, 56],
|
477 |
+
[0, 0, 0, 0, 0],
|
478 |
+
[0, 0, 0, 0, 0],
|
479 |
+
[0, 0, 0, 0, 0],
|
480 |
+
[1, -1, 0, 0, 52],
|
481 |
+
[0, 0, 0, 0, 0],
|
482 |
+
[0, 0, 0, 0, 0],
|
483 |
+
[0, 0, 0, 0, 0],
|
484 |
+
[0, 0, 0, 0, 0],
|
485 |
+
[0, 0, 0, 0, 0],
|
486 |
+
[0, 0, 0, 0, 0],
|
487 |
+
[0, 0, 0, 0, 0],
|
488 |
+
[1, 1, 0, 0, 44],
|
489 |
+
[0, 0, 0, 0, 0],
|
490 |
+
[0, 0, 0, 0, 0],
|
491 |
+
[0, 0, 0, 0, 0],
|
492 |
+
[1, 1, 0, 0, 40],
|
493 |
+
[0, 0, 0, 0, 0],
|
494 |
+
[1, 0, 0, -1, 38],
|
495 |
+
[1, 0, -1, 0, 37],
|
496 |
+
[0, 0, 0, 0, 0],
|
497 |
+
[0, 0, 0, 0, 0],
|
498 |
+
[0, 0, 0, 0, 0],
|
499 |
+
[1, 0, -1, 0, 33],
|
500 |
+
[0, 0, 0, 0, 0],
|
501 |
+
[0, 0, 0, 0, 0],
|
502 |
+
[0, 0, 0, 0, 0],
|
503 |
+
[0, 0, 0, 0, 0],
|
504 |
+
[1, -1, 0, 0, 28],
|
505 |
+
[0, 0, 0, 0, 0],
|
506 |
+
[1, 0, -1, 0, 26],
|
507 |
+
[1, 0, 0, -1, 25],
|
508 |
+
[0, 0, 0, 0, 0],
|
509 |
+
[0, 0, 0, 0, 0],
|
510 |
+
[0, 0, 0, 0, 0],
|
511 |
+
[0, 0, 0, 0, 0],
|
512 |
+
[1, -1, 0, 0, 20],
|
513 |
+
[0, 0, 0, 0, 0],
|
514 |
+
[1, 0, -1, 0, 18],
|
515 |
+
[0, 0, 0, 0, 0],
|
516 |
+
[0, 0, 0, 0, 0],
|
517 |
+
[0, 0, 0, 0, 0],
|
518 |
+
[0, 0, 0, 0, 0],
|
519 |
+
[0, 0, 0, 0, 0],
|
520 |
+
[0, 0, 0, 0, 0],
|
521 |
+
[0, 0, 0, 0, 0],
|
522 |
+
[0, 0, 0, 0, 0],
|
523 |
+
[1, 0, 0, -1, 9],
|
524 |
+
[0, 0, 0, 0, 0],
|
525 |
+
[0, 0, 0, 0, 0],
|
526 |
+
[1, 0, 0, -1, 6],
|
527 |
+
[0, 0, 0, 0, 0],
|
528 |
+
[0, 0, 0, 0, 0],
|
529 |
+
[0, 0, 0, 0, 0],
|
530 |
+
[0, 0, 0, 0, 0],
|
531 |
+
[0, 0, 0, 0, 0],
|
532 |
+
[0, 0, 0, 0, 0]
|
533 |
+
]
|
534 |
+
tet_table = [
|
535 |
+
[-1, -1, -1, -1, -1, -1],
|
536 |
+
[0, 0, 0, 0, 0, 0],
|
537 |
+
[0, 0, 0, 0, 0, 0],
|
538 |
+
[1, 1, 1, 1, 1, 1],
|
539 |
+
[4, 4, 4, 4, 4, 4],
|
540 |
+
[0, 0, 0, 0, 0, 0],
|
541 |
+
[4, 0, 0, 4, 4, -1],
|
542 |
+
[1, 1, 1, 1, 1, 1],
|
543 |
+
[4, 4, 4, 4, 4, 4],
|
544 |
+
[0, 4, 0, 4, 4, -1],
|
545 |
+
[0, 0, 0, 0, 0, 0],
|
546 |
+
[1, 1, 1, 1, 1, 1],
|
547 |
+
[5, 5, 5, 5, 5, 5],
|
548 |
+
[0, 0, 0, 0, 0, 0],
|
549 |
+
[0, 0, 0, 0, 0, 0],
|
550 |
+
[1, 1, 1, 1, 1, 1],
|
551 |
+
[2, 2, 2, 2, 2, 2],
|
552 |
+
[0, 0, 0, 0, 0, 0],
|
553 |
+
[2, 0, 2, -1, 0, 2],
|
554 |
+
[1, 1, 1, 1, 1, 1],
|
555 |
+
[2, -1, 2, 4, 4, 2],
|
556 |
+
[0, 0, 0, 0, 0, 0],
|
557 |
+
[2, 0, 2, 4, 4, 2],
|
558 |
+
[1, 1, 1, 1, 1, 1],
|
559 |
+
[2, 4, 2, 4, 4, 2],
|
560 |
+
[0, 4, 0, 4, 4, 0],
|
561 |
+
[2, 0, 2, 0, 0, 2],
|
562 |
+
[1, 1, 1, 1, 1, 1],
|
563 |
+
[2, 5, 2, 5, 5, 2],
|
564 |
+
[0, 0, 0, 0, 0, 0],
|
565 |
+
[2, 0, 2, 0, 0, 2],
|
566 |
+
[1, 1, 1, 1, 1, 1],
|
567 |
+
[1, 1, 1, 1, 1, 1],
|
568 |
+
[0, 1, 1, -1, 0, 1],
|
569 |
+
[0, 0, 0, 0, 0, 0],
|
570 |
+
[2, 2, 2, 2, 2, 2],
|
571 |
+
[4, 1, 1, 4, 4, 1],
|
572 |
+
[0, 1, 1, 0, 0, 1],
|
573 |
+
[4, 0, 0, 4, 4, 0],
|
574 |
+
[2, 2, 2, 2, 2, 2],
|
575 |
+
[-1, 1, 1, 4, 4, 1],
|
576 |
+
[0, 1, 1, 4, 4, 1],
|
577 |
+
[0, 0, 0, 0, 0, 0],
|
578 |
+
[2, 2, 2, 2, 2, 2],
|
579 |
+
[5, 1, 1, 5, 5, 1],
|
580 |
+
[0, 1, 1, 0, 0, 1],
|
581 |
+
[0, 0, 0, 0, 0, 0],
|
582 |
+
[2, 2, 2, 2, 2, 2],
|
583 |
+
[1, 1, 1, 1, 1, 1],
|
584 |
+
[0, 0, 0, 0, 0, 0],
|
585 |
+
[0, 0, 0, 0, 0, 0],
|
586 |
+
[8, 8, 8, 8, 8, 8],
|
587 |
+
[1, 1, 1, 4, 4, 1],
|
588 |
+
[0, 0, 0, 0, 0, 0],
|
589 |
+
[4, 0, 0, 4, 4, 0],
|
590 |
+
[4, 4, 4, 4, 4, 4],
|
591 |
+
[1, 1, 1, 4, 4, 1],
|
592 |
+
[0, 4, 0, 4, 4, 0],
|
593 |
+
[0, 0, 0, 0, 0, 0],
|
594 |
+
[4, 4, 4, 4, 4, 4],
|
595 |
+
[1, 1, 1, 5, 5, 1],
|
596 |
+
[0, 0, 0, 0, 0, 0],
|
597 |
+
[0, 0, 0, 0, 0, 0],
|
598 |
+
[5, 5, 5, 5, 5, 5],
|
599 |
+
[6, 6, 6, 6, 6, 6],
|
600 |
+
[6, -1, 0, 6, 0, 6],
|
601 |
+
[6, 0, 0, 6, 0, 6],
|
602 |
+
[6, 1, 1, 6, 1, 6],
|
603 |
+
[4, 4, 4, 4, 4, 4],
|
604 |
+
[0, 0, 0, 0, 0, 0],
|
605 |
+
[4, 0, 0, 4, 4, 4],
|
606 |
+
[1, 1, 1, 1, 1, 1],
|
607 |
+
[6, 4, -1, 6, 4, 6],
|
608 |
+
[6, 4, 0, 6, 4, 6],
|
609 |
+
[6, 0, 0, 6, 0, 6],
|
610 |
+
[6, 1, 1, 6, 1, 6],
|
611 |
+
[5, 5, 5, 5, 5, 5],
|
612 |
+
[0, 0, 0, 0, 0, 0],
|
613 |
+
[0, 0, 0, 0, 0, 0],
|
614 |
+
[1, 1, 1, 1, 1, 1],
|
615 |
+
[2, 2, 2, 2, 2, 2],
|
616 |
+
[0, 0, 0, 0, 0, 0],
|
617 |
+
[2, 0, 2, 2, 0, 2],
|
618 |
+
[1, 1, 1, 1, 1, 1],
|
619 |
+
[2, 2, 2, 2, 2, 2],
|
620 |
+
[0, 0, 0, 0, 0, 0],
|
621 |
+
[2, 0, 2, 2, 2, 2],
|
622 |
+
[1, 1, 1, 1, 1, 1],
|
623 |
+
[2, 4, 2, 2, 4, 2],
|
624 |
+
[0, 4, 0, 4, 4, 0],
|
625 |
+
[2, 0, 2, 2, 0, 2],
|
626 |
+
[1, 1, 1, 1, 1, 1],
|
627 |
+
[2, 2, 2, 2, 2, 2],
|
628 |
+
[0, 0, 0, 0, 0, 0],
|
629 |
+
[0, 0, 0, 0, 0, 0],
|
630 |
+
[1, 1, 1, 1, 1, 1],
|
631 |
+
[6, 1, 1, 6, -1, 6],
|
632 |
+
[6, 1, 1, 6, 0, 6],
|
633 |
+
[6, 0, 0, 6, 0, 6],
|
634 |
+
[6, 2, 2, 6, 2, 6],
|
635 |
+
[4, 1, 1, 4, 4, 1],
|
636 |
+
[0, 1, 1, 0, 0, 1],
|
637 |
+
[4, 0, 0, 4, 4, 4],
|
638 |
+
[2, 2, 2, 2, 2, 2],
|
639 |
+
[6, 1, 1, 6, 4, 6],
|
640 |
+
[6, 1, 1, 6, 4, 6],
|
641 |
+
[6, 0, 0, 6, 0, 6],
|
642 |
+
[6, 2, 2, 6, 2, 6],
|
643 |
+
[5, 1, 1, 5, 5, 1],
|
644 |
+
[0, 1, 1, 0, 0, 1],
|
645 |
+
[0, 0, 0, 0, 0, 0],
|
646 |
+
[2, 2, 2, 2, 2, 2],
|
647 |
+
[1, 1, 1, 1, 1, 1],
|
648 |
+
[0, 0, 0, 0, 0, 0],
|
649 |
+
[0, 0, 0, 0, 0, 0],
|
650 |
+
[6, 6, 6, 6, 6, 6],
|
651 |
+
[1, 1, 1, 1, 1, 1],
|
652 |
+
[0, 0, 0, 0, 0, 0],
|
653 |
+
[0, 0, 0, 0, 0, 0],
|
654 |
+
[4, 4, 4, 4, 4, 4],
|
655 |
+
[1, 1, 1, 1, 4, 1],
|
656 |
+
[0, 4, 0, 4, 4, 0],
|
657 |
+
[0, 0, 0, 0, 0, 0],
|
658 |
+
[4, 4, 4, 4, 4, 4],
|
659 |
+
[1, 1, 1, 1, 1, 1],
|
660 |
+
[0, 0, 0, 0, 0, 0],
|
661 |
+
[0, 5, 0, 5, 0, 5],
|
662 |
+
[5, 5, 5, 5, 5, 5],
|
663 |
+
[5, 5, 5, 5, 5, 5],
|
664 |
+
[0, 5, 0, 5, 0, 5],
|
665 |
+
[-1, 5, 0, 5, 0, 5],
|
666 |
+
[1, 5, 1, 5, 1, 5],
|
667 |
+
[4, 5, -1, 5, 4, 5],
|
668 |
+
[0, 5, 0, 5, 0, 5],
|
669 |
+
[4, 5, 0, 5, 4, 5],
|
670 |
+
[1, 5, 1, 5, 1, 5],
|
671 |
+
[4, 4, 4, 4, 4, 4],
|
672 |
+
[0, 4, 0, 4, 4, 4],
|
673 |
+
[0, 0, 0, 0, 0, 0],
|
674 |
+
[1, 1, 1, 1, 1, 1],
|
675 |
+
[6, 6, 6, 6, 6, 6],
|
676 |
+
[0, 0, 0, 0, 0, 0],
|
677 |
+
[0, 0, 0, 0, 0, 0],
|
678 |
+
[1, 1, 1, 1, 1, 1],
|
679 |
+
[2, 5, 2, 5, -1, 5],
|
680 |
+
[0, 5, 0, 5, 0, 5],
|
681 |
+
[2, 5, 2, 5, 0, 5],
|
682 |
+
[1, 5, 1, 5, 1, 5],
|
683 |
+
[2, 5, 2, 5, 4, 5],
|
684 |
+
[0, 5, 0, 5, 0, 5],
|
685 |
+
[2, 5, 2, 5, 4, 5],
|
686 |
+
[1, 5, 1, 5, 1, 5],
|
687 |
+
[2, 4, 2, 4, 4, 2],
|
688 |
+
[0, 4, 0, 4, 4, 4],
|
689 |
+
[2, 0, 2, 0, 0, 2],
|
690 |
+
[1, 1, 1, 1, 1, 1],
|
691 |
+
[2, 6, 2, 6, 6, 2],
|
692 |
+
[0, 0, 0, 0, 0, 0],
|
693 |
+
[2, 0, 2, 0, 0, 2],
|
694 |
+
[1, 1, 1, 1, 1, 1],
|
695 |
+
[1, 1, 1, 1, 1, 1],
|
696 |
+
[0, 1, 1, 1, 0, 1],
|
697 |
+
[0, 0, 0, 0, 0, 0],
|
698 |
+
[2, 2, 2, 2, 2, 2],
|
699 |
+
[4, 1, 1, 1, 4, 1],
|
700 |
+
[0, 1, 1, 1, 0, 1],
|
701 |
+
[4, 0, 0, 4, 4, 0],
|
702 |
+
[2, 2, 2, 2, 2, 2],
|
703 |
+
[1, 1, 1, 1, 1, 1],
|
704 |
+
[0, 1, 1, 1, 1, 1],
|
705 |
+
[0, 0, 0, 0, 0, 0],
|
706 |
+
[2, 2, 2, 2, 2, 2],
|
707 |
+
[1, 1, 1, 1, 1, 1],
|
708 |
+
[0, 0, 0, 0, 0, 0],
|
709 |
+
[0, 0, 0, 0, 0, 0],
|
710 |
+
[2, 2, 2, 2, 2, 2],
|
711 |
+
[1, 1, 1, 1, 1, 1],
|
712 |
+
[0, 0, 0, 0, 0, 0],
|
713 |
+
[0, 0, 0, 0, 0, 0],
|
714 |
+
[5, 5, 5, 5, 5, 5],
|
715 |
+
[1, 1, 1, 1, 4, 1],
|
716 |
+
[0, 0, 0, 0, 0, 0],
|
717 |
+
[4, 0, 0, 4, 4, 0],
|
718 |
+
[4, 4, 4, 4, 4, 4],
|
719 |
+
[1, 1, 1, 1, 1, 1],
|
720 |
+
[0, 0, 0, 0, 0, 0],
|
721 |
+
[0, 0, 0, 0, 0, 0],
|
722 |
+
[4, 4, 4, 4, 4, 4],
|
723 |
+
[1, 1, 1, 1, 1, 1],
|
724 |
+
[6, 0, 0, 6, 0, 6],
|
725 |
+
[0, 0, 0, 0, 0, 0],
|
726 |
+
[6, 6, 6, 6, 6, 6],
|
727 |
+
[5, 5, 5, 5, 5, 5],
|
728 |
+
[5, 5, 0, 5, 0, 5],
|
729 |
+
[5, 5, 0, 5, 0, 5],
|
730 |
+
[5, 5, 1, 5, 1, 5],
|
731 |
+
[4, 4, 4, 4, 4, 4],
|
732 |
+
[0, 0, 0, 0, 0, 0],
|
733 |
+
[4, 4, 0, 4, 4, 4],
|
734 |
+
[1, 1, 1, 1, 1, 1],
|
735 |
+
[4, 4, 4, 4, 4, 4],
|
736 |
+
[4, 4, 0, 4, 4, 4],
|
737 |
+
[0, 0, 0, 0, 0, 0],
|
738 |
+
[1, 1, 1, 1, 1, 1],
|
739 |
+
[8, 8, 8, 8, 8, 8],
|
740 |
+
[0, 0, 0, 0, 0, 0],
|
741 |
+
[0, 0, 0, 0, 0, 0],
|
742 |
+
[1, 1, 1, 1, 1, 1],
|
743 |
+
[2, 2, 2, 2, 2, 2],
|
744 |
+
[0, 0, 0, 0, 0, 0],
|
745 |
+
[2, 2, 2, 2, 0, 2],
|
746 |
+
[1, 1, 1, 1, 1, 1],
|
747 |
+
[2, 2, 2, 2, 2, 2],
|
748 |
+
[0, 0, 0, 0, 0, 0],
|
749 |
+
[2, 2, 2, 2, 2, 2],
|
750 |
+
[1, 1, 1, 1, 1, 1],
|
751 |
+
[2, 2, 2, 2, 2, 2],
|
752 |
+
[0, 0, 0, 0, 0, 0],
|
753 |
+
[0, 0, 0, 0, 0, 0],
|
754 |
+
[4, 1, 1, 4, 4, 1],
|
755 |
+
[2, 2, 2, 2, 2, 2],
|
756 |
+
[0, 0, 0, 0, 0, 0],
|
757 |
+
[0, 0, 0, 0, 0, 0],
|
758 |
+
[1, 1, 1, 1, 1, 1],
|
759 |
+
[1, 1, 1, 1, 1, 1],
|
760 |
+
[1, 1, 1, 1, 0, 1],
|
761 |
+
[0, 0, 0, 0, 0, 0],
|
762 |
+
[2, 2, 2, 2, 2, 2],
|
763 |
+
[1, 1, 1, 1, 1, 1],
|
764 |
+
[0, 0, 0, 0, 0, 0],
|
765 |
+
[0, 0, 0, 0, 0, 0],
|
766 |
+
[2, 4, 2, 4, 4, 2],
|
767 |
+
[1, 1, 1, 1, 1, 1],
|
768 |
+
[1, 1, 1, 1, 1, 1],
|
769 |
+
[0, 0, 0, 0, 0, 0],
|
770 |
+
[2, 2, 2, 2, 2, 2],
|
771 |
+
[1, 1, 1, 1, 1, 1],
|
772 |
+
[0, 0, 0, 0, 0, 0],
|
773 |
+
[0, 0, 0, 0, 0, 0],
|
774 |
+
[2, 2, 2, 2, 2, 2],
|
775 |
+
[1, 1, 1, 1, 1, 1],
|
776 |
+
[0, 0, 0, 0, 0, 0],
|
777 |
+
[0, 0, 0, 0, 0, 0],
|
778 |
+
[5, 5, 5, 5, 5, 5],
|
779 |
+
[1, 1, 1, 1, 1, 1],
|
780 |
+
[0, 0, 0, 0, 0, 0],
|
781 |
+
[0, 0, 0, 0, 0, 0],
|
782 |
+
[4, 4, 4, 4, 4, 4],
|
783 |
+
[1, 1, 1, 1, 1, 1],
|
784 |
+
[0, 0, 0, 0, 0, 0],
|
785 |
+
[0, 0, 0, 0, 0, 0],
|
786 |
+
[4, 4, 4, 4, 4, 4],
|
787 |
+
[1, 1, 1, 1, 1, 1],
|
788 |
+
[0, 0, 0, 0, 0, 0],
|
789 |
+
[0, 0, 0, 0, 0, 0],
|
790 |
+
[12, 12, 12, 12, 12, 12]
|
791 |
+
]
|
core/instant_utils/__init__.py
ADDED
File without changes
|
core/instant_utils/camera_util.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def pad_camera_extrinsics_4x4(extrinsics):
|
7 |
+
if extrinsics.shape[-2] == 4:
|
8 |
+
return extrinsics
|
9 |
+
padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
|
10 |
+
if extrinsics.ndim == 3:
|
11 |
+
padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
|
12 |
+
extrinsics = torch.cat([extrinsics, padding], dim=-2)
|
13 |
+
return extrinsics
|
14 |
+
|
15 |
+
|
16 |
+
def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
|
17 |
+
"""
|
18 |
+
Create OpenGL camera extrinsics from camera locations and look-at position.
|
19 |
+
|
20 |
+
camera_position: (M, 3) or (3,)
|
21 |
+
look_at: (3)
|
22 |
+
up_world: (3)
|
23 |
+
return: (M, 3, 4) or (3, 4)
|
24 |
+
"""
|
25 |
+
# by default, looking at the origin and world up is z-axis
|
26 |
+
if look_at is None:
|
27 |
+
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
|
28 |
+
if up_world is None:
|
29 |
+
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
|
30 |
+
if camera_position.ndim == 2:
|
31 |
+
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
32 |
+
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
|
33 |
+
|
34 |
+
# OpenGL camera: z-backward, x-right, y-up
|
35 |
+
z_axis = camera_position - look_at
|
36 |
+
z_axis = F.normalize(z_axis, dim=-1).float()
|
37 |
+
x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
|
38 |
+
x_axis = F.normalize(x_axis, dim=-1).float()
|
39 |
+
y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
|
40 |
+
y_axis = F.normalize(y_axis, dim=-1).float()
|
41 |
+
|
42 |
+
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
|
43 |
+
extrinsics = pad_camera_extrinsics_4x4(extrinsics)
|
44 |
+
return extrinsics
|
45 |
+
|
46 |
+
|
47 |
+
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
|
48 |
+
azimuths = np.deg2rad(azimuths)
|
49 |
+
elevations = np.deg2rad(elevations)
|
50 |
+
|
51 |
+
xs = radius * np.cos(elevations) * np.cos(azimuths)
|
52 |
+
ys = radius * np.cos(elevations) * np.sin(azimuths)
|
53 |
+
zs = radius * np.sin(elevations)
|
54 |
+
|
55 |
+
cam_locations = np.stack([xs, ys, zs], axis=-1)
|
56 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
57 |
+
|
58 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
59 |
+
return c2ws
|
60 |
+
|
61 |
+
|
62 |
+
def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
|
63 |
+
# M: number of circular views
|
64 |
+
# radius: camera dist to center
|
65 |
+
# elevation: elevation degrees of the camera
|
66 |
+
# return: (M, 4, 4)
|
67 |
+
assert M > 0 and radius > 0
|
68 |
+
|
69 |
+
elevation = np.deg2rad(elevation)
|
70 |
+
|
71 |
+
camera_positions = []
|
72 |
+
for i in range(M):
|
73 |
+
azimuth = 2 * np.pi * i / M
|
74 |
+
x = radius * np.cos(elevation) * np.cos(azimuth)
|
75 |
+
y = radius * np.cos(elevation) * np.sin(azimuth)
|
76 |
+
z = radius * np.sin(elevation)
|
77 |
+
camera_positions.append([x, y, z])
|
78 |
+
camera_positions = np.array(camera_positions)
|
79 |
+
camera_positions = torch.from_numpy(camera_positions).float()
|
80 |
+
extrinsics = center_looking_at_camera_pose(camera_positions)
|
81 |
+
return extrinsics
|
82 |
+
|
83 |
+
|
84 |
+
def FOV_to_intrinsics(fov, device='cpu'):
|
85 |
+
"""
|
86 |
+
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
|
87 |
+
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
|
88 |
+
Assumes principal point is at image center.
|
89 |
+
"""
|
90 |
+
focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
|
91 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
92 |
+
return intrinsics
|
93 |
+
|
94 |
+
|
95 |
+
def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
|
96 |
+
"""
|
97 |
+
Get the input camera parameters.
|
98 |
+
"""
|
99 |
+
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
|
100 |
+
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
|
101 |
+
|
102 |
+
c2ws = spherical_camera_pose(azimuths, elevations, radius)
|
103 |
+
c2ws = c2ws.float().flatten(-2)
|
104 |
+
|
105 |
+
Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
|
106 |
+
|
107 |
+
extrinsics = c2ws[:, :12]
|
108 |
+
intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
|
109 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
110 |
+
|
111 |
+
return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
core/instant_utils/infer_util.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import rembg
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import PIL.Image
|
7 |
+
from PIL import Image
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
|
11 |
+
def remove_background(image: PIL.Image.Image,
|
12 |
+
rembg_session: Any = None,
|
13 |
+
force: bool = False,
|
14 |
+
**rembg_kwargs,
|
15 |
+
) -> PIL.Image.Image:
|
16 |
+
do_remove = True
|
17 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
18 |
+
do_remove = False
|
19 |
+
do_remove = do_remove or force
|
20 |
+
if do_remove:
|
21 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
22 |
+
return image
|
23 |
+
|
24 |
+
|
25 |
+
def resize_foreground(
|
26 |
+
image: PIL.Image.Image,
|
27 |
+
ratio: float,
|
28 |
+
) -> PIL.Image.Image:
|
29 |
+
image = np.array(image)
|
30 |
+
assert image.shape[-1] == 4
|
31 |
+
alpha = np.where(image[..., 3] > 0)
|
32 |
+
y1, y2, x1, x2 = (
|
33 |
+
alpha[0].min(),
|
34 |
+
alpha[0].max(),
|
35 |
+
alpha[1].min(),
|
36 |
+
alpha[1].max(),
|
37 |
+
)
|
38 |
+
# crop the foreground
|
39 |
+
fg = image[y1:y2, x1:x2]
|
40 |
+
# pad to square
|
41 |
+
size = max(fg.shape[0], fg.shape[1])
|
42 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
43 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
44 |
+
new_image = np.pad(
|
45 |
+
fg,
|
46 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
47 |
+
mode="constant",
|
48 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
49 |
+
)
|
50 |
+
|
51 |
+
# compute padding according to the ratio
|
52 |
+
new_size = int(new_image.shape[0] / ratio)
|
53 |
+
# pad to size, double side
|
54 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
55 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
56 |
+
new_image = np.pad(
|
57 |
+
new_image,
|
58 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
59 |
+
mode="constant",
|
60 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
61 |
+
)
|
62 |
+
new_image = PIL.Image.fromarray(new_image)
|
63 |
+
return new_image
|
64 |
+
|
65 |
+
|
66 |
+
def images_to_video(
|
67 |
+
images: torch.Tensor,
|
68 |
+
output_path: str,
|
69 |
+
fps: int = 30,
|
70 |
+
) -> None:
|
71 |
+
# images: (N, C, H, W)
|
72 |
+
video_dir = os.path.dirname(output_path)
|
73 |
+
video_name = os.path.basename(output_path)
|
74 |
+
os.makedirs(video_dir, exist_ok=True)
|
75 |
+
|
76 |
+
frames = []
|
77 |
+
for i in range(len(images)):
|
78 |
+
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
79 |
+
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
80 |
+
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
81 |
+
assert frame.min() >= 0 and frame.max() <= 255, \
|
82 |
+
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
83 |
+
frames.append(frame)
|
84 |
+
imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
|
85 |
+
|
86 |
+
|
87 |
+
def save_video(
|
88 |
+
frames: torch.Tensor,
|
89 |
+
output_path: str,
|
90 |
+
fps: int = 30,
|
91 |
+
) -> None:
|
92 |
+
# images: (N, C, H, W)
|
93 |
+
frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]
|
94 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
95 |
+
for frame in frames:
|
96 |
+
writer.append_data(frame)
|
97 |
+
writer.close()
|
core/instant_utils/mesh_util.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import xatlas
|
11 |
+
import trimesh
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import nvdiffrast.torch as dr
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
|
18 |
+
def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):
|
19 |
+
|
20 |
+
pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
|
21 |
+
#facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
|
22 |
+
|
23 |
+
mesh = trimesh.Trimesh(
|
24 |
+
vertices=pointnp_px3,
|
25 |
+
faces=facenp_fx3,
|
26 |
+
vertex_colors=colornp_px3,
|
27 |
+
)
|
28 |
+
mesh.export(fpath, 'obj')
|
29 |
+
|
30 |
+
|
31 |
+
def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
|
32 |
+
|
33 |
+
pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
|
34 |
+
|
35 |
+
mesh = trimesh.Trimesh(
|
36 |
+
vertices=pointnp_px3,
|
37 |
+
faces=facenp_fx3,
|
38 |
+
vertex_colors=colornp_px3,
|
39 |
+
)
|
40 |
+
mesh.export(fpath, 'glb')
|
41 |
+
|
42 |
+
|
43 |
+
def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
|
44 |
+
import os
|
45 |
+
fol, na = os.path.split(fname)
|
46 |
+
na, _ = os.path.splitext(na)
|
47 |
+
|
48 |
+
matname = '%s/%s.mtl' % (fol, na)
|
49 |
+
fid = open(matname, 'w')
|
50 |
+
fid.write('newmtl material_0\n')
|
51 |
+
fid.write('Kd 1 1 1\n')
|
52 |
+
fid.write('Ka 0 0 0\n')
|
53 |
+
fid.write('Ks 0.4 0.4 0.4\n')
|
54 |
+
fid.write('Ns 10\n')
|
55 |
+
fid.write('illum 2\n')
|
56 |
+
fid.write('map_Kd %s.png\n' % na)
|
57 |
+
fid.close()
|
58 |
+
####
|
59 |
+
|
60 |
+
fid = open(fname, 'w')
|
61 |
+
fid.write('mtllib %s.mtl\n' % na)
|
62 |
+
|
63 |
+
for pidx, p in enumerate(pointnp_px3):
|
64 |
+
pp = p
|
65 |
+
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
|
66 |
+
|
67 |
+
for pidx, p in enumerate(tcoords_px2):
|
68 |
+
pp = p
|
69 |
+
fid.write('vt %f %f\n' % (pp[0], pp[1]))
|
70 |
+
|
71 |
+
fid.write('usemtl material_0\n')
|
72 |
+
for i, f in enumerate(facenp_fx3):
|
73 |
+
f1 = f + 1
|
74 |
+
f2 = facetex_fx3[i] + 1
|
75 |
+
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
|
76 |
+
fid.close()
|
77 |
+
|
78 |
+
# save texture map
|
79 |
+
lo, hi = 0, 1
|
80 |
+
img = np.asarray(texmap_hxwx3, dtype=np.float32)
|
81 |
+
img = (img - lo) * (255 / (hi - lo))
|
82 |
+
img = img.clip(0, 255)
|
83 |
+
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
|
84 |
+
mask = (mask <= 3.0).astype(np.float32)
|
85 |
+
kernel = np.ones((3, 3), 'uint8')
|
86 |
+
dilate_img = cv2.dilate(img, kernel, iterations=1)
|
87 |
+
img = img * (1 - mask) + dilate_img * mask
|
88 |
+
img = img.clip(0, 255).astype(np.uint8)
|
89 |
+
Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png')
|
90 |
+
|
91 |
+
|
92 |
+
def loadobj(meshfile):
|
93 |
+
v = []
|
94 |
+
f = []
|
95 |
+
meshfp = open(meshfile, 'r')
|
96 |
+
for line in meshfp.readlines():
|
97 |
+
data = line.strip().split(' ')
|
98 |
+
data = [da for da in data if len(da) > 0]
|
99 |
+
if len(data) != 4:
|
100 |
+
continue
|
101 |
+
if data[0] == 'v':
|
102 |
+
v.append([float(d) for d in data[1:]])
|
103 |
+
if data[0] == 'f':
|
104 |
+
data = [da.split('/')[0] for da in data]
|
105 |
+
f.append([int(d) for d in data[1:]])
|
106 |
+
meshfp.close()
|
107 |
+
|
108 |
+
# torch need int64
|
109 |
+
facenp_fx3 = np.array(f, dtype=np.int64) - 1
|
110 |
+
pointnp_px3 = np.array(v, dtype=np.float32)
|
111 |
+
return pointnp_px3, facenp_fx3
|
112 |
+
|
113 |
+
|
114 |
+
def loadobjtex(meshfile):
|
115 |
+
v = []
|
116 |
+
vt = []
|
117 |
+
f = []
|
118 |
+
ft = []
|
119 |
+
meshfp = open(meshfile, 'r')
|
120 |
+
for line in meshfp.readlines():
|
121 |
+
data = line.strip().split(' ')
|
122 |
+
data = [da for da in data if len(da) > 0]
|
123 |
+
if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
|
124 |
+
continue
|
125 |
+
if data[0] == 'v':
|
126 |
+
assert len(data) == 4
|
127 |
+
|
128 |
+
v.append([float(d) for d in data[1:]])
|
129 |
+
if data[0] == 'vt':
|
130 |
+
if len(data) == 3 or len(data) == 4:
|
131 |
+
vt.append([float(d) for d in data[1:3]])
|
132 |
+
if data[0] == 'f':
|
133 |
+
data = [da.split('/') for da in data]
|
134 |
+
if len(data) == 4:
|
135 |
+
f.append([int(d[0]) for d in data[1:]])
|
136 |
+
ft.append([int(d[1]) for d in data[1:]])
|
137 |
+
elif len(data) == 5:
|
138 |
+
idx1 = [1, 2, 3]
|
139 |
+
data1 = [data[i] for i in idx1]
|
140 |
+
f.append([int(d[0]) for d in data1])
|
141 |
+
ft.append([int(d[1]) for d in data1])
|
142 |
+
idx2 = [1, 3, 4]
|
143 |
+
data2 = [data[i] for i in idx2]
|
144 |
+
f.append([int(d[0]) for d in data2])
|
145 |
+
ft.append([int(d[1]) for d in data2])
|
146 |
+
meshfp.close()
|
147 |
+
|
148 |
+
# torch need int64
|
149 |
+
facenp_fx3 = np.array(f, dtype=np.int64) - 1
|
150 |
+
ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
|
151 |
+
pointnp_px3 = np.array(v, dtype=np.float32)
|
152 |
+
uvs = np.array(vt, dtype=np.float32)
|
153 |
+
return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
|
154 |
+
|
155 |
+
|
156 |
+
# ==============================================================================================
|
157 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
158 |
+
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
159 |
+
|
160 |
+
|
161 |
+
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
|
162 |
+
vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
|
163 |
+
|
164 |
+
# Convert to tensors
|
165 |
+
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
166 |
+
|
167 |
+
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
|
168 |
+
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
|
169 |
+
# mesh_v_tex. ture
|
170 |
+
uv_clip = uvs[None, ...] * 2.0 - 1.0
|
171 |
+
|
172 |
+
# pad to four component coordinate
|
173 |
+
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
|
174 |
+
|
175 |
+
# rasterize
|
176 |
+
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
|
177 |
+
|
178 |
+
# Interpolate world space position
|
179 |
+
gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
|
180 |
+
mask = rast[..., 3:4] > 0
|
181 |
+
return uvs, mesh_tex_idx, gb_pos, mask
|
core/instant_utils/train_util.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
|
4 |
+
def count_params(model, verbose=False):
|
5 |
+
total_params = sum(p.numel() for p in model.parameters())
|
6 |
+
if verbose:
|
7 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
8 |
+
return total_params
|
9 |
+
|
10 |
+
|
11 |
+
def instantiate_from_config(config):
|
12 |
+
if not "target" in config:
|
13 |
+
if config == '__is_first_stage__':
|
14 |
+
return None
|
15 |
+
elif config == "__is_unconditional__":
|
16 |
+
return None
|
17 |
+
raise KeyError("Expected key `target` to instantiate.")
|
18 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
19 |
+
|
20 |
+
|
21 |
+
def get_obj_from_str(string, reload=False):
|
22 |
+
module, cls = string.rsplit(".", 1)
|
23 |
+
if reload:
|
24 |
+
module_imp = importlib.import_module(module)
|
25 |
+
importlib.reload(module_imp)
|
26 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
core/lrm_reconstructor.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from typing import Tuple, Literal
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import itertools
|
10 |
+
|
11 |
+
|
12 |
+
# LRM
|
13 |
+
from .embedder import CameraEmbedder
|
14 |
+
from .transformer import TransformerDecoder
|
15 |
+
# from accelerate.logging import get_logger
|
16 |
+
|
17 |
+
# logger = get_logger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class LRM_VSD_Mesh_Net(nn.Module):
|
21 |
+
"""
|
22 |
+
predict VSD using transformer
|
23 |
+
"""
|
24 |
+
def __init__(self, camera_embed_dim: int,
|
25 |
+
transformer_dim: int, transformer_layers: int, transformer_heads: int,
|
26 |
+
triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
|
27 |
+
encoder_freeze: bool = True, encoder_type: str = 'dino',
|
28 |
+
encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, app_dim = 27, density_dim = 8, app_n_comp=24,
|
29 |
+
density_n_comp=8):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
# attributes
|
33 |
+
self.encoder_feat_dim = encoder_feat_dim
|
34 |
+
self.camera_embed_dim = camera_embed_dim
|
35 |
+
self.triplane_low_res = triplane_low_res
|
36 |
+
self.triplane_high_res = triplane_high_res
|
37 |
+
self.triplane_dim = triplane_dim
|
38 |
+
self.transformer_dim=transformer_dim
|
39 |
+
|
40 |
+
# modules
|
41 |
+
self.encoder = self._encoder_fn(encoder_type)(
|
42 |
+
model_name=encoder_model_name,
|
43 |
+
modulation_dim=self.camera_embed_dim, #mod camera vector
|
44 |
+
freeze=encoder_freeze,
|
45 |
+
)
|
46 |
+
self.camera_embedder = CameraEmbedder(
|
47 |
+
raw_dim=12+4, embed_dim=camera_embed_dim,
|
48 |
+
)
|
49 |
+
|
50 |
+
self.n_comp=app_n_comp+density_n_comp
|
51 |
+
self.app_dim=app_dim
|
52 |
+
self.density_dim=density_dim
|
53 |
+
self.app_n_comp=app_n_comp
|
54 |
+
self.density_n_comp=density_n_comp
|
55 |
+
|
56 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*(triplane_low_res**2)+3*triplane_low_res, transformer_dim) * (1. / transformer_dim) ** 0.5)
|
57 |
+
self.transformer = TransformerDecoder(
|
58 |
+
block_type='cond',
|
59 |
+
num_layers=transformer_layers, num_heads=transformer_heads,
|
60 |
+
inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=None,
|
61 |
+
)
|
62 |
+
# for plane
|
63 |
+
self.upsampler = nn.ConvTranspose2d(transformer_dim, self.n_comp, kernel_size=2, stride=2, padding=0)
|
64 |
+
self.dim_map = nn.Linear(transformer_dim,self.n_comp)
|
65 |
+
self.up_line = nn.Linear(triplane_low_res,triplane_low_res*2)
|
66 |
+
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def _encoder_fn(encoder_type: str):
|
70 |
+
encoder_type = encoder_type.lower()
|
71 |
+
assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type"
|
72 |
+
if encoder_type == 'dino':
|
73 |
+
from .encoders.dino_wrapper import DinoWrapper
|
74 |
+
#logger.info("Using DINO as the encoder")
|
75 |
+
return DinoWrapper
|
76 |
+
elif encoder_type == 'dinov2':
|
77 |
+
from .encoders.dinov2_wrapper import Dinov2Wrapper
|
78 |
+
#logger.info("Using DINOv2 as the encoder")
|
79 |
+
return Dinov2Wrapper
|
80 |
+
|
81 |
+
def forward_transformer(self, image_feats, camera_embeddings=None):
|
82 |
+
N = image_feats.shape[0]
|
83 |
+
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
84 |
+
x = self.transformer(
|
85 |
+
x,
|
86 |
+
cond=image_feats,
|
87 |
+
mod=camera_embeddings,
|
88 |
+
)
|
89 |
+
return x
|
90 |
+
def reshape_upsample(self, tokens):
|
91 |
+
#B,_,3*ncomp
|
92 |
+
N = tokens.shape[0]
|
93 |
+
H = W = self.triplane_low_res
|
94 |
+
P=self.n_comp
|
95 |
+
|
96 |
+
offset=3*H*W
|
97 |
+
|
98 |
+
# planes
|
99 |
+
plane_tokens= tokens[:,:3*H*W,:].view(N,H,W,3,self.transformer_dim)
|
100 |
+
plane_tokens = torch.einsum('nhwip->inphw', plane_tokens) # [3, N, P, H, W]
|
101 |
+
plane_tokens = plane_tokens.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
102 |
+
plane_tokens = self.upsampler(plane_tokens) # [3*N, P, H', W']
|
103 |
+
plane_tokens = plane_tokens.view(3, N, *plane_tokens.shape[-3:]) # [3, N, P, H', W']
|
104 |
+
plane_tokens = torch.einsum('inphw->niphw', plane_tokens) # [N, 3, P, H', W']
|
105 |
+
plane_tokens = plane_tokens.reshape(N, 3*P, *plane_tokens.shape[-2:]) # # [N, 3*P, H', W']
|
106 |
+
plane_tokens = plane_tokens.contiguous()
|
107 |
+
|
108 |
+
#lines
|
109 |
+
line_tokens= tokens[:,3*H*W:3*H*W+3*H,:].view(N,H,3,self.transformer_dim)
|
110 |
+
line_tokens= self.dim_map(line_tokens)
|
111 |
+
line_tokens = torch.einsum('nhip->npih', line_tokens) # [ N, P, 3, H]
|
112 |
+
line_tokens=self.up_line(line_tokens)
|
113 |
+
line_tokens = torch.einsum('npih->niph', line_tokens) # [ N, 3, P, H]
|
114 |
+
line_tokens=line_tokens.reshape(N,3*P,line_tokens.shape[-1],1)
|
115 |
+
line_tokens = line_tokens.contiguous()
|
116 |
+
|
117 |
+
mat_tokens=None
|
118 |
+
|
119 |
+
d_mat_tokens=None
|
120 |
+
|
121 |
+
return plane_tokens[:,:self.app_n_comp*3,:,:],line_tokens[:,:self.app_n_comp*3,:,:],mat_tokens,d_mat_tokens,plane_tokens[:,self.app_n_comp*3:,:,:],line_tokens[:,self.app_n_comp*3:,:,:]
|
122 |
+
|
123 |
+
def forward_planes(self, image, camera):
|
124 |
+
# image: [N, V, C_img, H_img, W_img]
|
125 |
+
# camera: [N,V, D_cam_raw]
|
126 |
+
N,V,_,H,W = image.shape
|
127 |
+
image=image.reshape(N*V,3,H,W)
|
128 |
+
camera=camera.reshape(N*V,-1)
|
129 |
+
|
130 |
+
|
131 |
+
# embed camera
|
132 |
+
camera_embeddings = self.camera_embedder(camera)
|
133 |
+
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
|
134 |
+
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
|
135 |
+
|
136 |
+
# encode image
|
137 |
+
image_feats = self.encoder(image, camera_embeddings)
|
138 |
+
assert image_feats.shape[-1] == self.encoder_feat_dim, \
|
139 |
+
f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
|
140 |
+
|
141 |
+
image_feats=image_feats.reshape(N,V*image_feats.shape[-2],image_feats.shape[-1])
|
142 |
+
|
143 |
+
# transformer generating planes
|
144 |
+
tokens = self.forward_transformer(image_feats)
|
145 |
+
|
146 |
+
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.reshape_upsample(tokens)
|
147 |
+
|
148 |
+
return app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines
|
149 |
+
|
150 |
+
def forward(self, image,source_camera):
|
151 |
+
# image: [N,V, C_img, H_img, W_img]
|
152 |
+
# source_camera: [N, V, D_cam_raw]
|
153 |
+
|
154 |
+
assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera"
|
155 |
+
planes = self.forward_planes(image, source_camera)
|
156 |
+
|
157 |
+
#B,3,dim,H,W
|
158 |
+
return planes
|
core/models.py
ADDED
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import mcubes
|
6 |
+
|
7 |
+
import kiui
|
8 |
+
from kiui.lpips import LPIPS
|
9 |
+
|
10 |
+
from core.lrm_reconstructor import LRM_VSD_Mesh_Net
|
11 |
+
from core.options import Options
|
12 |
+
from core.tensoRF import TensorVMSplit_Mesh,TensorVMSplit_NeRF
|
13 |
+
from torchvision.transforms import v2
|
14 |
+
from core.geometry.camera.perspective_camera import PerspectiveCamera
|
15 |
+
from core.geometry.render.neural_render import NeuralRender
|
16 |
+
from core.geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
|
17 |
+
import nvdiffrast.torch as dr
|
18 |
+
from core.instant_utils.mesh_util import xatlas_uvmap
|
19 |
+
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
|
22 |
+
#tensorSDF + transformer + volume_rendering
|
23 |
+
class LTRFM_NeRF(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
opt: Options,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.opt = opt
|
31 |
+
|
32 |
+
|
33 |
+
#predict svd using transformer
|
34 |
+
self.vsd_net = LRM_VSD_Mesh_Net(
|
35 |
+
camera_embed_dim=opt.camera_embed_dim,
|
36 |
+
transformer_dim=opt.transformer_dim,
|
37 |
+
transformer_layers=opt.transformer_layers,
|
38 |
+
transformer_heads=opt.transformer_heads,
|
39 |
+
triplane_low_res=opt.triplane_low_res,
|
40 |
+
triplane_high_res=opt.triplane_high_res,
|
41 |
+
triplane_dim=opt.triplane_dim,
|
42 |
+
encoder_freeze=opt.encoder_freeze,
|
43 |
+
encoder_type=opt.encoder_type,
|
44 |
+
encoder_model_name=opt.encoder_model_name,
|
45 |
+
encoder_feat_dim=opt.encoder_feat_dim,
|
46 |
+
app_dim=opt.app_dim,
|
47 |
+
density_dim=opt.density_dim,
|
48 |
+
app_n_comp=opt.app_n_comp,
|
49 |
+
density_n_comp=opt.density_n_comp,
|
50 |
+
)
|
51 |
+
|
52 |
+
aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device)
|
53 |
+
grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device)
|
54 |
+
near_far =torch.tensor([opt.znear, opt.zfar]).to(device)
|
55 |
+
# tensorf Renderer
|
56 |
+
self.tensorRF = TensorVMSplit_NeRF(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\
|
57 |
+
density_dim=opt.density_dim,near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe)
|
58 |
+
|
59 |
+
# LPIPS loss
|
60 |
+
if self.opt.lambda_lpips > 0:
|
61 |
+
self.lpips_loss = LPIPS(net='vgg')
|
62 |
+
self.lpips_loss.requires_grad_(False)
|
63 |
+
|
64 |
+
|
65 |
+
def state_dict(self, **kwargs):
|
66 |
+
# remove lpips_loss
|
67 |
+
state_dict = super().state_dict(**kwargs)
|
68 |
+
for k in list(state_dict.keys()):
|
69 |
+
if 'lpips_loss' in k:
|
70 |
+
del state_dict[k]
|
71 |
+
return state_dict
|
72 |
+
|
73 |
+
def set_beta(self,t):
|
74 |
+
self.tensorRF.lap_density.set_beta(t)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
# predict svd_volume
|
79 |
+
def forward_svd_volume(self, images, data):
|
80 |
+
# images: [B, 4, 9, H, W]
|
81 |
+
# return: Gaussians: [B, dim_t]
|
82 |
+
B, V, C, H, W = images.shape
|
83 |
+
|
84 |
+
|
85 |
+
source_camera=data['source_camera']
|
86 |
+
images_vit=data['input_vit'] # for transformer
|
87 |
+
source_camera=source_camera.reshape(B,V,-1) # [B*V, 16]
|
88 |
+
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera)
|
89 |
+
|
90 |
+
|
91 |
+
app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size)
|
92 |
+
app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1)
|
93 |
+
density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size)
|
94 |
+
density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1)
|
95 |
+
|
96 |
+
results = {
|
97 |
+
'app_planes': app_planes,
|
98 |
+
'app_lines': app_lines,
|
99 |
+
'basis_mat':basis_mat,
|
100 |
+
'd_basis_mat':d_basis_mat,
|
101 |
+
'density_planes':density_planes,
|
102 |
+
'density_lines':density_lines
|
103 |
+
}
|
104 |
+
|
105 |
+
return results
|
106 |
+
|
107 |
+
def extract_mesh(self,
|
108 |
+
planes: torch.Tensor,
|
109 |
+
mesh_resolution: int = 256,
|
110 |
+
mesh_threshold: int = 0.009,
|
111 |
+
use_texture_map: bool = False,
|
112 |
+
texture_resolution: int = 1024,):
|
113 |
+
|
114 |
+
device = planes['app_planes'].device
|
115 |
+
|
116 |
+
grid_size = mesh_resolution
|
117 |
+
points = torch.linspace(-1, 1, steps=grid_size).half()
|
118 |
+
|
119 |
+
x, y, z = torch.meshgrid(points, points, points)
|
120 |
+
|
121 |
+
xyz_samples = torch.stack((x, y, z), dim=0).unsqueeze(0).to(device)
|
122 |
+
xyz_samples=xyz_samples.permute(0,2,3,4,1)
|
123 |
+
xyz_samples=xyz_samples.view(1,-1,1,3)
|
124 |
+
|
125 |
+
|
126 |
+
grid_out = self.tensorRF.predict_sdf(planes,xyz_samples)
|
127 |
+
grid_out['sigma']=grid_out['sigma'].view(grid_size,grid_size,grid_size).float()
|
128 |
+
|
129 |
+
vertices, faces = mcubes.marching_cubes(
|
130 |
+
grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
|
131 |
+
mesh_threshold,
|
132 |
+
)
|
133 |
+
vertices = vertices / (mesh_resolution - 1) * 2 - 1
|
134 |
+
|
135 |
+
if not use_texture_map:
|
136 |
+
# query vertex colors
|
137 |
+
vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
|
138 |
+
rgb_colors = self.tensorRF.predict_color(
|
139 |
+
planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
|
140 |
+
rgb_colors = (rgb_colors * 255).astype(np.uint8)
|
141 |
+
|
142 |
+
albedob_colors = self.tensorRF.predict_color(
|
143 |
+
planes, vertices_tensor)['albedo'].squeeze(0).cpu().numpy()
|
144 |
+
albedob_colors = (albedob_colors * 255).astype(np.uint8)
|
145 |
+
|
146 |
+
shading_colors = self.tensorRF.predict_color(
|
147 |
+
planes, vertices_tensor)['shading'].squeeze(0).cpu().numpy()
|
148 |
+
shading_colors = (shading_colors * 255).astype(np.uint8)
|
149 |
+
|
150 |
+
return vertices, faces, [rgb_colors,albedob_colors,shading_colors]
|
151 |
+
|
152 |
+
# use x-atlas to get uv mapping for the mesh
|
153 |
+
vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
|
154 |
+
faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
|
155 |
+
|
156 |
+
ctx = dr.RasterizeCudaContext(device=device)
|
157 |
+
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
|
158 |
+
ctx, vertices, faces, resolution=texture_resolution)
|
159 |
+
tex_hard_mask = tex_hard_mask.float().cpu()
|
160 |
+
|
161 |
+
# query the texture field to get the RGB color for texture map
|
162 |
+
#TBD here
|
163 |
+
query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3)
|
164 |
+
|
165 |
+
vertices_colors = self.tensorRF.predict_color(
|
166 |
+
planes, query_vertices)['rgb'].squeeze(0).cpu()
|
167 |
+
|
168 |
+
vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3)
|
169 |
+
|
170 |
+
background_feature = torch.zeros_like(vertices_colors)
|
171 |
+
img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask.half())
|
172 |
+
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
|
173 |
+
#albedo
|
174 |
+
vertices_colors_albedo = self.tensorRF.predict_color(
|
175 |
+
planes, query_vertices)['albedo'].squeeze(0).cpu()
|
176 |
+
|
177 |
+
vertices_colors_albedo=vertices_colors_albedo.reshape(1,texture_resolution,texture_resolution,3)
|
178 |
+
|
179 |
+
background_feature = torch.zeros_like(vertices_colors_albedo)
|
180 |
+
img_feat = torch.lerp(background_feature, vertices_colors_albedo, tex_hard_mask.half())
|
181 |
+
texture_map_albedo = img_feat.permute(0, 3, 1, 2).squeeze(0)
|
182 |
+
|
183 |
+
return vertices, faces, uvs, mesh_tex_idx, [texture_map,texture_map_albedo]
|
184 |
+
|
185 |
+
|
186 |
+
def forward(self, data, step_ratio=1):
|
187 |
+
# data: output of the dataloader
|
188 |
+
# return: loss
|
189 |
+
#self.set_beta(data['t'])
|
190 |
+
results = {}
|
191 |
+
loss = 0
|
192 |
+
|
193 |
+
images = data['input'] # [B, 4, 9, h, W], input features
|
194 |
+
|
195 |
+
# use the first view to predict gaussians
|
196 |
+
svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
|
197 |
+
|
198 |
+
results['svd_volume'] = svd_volume
|
199 |
+
|
200 |
+
# always use white bg
|
201 |
+
bg_color = torch.ones(3, dtype=torch.float32).to(device)
|
202 |
+
|
203 |
+
# use the other views for rendering and supervision
|
204 |
+
results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
|
205 |
+
pred_shading = results['image'] # [B, V, C, output_size, output_size]
|
206 |
+
pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
|
207 |
+
pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
|
208 |
+
|
209 |
+
pred_images = pred_shading*pred_albedos
|
210 |
+
|
211 |
+
results['images_pred'] = pred_images
|
212 |
+
results['alphas_pred'] = pred_alphas
|
213 |
+
results['pred_albedos'] = pred_albedos
|
214 |
+
results['pred_shading'] = pred_shading
|
215 |
+
|
216 |
+
|
217 |
+
gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
|
218 |
+
gt_albedos = data['albedos_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
|
219 |
+
gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
|
220 |
+
|
221 |
+
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
222 |
+
gt_albedos = gt_albedos * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
223 |
+
|
224 |
+
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + F.mse_loss(pred_albedos, gt_albedos)
|
225 |
+
loss = loss + loss_mse
|
226 |
+
|
227 |
+
# eikonal_loss = ((results['eik_grads'].norm(2, dim=1) - 1) ** 2).mean()
|
228 |
+
# loss = loss+ 0.1*eikonal_loss
|
229 |
+
|
230 |
+
if self.opt.lambda_lpips > 0:
|
231 |
+
loss_lpips = self.lpips_loss(
|
232 |
+
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
233 |
+
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
234 |
+
).mean()
|
235 |
+
results['loss_lpips'] = loss_lpips
|
236 |
+
loss = loss + self.opt.lambda_lpips * loss_lpips
|
237 |
+
|
238 |
+
results['loss'] = loss
|
239 |
+
|
240 |
+
# metric
|
241 |
+
with torch.no_grad():
|
242 |
+
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
|
243 |
+
results['psnr'] = psnr
|
244 |
+
|
245 |
+
return results
|
246 |
+
|
247 |
+
|
248 |
+
def render_frame(self, data):
|
249 |
+
# data: output of the dataloader
|
250 |
+
# return: loss
|
251 |
+
#self.set_beta(data['t'])
|
252 |
+
results = {}
|
253 |
+
loss = 0
|
254 |
+
|
255 |
+
images = data['input_vit']
|
256 |
+
|
257 |
+
# use the first view to predict gaussians
|
258 |
+
svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
|
259 |
+
|
260 |
+
results['svd_volume'] = svd_volume
|
261 |
+
|
262 |
+
# always use white bg
|
263 |
+
bg_color = torch.ones(3, dtype=torch.float32).to(device)
|
264 |
+
|
265 |
+
# use the other views for rendering and supervision
|
266 |
+
results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
|
267 |
+
pred_shading = results['image'] # [B, V, C, output_size, output_size]
|
268 |
+
pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
|
269 |
+
pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
|
270 |
+
|
271 |
+
pred_images = pred_shading*pred_albedos
|
272 |
+
|
273 |
+
results['images_pred'] = pred_images
|
274 |
+
results['alphas_pred'] = pred_alphas
|
275 |
+
results['pred_albedos'] = pred_albedos
|
276 |
+
results['pred_shading'] = pred_shading
|
277 |
+
|
278 |
+
|
279 |
+
return results
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
#tensorSDF + transformer + SDF + Mesh
|
286 |
+
class LTRFM_Mesh(nn.Module):
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
opt: Options,
|
290 |
+
):
|
291 |
+
super().__init__()
|
292 |
+
|
293 |
+
self.opt = opt
|
294 |
+
|
295 |
+
# attributes
|
296 |
+
self.grid_res = 128 #grid_res
|
297 |
+
self.grid_scale = 2.0 #grid_scale
|
298 |
+
self.deformation_multiplier = 4.0
|
299 |
+
|
300 |
+
|
301 |
+
self.init_flexicubes_geometry(device, self.opt)
|
302 |
+
|
303 |
+
#predict svd using transformer
|
304 |
+
self.vsd_net = LRM_VSD_Mesh_Net(
|
305 |
+
camera_embed_dim=opt.camera_embed_dim,
|
306 |
+
transformer_dim=opt.transformer_dim,
|
307 |
+
transformer_layers=opt.transformer_layers,
|
308 |
+
transformer_heads=opt.transformer_heads,
|
309 |
+
triplane_low_res=opt.triplane_low_res,
|
310 |
+
triplane_high_res=opt.triplane_high_res,
|
311 |
+
triplane_dim=opt.triplane_dim,
|
312 |
+
encoder_freeze=opt.encoder_freeze,
|
313 |
+
encoder_type=opt.encoder_type,
|
314 |
+
encoder_model_name=opt.encoder_model_name,
|
315 |
+
encoder_feat_dim=opt.encoder_feat_dim,
|
316 |
+
app_dim=opt.app_dim,
|
317 |
+
density_dim=opt.density_dim,
|
318 |
+
app_n_comp=opt.app_n_comp,
|
319 |
+
density_n_comp=opt.density_n_comp,
|
320 |
+
)
|
321 |
+
|
322 |
+
aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device)
|
323 |
+
grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device)
|
324 |
+
near_far =torch.tensor([opt.znear, opt.zfar]).to(device)
|
325 |
+
# tensorf Renderer
|
326 |
+
self.tensorRF = TensorVMSplit_Mesh(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\
|
327 |
+
density_dim=opt.density_dim, near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe)
|
328 |
+
|
329 |
+
# LPIPS loss
|
330 |
+
if self.opt.lambda_lpips > 0:
|
331 |
+
self.lpips_loss = LPIPS(net='vgg')
|
332 |
+
self.lpips_loss.requires_grad_(False)
|
333 |
+
|
334 |
+
|
335 |
+
# load ckpt
|
336 |
+
if opt.ckpt_nerf is not None:
|
337 |
+
sd = torch.load(opt.ckpt_nerf, map_location='cpu')['model']
|
338 |
+
#sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
|
339 |
+
sd_fc = {}
|
340 |
+
for k, v in sd.items():
|
341 |
+
k=k.replace('module.', '')
|
342 |
+
if k.startswith('vsd.renderModule.'):
|
343 |
+
continue
|
344 |
+
else:
|
345 |
+
sd_fc[k] = v
|
346 |
+
sd_fc = {k.replace('vsd_net.', ''): v for k, v in sd_fc.items()}
|
347 |
+
sd_fc = {k.replace('tensorRF.', ''): v for k, v in sd_fc.items()}
|
348 |
+
# missing `net_deformation` and `net_weight` parameters
|
349 |
+
self.vsd_net.load_state_dict(sd_fc, strict=False)
|
350 |
+
self.tensorRF.load_state_dict(sd_fc, strict=False)
|
351 |
+
print(f'Loaded weights from {opt.ckpt_nerf}')
|
352 |
+
|
353 |
+
|
354 |
+
def state_dict(self, **kwargs):
|
355 |
+
# remove lpips_loss
|
356 |
+
state_dict = super().state_dict(**kwargs)
|
357 |
+
for k in list(state_dict.keys()):
|
358 |
+
if 'lpips_loss' in k:
|
359 |
+
del state_dict[k]
|
360 |
+
return state_dict
|
361 |
+
|
362 |
+
|
363 |
+
# predict svd_volume
|
364 |
+
def forward_svd_volume(self, images, data):
|
365 |
+
# images: [B, 4, 9, H, W]
|
366 |
+
# return: Gaussians: [B, dim_t]
|
367 |
+
B, V, C, H, W = images.shape
|
368 |
+
|
369 |
+
source_camera=data['source_camera']
|
370 |
+
images_vit=data['input_vit'] # for transformer
|
371 |
+
source_camera=source_camera.reshape(B,V,-1) # [B*V, 16]
|
372 |
+
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera)
|
373 |
+
|
374 |
+
|
375 |
+
app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size)
|
376 |
+
app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1)
|
377 |
+
density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size)
|
378 |
+
density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1)
|
379 |
+
|
380 |
+
results = {
|
381 |
+
'app_planes': app_planes,
|
382 |
+
'app_lines': app_lines,
|
383 |
+
'basis_mat':basis_mat,
|
384 |
+
'd_basis_mat':d_basis_mat,
|
385 |
+
'density_planes':density_planes,
|
386 |
+
'density_lines':density_lines
|
387 |
+
}
|
388 |
+
|
389 |
+
return results
|
390 |
+
|
391 |
+
|
392 |
+
def init_flexicubes_geometry(self, device, opt):
|
393 |
+
camera = PerspectiveCamera(opt, device=device)
|
394 |
+
renderer = NeuralRender(device, camera_model=camera)
|
395 |
+
self.geometry = FlexiCubesGeometry(
|
396 |
+
grid_res=self.grid_res,
|
397 |
+
scale=self.grid_scale,
|
398 |
+
renderer=renderer,
|
399 |
+
render_type='neural_render',
|
400 |
+
device=device,
|
401 |
+
)
|
402 |
+
|
403 |
+
|
404 |
+
# query vsd for sdf weight and ...
|
405 |
+
def get_sdf_deformation_prediction(self, planes):
|
406 |
+
'''
|
407 |
+
Predict SDF and deformation for tetrahedron vertices
|
408 |
+
:param planes: triplane feature map for the geometry
|
409 |
+
'''
|
410 |
+
B = planes['app_lines'].shape[0]
|
411 |
+
init_position = self.geometry.verts.unsqueeze(0).expand(B, -1, -1)
|
412 |
+
|
413 |
+
|
414 |
+
sdf, deformation, weight = self.tensorRF.get_geometry_prediction(planes,init_position,self.geometry.indices)
|
415 |
+
|
416 |
+
deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
|
417 |
+
sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
|
418 |
+
|
419 |
+
sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
|
420 |
+
sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
|
421 |
+
pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
|
422 |
+
neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
|
423 |
+
zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
|
424 |
+
if torch.sum(zero_surface).item() > 0:
|
425 |
+
update_sdf = torch.zeros_like(sdf[0:1])
|
426 |
+
max_sdf = sdf.max()
|
427 |
+
min_sdf = sdf.min()
|
428 |
+
update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero
|
429 |
+
update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero
|
430 |
+
new_sdf = torch.zeros_like(sdf)
|
431 |
+
for i_batch in range(zero_surface.shape[0]):
|
432 |
+
if zero_surface[i_batch]:
|
433 |
+
new_sdf[i_batch:i_batch + 1] += update_sdf
|
434 |
+
update_mask = (new_sdf == 0).float()
|
435 |
+
# Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
|
436 |
+
sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
|
437 |
+
sdf_reg_loss = sdf_reg_loss * zero_surface.float()
|
438 |
+
sdf = sdf * update_mask + new_sdf * (1 - update_mask)
|
439 |
+
|
440 |
+
final_sdf = []
|
441 |
+
final_def = []
|
442 |
+
for i_batch in range(zero_surface.shape[0]):
|
443 |
+
if zero_surface[i_batch]:
|
444 |
+
final_sdf.append(sdf[i_batch: i_batch + 1].detach())
|
445 |
+
final_def.append(deformation[i_batch: i_batch + 1].detach())
|
446 |
+
else:
|
447 |
+
final_sdf.append(sdf[i_batch: i_batch + 1])
|
448 |
+
final_def.append(deformation[i_batch: i_batch + 1])
|
449 |
+
sdf = torch.cat(final_sdf, dim=0)
|
450 |
+
deformation = torch.cat(final_def, dim=0)
|
451 |
+
return sdf, deformation, sdf_reg_loss, weight
|
452 |
+
|
453 |
+
def get_geometry_prediction(self, planes=None):
|
454 |
+
'''
|
455 |
+
Function to generate mesh with give triplanes
|
456 |
+
:param planes: triplane features
|
457 |
+
'''
|
458 |
+
|
459 |
+
sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
|
460 |
+
|
461 |
+
|
462 |
+
v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
|
463 |
+
tets = self.geometry.indices
|
464 |
+
n_batch = planes['app_planes'].shape[0]
|
465 |
+
v_list = []
|
466 |
+
f_list = []
|
467 |
+
flexicubes_surface_reg_list = []
|
468 |
+
|
469 |
+
|
470 |
+
for i_batch in range(n_batch):
|
471 |
+
verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
|
472 |
+
v_deformed[i_batch],
|
473 |
+
sdf[i_batch].squeeze(dim=-1),
|
474 |
+
with_uv=False,
|
475 |
+
indices=tets,
|
476 |
+
weight_n=weight[i_batch].squeeze(dim=-1),
|
477 |
+
is_training=self.training,
|
478 |
+
)
|
479 |
+
flexicubes_surface_reg_list.append(flexicubes_surface_reg)
|
480 |
+
v_list.append(verts)
|
481 |
+
f_list.append(faces)
|
482 |
+
|
483 |
+
flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
|
484 |
+
flexicubes_weight_reg = (weight ** 2).mean()
|
485 |
+
|
486 |
+
return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
|
487 |
+
|
488 |
+
def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
|
489 |
+
'''
|
490 |
+
Predict Texture given triplanes
|
491 |
+
:param planes: the triplane feature map
|
492 |
+
:param tex_pos: Position we want to query the texture field
|
493 |
+
:param hard_mask: 2D silhoueete of the rendered image
|
494 |
+
'''
|
495 |
+
B = planes['app_planes'].shape[0]
|
496 |
+
tex_pos = torch.cat(tex_pos, dim=0)
|
497 |
+
if not hard_mask is None:
|
498 |
+
tex_pos = tex_pos * hard_mask.float()
|
499 |
+
batch_size = tex_pos.shape[0]
|
500 |
+
tex_pos = tex_pos.reshape(batch_size, -1, 3)
|
501 |
+
###################
|
502 |
+
# We use mask to get the texture location (to save the memory)
|
503 |
+
if hard_mask is not None:
|
504 |
+
n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
|
505 |
+
sample_tex_pose_list = []
|
506 |
+
max_point = n_point_list.max()
|
507 |
+
if max_point==0: # xrg: hard mask may filter all points, and don not left any point
|
508 |
+
max_point=max_point+1
|
509 |
+
expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
|
510 |
+
for i in range(tex_pos.shape[0]):
|
511 |
+
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
|
512 |
+
if tex_pos_one_shape.shape[1] < max_point:
|
513 |
+
tex_pos_one_shape = torch.cat(
|
514 |
+
[tex_pos_one_shape, torch.zeros(
|
515 |
+
1, max_point - tex_pos_one_shape.shape[1], 3,
|
516 |
+
device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
|
517 |
+
sample_tex_pose_list.append(tex_pos_one_shape)
|
518 |
+
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
|
519 |
+
|
520 |
+
|
521 |
+
#return texture rgb
|
522 |
+
tex_feat = self.tensorRF.get_texture_prediction(tex_pos,vsd_vome=planes)
|
523 |
+
|
524 |
+
if hard_mask is not None:
|
525 |
+
final_tex_feat = torch.zeros(
|
526 |
+
B, hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
|
527 |
+
expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
|
528 |
+
for i in range(B):
|
529 |
+
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
|
530 |
+
tex_feat = final_tex_feat
|
531 |
+
|
532 |
+
return tex_feat.reshape(B, hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
|
533 |
+
|
534 |
+
def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
|
535 |
+
'''
|
536 |
+
Function to render a generated mesh with nvdiffrast
|
537 |
+
:param mesh_v: List of vertices for the mesh
|
538 |
+
:param mesh_f: List of faces for the mesh
|
539 |
+
:param cam_mv: 4x4 rotation matrix
|
540 |
+
:return:
|
541 |
+
'''
|
542 |
+
return_value_list = []
|
543 |
+
for i_mesh in range(len(mesh_v)):
|
544 |
+
return_value = self.geometry.render_mesh(
|
545 |
+
mesh_v[i_mesh],
|
546 |
+
mesh_f[i_mesh].int(),
|
547 |
+
cam_mv[i_mesh],
|
548 |
+
resolution=render_size,
|
549 |
+
hierarchical_mask=False
|
550 |
+
)
|
551 |
+
return_value_list.append(return_value)
|
552 |
+
|
553 |
+
return_keys = return_value_list[0].keys()
|
554 |
+
return_value = dict()
|
555 |
+
for k in return_keys:
|
556 |
+
value = [v[k] for v in return_value_list]
|
557 |
+
return_value[k] = value
|
558 |
+
|
559 |
+
mask = torch.cat(return_value['mask'], dim=0)
|
560 |
+
hard_mask = torch.cat(return_value['hard_mask'], dim=0)
|
561 |
+
tex_pos = return_value['tex_pos']
|
562 |
+
depth = torch.cat(return_value['depth'], dim=0)
|
563 |
+
normal = torch.cat(return_value['normal'], dim=0)
|
564 |
+
return mask, hard_mask, tex_pos, depth, normal
|
565 |
+
|
566 |
+
def forward_geometry(self, planes, render_cameras, render_size=256):
|
567 |
+
'''
|
568 |
+
Main function of our Generator. It first generate 3D mesh, then render it into 2D image
|
569 |
+
with given `render_cameras`.
|
570 |
+
:param planes: triplane features
|
571 |
+
:param render_cameras: cameras to render generated 3D shape, a w2c matrix
|
572 |
+
'''
|
573 |
+
B, NV = render_cameras.shape[:2]
|
574 |
+
|
575 |
+
# Generate 3D mesh first
|
576 |
+
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
|
577 |
+
|
578 |
+
# Render the mesh into 2D image (get 3d position of each image plane) continue for here
|
579 |
+
cam_mv = render_cameras
|
580 |
+
run_n_view = cam_mv.shape[1]
|
581 |
+
antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
|
582 |
+
|
583 |
+
tex_hard_mask = hard_mask
|
584 |
+
tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
|
585 |
+
tex_hard_mask = torch.cat(
|
586 |
+
[torch.cat(
|
587 |
+
[tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
|
588 |
+
for i_view in range(run_n_view)], dim=2)
|
589 |
+
for i in range(B)], dim=0)
|
590 |
+
|
591 |
+
# Querying the texture field to predict the texture feature for each pixel on the image
|
592 |
+
tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
|
593 |
+
background_feature = torch.ones_like(tex_feat) # white background
|
594 |
+
|
595 |
+
# Merge them together
|
596 |
+
img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
|
597 |
+
|
598 |
+
# We should split it back to the original image shape
|
599 |
+
img_feat = torch.cat(
|
600 |
+
[torch.cat(
|
601 |
+
[img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
|
602 |
+
for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
|
603 |
+
|
604 |
+
img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
|
605 |
+
|
606 |
+
albedo=img[:,:,3:6,:,:]
|
607 |
+
img=img[:,:,0:3,:,:]
|
608 |
+
|
609 |
+
antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
|
610 |
+
depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive
|
611 |
+
normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
|
612 |
+
|
613 |
+
out = {
|
614 |
+
'image': img,
|
615 |
+
'albedo': albedo,
|
616 |
+
'mask': antilias_mask,
|
617 |
+
'depth': depth,
|
618 |
+
'normal': normal,
|
619 |
+
'sdf': sdf,
|
620 |
+
'mesh_v': mesh_v,
|
621 |
+
'mesh_f': mesh_f,
|
622 |
+
'sdf_reg_loss': sdf_reg_loss,
|
623 |
+
}
|
624 |
+
return out
|
625 |
+
|
626 |
+
def forward(self, data, step_ratio=1):
|
627 |
+
# data: output of the dataloader
|
628 |
+
# return: loss
|
629 |
+
|
630 |
+
results = {}
|
631 |
+
loss = 0
|
632 |
+
|
633 |
+
images = data['input'] # [B, 4, 9, h, W], input features
|
634 |
+
|
635 |
+
# use the first view to predict gaussians
|
636 |
+
svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
|
637 |
+
|
638 |
+
results['svd_volume'] = svd_volume
|
639 |
+
|
640 |
+
# return the rendered images
|
641 |
+
results = self.forward_geometry(svd_volume, data['w2c'], self.opt.output_size)
|
642 |
+
|
643 |
+
|
644 |
+
# always use white bg
|
645 |
+
bg_color = torch.ones(3, dtype=torch.float32).to(device)
|
646 |
+
|
647 |
+
# use the other views for rendering and supervision
|
648 |
+
#results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
|
649 |
+
|
650 |
+
|
651 |
+
pred_shading = results['image'] # [B, V, C, output_size, output_size]
|
652 |
+
pred_alphas = results['mask'] # [B, V, 1, output_size, output_size]
|
653 |
+
pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
|
654 |
+
|
655 |
+
pred_images=pred_shading*pred_albedos
|
656 |
+
|
657 |
+
results['images_pred'] = pred_images
|
658 |
+
results['alphas_pred'] = pred_alphas
|
659 |
+
results['pred_albedos'] = pred_albedos
|
660 |
+
results['pred_shading'] = pred_shading
|
661 |
+
|
662 |
+
|
663 |
+
gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
|
664 |
+
gt_albedos = data['albedos_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
|
665 |
+
gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
|
666 |
+
|
667 |
+
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
668 |
+
gt_albedos = gt_albedos * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
669 |
+
|
670 |
+
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + F.mse_loss(pred_albedos, gt_albedos)
|
671 |
+
loss = loss + loss_mse
|
672 |
+
|
673 |
+
if self.opt.lambda_lpips > 0:
|
674 |
+
loss_lpips = self.lpips_loss(
|
675 |
+
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
676 |
+
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
677 |
+
).mean()
|
678 |
+
results['loss_lpips'] = loss_lpips
|
679 |
+
loss = loss + self.opt.lambda_lpips * loss_lpips
|
680 |
+
|
681 |
+
results['loss'] = loss
|
682 |
+
|
683 |
+
# metric
|
684 |
+
with torch.no_grad():
|
685 |
+
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
|
686 |
+
results['psnr'] = psnr
|
687 |
+
|
688 |
+
return results
|
689 |
+
|
690 |
+
|
691 |
+
def render_frame(self, data):
|
692 |
+
# data: output of the dataloader
|
693 |
+
# return: loss
|
694 |
+
|
695 |
+
results = {}
|
696 |
+
|
697 |
+
images = data['input_vit'] # [B, 4, 9, h, W], input features
|
698 |
+
|
699 |
+
# use the first view to predict gaussians
|
700 |
+
svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
|
701 |
+
|
702 |
+
results['svd_volume'] = svd_volume
|
703 |
+
|
704 |
+
# return the rendered images
|
705 |
+
results = self.forward_geometry(svd_volume, data['w2c'], self.opt.infer_render_size)
|
706 |
+
|
707 |
+
|
708 |
+
# always use white bg
|
709 |
+
bg_color = torch.ones(3, dtype=torch.float32).to(device)
|
710 |
+
|
711 |
+
|
712 |
+
pred_shading = results['image'] # [B, V, C, output_size, output_size]
|
713 |
+
pred_alphas = results['mask'] # [B, V, 1, output_size, output_size]
|
714 |
+
pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
|
715 |
+
|
716 |
+
pred_images=pred_shading*pred_albedos
|
717 |
+
|
718 |
+
results['images_pred'] = pred_images
|
719 |
+
results['alphas_pred'] = pred_alphas
|
720 |
+
results['pred_albedos'] = pred_albedos
|
721 |
+
results['pred_shading'] = pred_shading
|
722 |
+
|
723 |
+
return results
|
724 |
+
|
725 |
+
def extract_mesh(
|
726 |
+
self,
|
727 |
+
planes: torch.Tensor,
|
728 |
+
use_texture_map: bool = False,
|
729 |
+
texture_resolution: int = 1024,
|
730 |
+
**kwargs,
|
731 |
+
):
|
732 |
+
'''
|
733 |
+
Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
|
734 |
+
:param planes: triplane features
|
735 |
+
:param use_texture_map: use texture map or vertex color
|
736 |
+
:param texture_resolution: the resolution of texure map
|
737 |
+
'''
|
738 |
+
assert planes['app_planes'].shape[0] == 1
|
739 |
+
device = planes['app_planes'].device
|
740 |
+
|
741 |
+
|
742 |
+
# predict geometry first
|
743 |
+
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
|
744 |
+
vertices, faces = mesh_v[0], mesh_f[0]
|
745 |
+
|
746 |
+
if not use_texture_map:
|
747 |
+
# query vertex colors
|
748 |
+
vertices_tensor = vertices.unsqueeze(0)
|
749 |
+
rgb_colors = self.tensorRF.predict_color(planes, vertices_tensor)['rgb'].clamp(0, 1).squeeze(0).cpu().numpy()
|
750 |
+
rgb_colors = (rgb_colors * 255).astype(np.uint8)
|
751 |
+
|
752 |
+
albedob_colors = self.tensorRF.predict_color(planes, vertices_tensor)['albedo'].clamp(0, 1).squeeze(0).cpu().numpy()
|
753 |
+
albedob_colors = (albedob_colors * 255).astype(np.uint8)
|
754 |
+
|
755 |
+
shading_colors = self.tensorRF.predict_color(planes, vertices_tensor)['shading'].clamp(0, 1).squeeze(0).cpu().numpy()
|
756 |
+
shading_colors = (shading_colors * 255).astype(np.uint8)
|
757 |
+
|
758 |
+
|
759 |
+
return vertices.cpu().numpy(), faces.cpu().numpy(), [rgb_colors,albedob_colors,shading_colors]
|
760 |
+
|
761 |
+
# use x-atlas to get uv mapping for the mesh
|
762 |
+
ctx = dr.RasterizeCudaContext(device=device)
|
763 |
+
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
|
764 |
+
self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
|
765 |
+
|
766 |
+
tex_hard_mask = tex_hard_mask.float().cpu()
|
767 |
+
|
768 |
+
# query the texture field to get the RGB color for texture map
|
769 |
+
#TBD here
|
770 |
+
query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3)
|
771 |
+
|
772 |
+
vertices_colors = self.tensorRF.predict_color(
|
773 |
+
planes, query_vertices)['rgb'].squeeze(0).cpu()
|
774 |
+
|
775 |
+
vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3)
|
776 |
+
|
777 |
+
background_feature = torch.zeros_like(vertices_colors)
|
778 |
+
img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask)
|
779 |
+
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
|
780 |
+
|
781 |
+
return vertices, faces, uvs, mesh_tex_idx, texture_map
|
782 |
+
|
783 |
+
|
core/modulate.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023-2024, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class ModLN(nn.Module):
|
21 |
+
"""
|
22 |
+
Modulation with adaLN.
|
23 |
+
|
24 |
+
References:
|
25 |
+
DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
|
26 |
+
"""
|
27 |
+
def __init__(self, inner_dim: int, mod_dim: int, eps: float):
|
28 |
+
super().__init__()
|
29 |
+
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
30 |
+
self.mlp = nn.Sequential(
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(mod_dim, inner_dim * 2),
|
33 |
+
)
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def modulate(x, shift, scale):
|
37 |
+
# x: [N, L, D]
|
38 |
+
# shift, scale: [N, D]
|
39 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
42 |
+
shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D]
|
43 |
+
return self.modulate(self.norm(x), shift, scale) # [N, L, D]
|