diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..26d29f6fd9ca4ce648be077de5751dcbb6288792 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +workspace_test +__pycache__ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..abc00b79ae789273ded462678304d8a65389fefb --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,23 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "app", + "type": "debugpy", + "request": "launch", + "program": "./app.py", + "console": "integratedTerminal", + "env": { + "CUDA_VISIBLE_DEVICES": "2" + }, + // "args": [ + // "tiny_trf_trans_nerf",//"tiny_trf_trans_nerf" tiny_trf_trans_nerf_123plus + // "--resume", + // "pretrained/last6view060804_24.ckpt",//"pretrained/last_060302_49.ckpt",//"pretrained/last_060302_49.ckpt", + // "--output_size", + // "64" + // ], + "justMyCode": true + }, + ] +} \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d6762ebe0578779a5941264dc38ce0c378724b --- /dev/null +++ b/app.py @@ -0,0 +1,385 @@ +import os +import tyro +import imageio +import numpy as np +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from safetensors.torch import load_file +import rembg +import gradio as gr + +import kiui +from kiui.op import recenter +from kiui.cam import orbit_camera +from core.utils import get_rays, grid_distortion, orbit_camera_jitter + +from core.options import AllConfigs, Options +from core.models import LTRFM_Mesh,LTRFM_NeRF +from core.instant_utils.mesh_util import save_obj, save_obj_with_mtl +from mvdream.pipeline_mvdream import MVDreamPipeline +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler +from huggingface_hub import hf_hub_download + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +GRADIO_VIDEO_PATH = 'gradio_output.mp4' +GRADIO_OBJ_PATH = 'gradio_output_rgb.obj' +GRADIO_OBJ_ALBEDO_PATH = 'gradio_output_albedo.obj' +GRADIO_OBJ_SHADING_PATH = 'gradio_output_shading.obj' + +#opt = tyro.cli(AllConfigs) + +ckpt_path = hf_hub_download(repo_id="rgxie/LDM", filename="LDM6v01.ckpt") + +opt = Options( + input_size=512, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + volume_mode='TRF_NeRF', + splat_size=64, + output_size=62, #crop patch + data_mode='s5', + num_views=8, + gradient_accumulation_steps=1, #2 + mixed_precision='bf16', + resume=ckpt_path, +) + + +# model +if opt.volume_mode == 'TRF_Mesh': + model = LTRFM_Mesh(opt) +elif opt.volume_mode == 'TRF_NeRF': + model = LTRFM_NeRF(opt) +else: + model = LGM(opt) + +# resume pretrained checkpoint +if opt.resume is not None: + if opt.resume.endswith('safetensors'): + ckpt = load_file(opt.resume, device='cpu') + else: #ckpt + ckpt_dict = torch.load(opt.resume, map_location='cpu') + ckpt=ckpt_dict["model"] + + state_dict = model.state_dict() + for k, v in ckpt.items(): + k=k.replace('module.', '') + if k in state_dict: + if state_dict[k].shape == v.shape: + state_dict[k].copy_(v) + else: + print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') + else: + print(f'[WARN] unexpected param {k}: {v.shape}') + print(f'[INFO] load resume success!') + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = model.half().to(device) +model.eval() + +tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) +proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) +proj_matrix[0, 0] = 1 / tan_half_fov +proj_matrix[1, 1] = 1 / tan_half_fov +proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) +proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) +proj_matrix[2, 3] = 1 + +# load dreams +pipe_text = MVDreamPipeline.from_pretrained( + 'ashawkey/mvdream-sd2.1-diffusers', # remote weights + torch_dtype=torch.float16, + trust_remote_code=True, + # local_files_only=True, +) +pipe_text = pipe_text.to(device) + +# mvdream +pipe_image = MVDreamPipeline.from_pretrained( + "ashawkey/imagedream-ipmv-diffusers", # remote weights + torch_dtype=torch.float16, + trust_remote_code=True, + # local_files_only=True, +) +pipe_image = pipe_image.to(device) + + +print('Loading 123plus model ...') +pipe_image_plus = DiffusionPipeline.from_pretrained( + "sudo-ai/zero123plus-v1.2", + custom_pipeline="zero123plus", + torch_dtype=torch.float16, + trust_remote_code=True, + local_files_only=True, +) +pipe_image_plus.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe_image_plus.scheduler.config, timestep_spacing='trailing' +) + +unet_path='./pretrained/diffusion_pytorch_model.bin' + +print('Loading custom white-background unet ...') +if os.path.exists(unet_path): + unet_ckpt_path = unet_path +else: + unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") +state_dict = torch.load(unet_ckpt_path, map_location='cpu') +pipe_image_plus.unet.load_state_dict(state_dict, strict=True) +pipe_image_plus = pipe_image_plus.to(device) + +# load rembg +bg_remover = rembg.new_session() + +# process function +def process(condition_input_image, prompt, prompt_neg='', input_elevation=0, input_num_steps=30, input_seed=42, mv_moedl_option=None): + + # seed + kiui.seed_everything(input_seed) + + os.makedirs(os.path.join(opt.workspace, "gradio"), exist_ok=True) + output_video_path = os.path.join(opt.workspace,"gradio", GRADIO_VIDEO_PATH) + output_obj_rgb_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_PATH) + output_obj_albedo_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_ALBEDO_PATH) + output_obj_shading_path = os.path.join(opt.workspace,"gradio", GRADIO_OBJ_SHADING_PATH) + + # text-conditioned + if condition_input_image is None: + mv_image_uint8 = pipe_text(prompt, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=7.5, elevation=input_elevation) + mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8) + # bg removal + mv_image = [] + for i in range(4): + image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4] + # to white bg + image = image.astype(np.float32) / 255 + image = recenter(image, image[..., 0] > 0, border_ratio=0.2) + image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:]) + mv_image.append(image) + + mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1) + input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) + + processed_image=None + # image-conditioned (may also input text, but no text usually works too) + else: + condition_input_image = np.array(condition_input_image) # uint8 + # bg removal + carved_image = rembg.remove(condition_input_image, session=bg_remover) # [H, W, 4] + mask = carved_image[..., -1] > 0 + image = recenter(carved_image, mask, border_ratio=0.2) + image = image.astype(np.float32) / 255.0 + processed_image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) + + if mv_moedl_option=='mvdream': + mv_image = pipe_image(prompt, processed_image, negative_prompt=prompt_neg, num_inference_steps=input_num_steps, guidance_scale=5.0, elevation=input_elevation) + + mv_image_grid = np.concatenate([mv_image[1], mv_image[2],mv_image[3], mv_image[0]],axis=1) + input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) + else: + from PIL import Image + from einops import rearrange, repeat + + # input_image=input_image* 255 + processed_image = Image.fromarray((processed_image * 255).astype(np.uint8)) + mv_image = pipe_image_plus(processed_image, num_inference_steps=input_num_steps).images[0] + mv_image = np.asarray(mv_image, dtype=np.float32) / 255.0 + mv_image = torch.from_numpy(mv_image).permute(2, 0, 1).contiguous().float() # (3, 960, 640) + mv_image_grid = rearrange(mv_image, 'c (n h) (m w) -> (m h) (n w) c', n=3, m=2).numpy() + mv_image = rearrange(mv_image, 'c (n h) (m w) -> (n m) h w c', n=3, m=2).numpy() + input_image = mv_image + + # generate gaussians + # [4, 256, 256, 3], float32 + input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] + input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) + + images_input_vit = F.interpolate(input_image, size=(224, 224), mode='bilinear', align_corners=False) + + data = {} + input_image = input_image.unsqueeze(0) # [1, 4, 9, H, W] + images_input_vit=images_input_vit.unsqueeze(0) + data['input_vit']=images_input_vit + + elevation = 0 + cam_poses =[] + if mv_moedl_option=='mvdream' or condition_input_image is None: + azimuth = np.arange(0, 360, 90, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + cam_poses.append(cam_pose) + else: + azimuth = np.arange(30, 360, 60, dtype=np.int32) + cnt = 0 + for azi in tqdm.tqdm(azimuth): + if (cnt+1) % 2!= 0: + elevation=-20 + else: + elevation=30 + cam_pose = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + cam_poses.append(cam_pose) + cnt=cnt+1 + + cam_poses = torch.cat(cam_poses,0) + radius = torch.norm(cam_poses[0, :3, 3]) + cam_poses[:, :3, 3] *= opt.cam_radius / radius + transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32).to(device) @ torch.inverse(cam_poses[0]) + cam_poses = transform.unsqueeze(0) @ cam_poses + + cam_poses=cam_poses.unsqueeze(0) + data['source_camera']=cam_poses + + with torch.no_grad(): + if opt.volume_mode == 'TRF_Mesh': + with torch.autocast(device_type='cuda', dtype=torch.float32): + svd_volume = model.forward_svd_volume(input_image,data) + else: + with torch.autocast(device_type='cuda', dtype=torch.float16): + svd_volume = model.forward_svd_volume(input_image,data) + + #time-consuming + export_texmap=False + + mesh_out = model.extract_mesh(svd_volume,use_texture_map=export_texmap) + + if export_texmap: + vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out + + for i in range(len(tex_map)): + mesh_path=os.path.join(opt.workspace, name + str(i) + '_'+ str(seed)+ '.obj') + save_obj_with_mtl( + vertices.data.cpu().numpy(), + uvs.data.cpu().numpy(), + faces.data.cpu().numpy(), + mesh_tex_idx.data.cpu().numpy(), + tex_map[i].permute(1, 2, 0).data.cpu().numpy(), + mesh_path, + ) + else: + vertices, faces, vertex_colors = mesh_out + + save_obj(vertices, faces, vertex_colors[0], output_obj_rgb_path) + save_obj(vertices, faces, vertex_colors[1], output_obj_albedo_path) + save_obj(vertices, faces, vertex_colors[2], output_obj_shading_path) + + + return mv_image_grid, processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path + +# gradio UI + +_TITLE = '''LDM: Large Tensorial SDF Model for Textured Mesh Generation''' + +_DESCRIPTION = ''' + + +* Input can be text prompt, image. +* If you find the output unsatisfying, try using different seeds! +''' + +block = gr.Blocks(title=_TITLE).queue() +with block: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + with gr.Tab("Image-to-3D"): + # input image + with gr.Row(): + condition_input_image = gr.Image( + label="Input Image", + image_mode="RGBA", + type="pil" + ) + + processed_image = gr.Image( + label="Processed Image", + image_mode="RGBA", + type="pil", + interactive=False + ) + + + with gr.Row(): + mv_moedl_option = gr.Radio([ + "zero123plus", + "mvdream" + ], value="zero123plus", + label="Multi-view Diffusion") + + with gr.Row(variant="panel"): + gr.Examples( + examples=[ + os.path.join("example", img_name) for img_name in sorted(os.listdir("example")) + ], + inputs=[condition_input_image], + fn=lambda x: process(condition_input_image=x, prompt=''), + cache_examples=False, + examples_per_page=20, + label='Image-to-3D Examples' + ) + + with gr.Tab("Text-to-3D"): + # input prompt + with gr.Row(): + input_text = gr.Textbox(label="prompt") + # negative prompt + with gr.Row(): + input_neg_text = gr.Textbox(label="negative prompt", value='ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate') + + with gr.Row(variant="panel"): + gr.Examples( + examples=[ + "a hamburger", + "a furry red fox head", + "a teddy bear", + "a motorbike", + ], + inputs=[input_text], + fn=lambda x: process(condition_input_image=None, prompt=x), + cache_examples=False, + label='Text-to-3D Examples' + ) + + # elevation + input_elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0) + # inference steps + input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=30) + # random seed + input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0) + # gen button + button_gen = gr.Button("Generate") + + + with gr.Column(scale=1): + with gr.Row(): + # multi-view results + mv_image_grid = gr.Image(interactive=False, show_label=False) + with gr.Row(): + output_obj_rgb_path = gr.Model3D( + label="RGB Model (OBJ Format)", + interactive=False, + ) + with gr.Row(): + output_obj_albedo_path = gr.Model3D( + label="Albedo Model (OBJ Format)", + interactive=False, + ) + with gr.Row(): + output_obj_shading_path = gr.Model3D( + label="Shading Model (OBJ Format)", + interactive=False, + ) + + + button_gen.click(process, inputs=[condition_input_image, input_text, input_neg_text, input_elevation, input_num_steps, input_seed,mv_moedl_option], outputs=[mv_image_grid,processed_image, output_obj_rgb_path, output_obj_albedo_path, output_obj_shading_path]) + + +block.launch(server_name="0.0.0.0", share=False) \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/core/block.py b/core/block.py new file mode 100644 index 0000000000000000000000000000000000000000..a49bc70738101b044c878bccb2feb7e182816122 --- /dev/null +++ b/core/block.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn + +from .modulate import ModLN + + +class BasicBlock(nn.Module): + """ + Transformer block that is in its simplest form. + Designed for PF-LRM architecture. + """ + # Block contains a self-attention layer and an MLP + def __init__(self, inner_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0.): + super().__init__() + self.norm1 = nn.LayerNorm(inner_dim, eps=eps) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x): + # x: [N, L, D] + before_sa = self.norm1(x) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm2(x)) + return x + + +class ConditionBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition. + Designed for SparseLRM architecture. + """ + # Block contains a cross-attention layer, a self-attention layer, and an MLP + def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0.): + super().__init__() + self.norm1 = nn.LayerNorm(inner_dim, eps=eps) + self.cross_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim, eps=eps) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm3 = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm3(x)) + return x + + +class ConditionModulationBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. + Designed for raw LRM architecture. + """ + # Block contains a cross-attention layer, a self-attention layer, and an MLP + def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0.): + super().__init__() + self.norm1 = ModLN(inner_dim, mod_dim, eps) + self.cross_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = ModLN(inner_dim, mod_dim, eps) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm3 = ModLN(inner_dim, mod_dim, eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond, mod): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + # mod: [N, D_mod] + x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0] + before_sa = self.norm2(x, mod) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm3(x, mod)) + return x diff --git a/core/embedder.py b/core/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..05219ffbab7eb1299a6cfc0589cb5b370b18c734 --- /dev/null +++ b/core/embedder.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class CameraEmbedder(nn.Module): + """ + Embed camera features to a high-dimensional vector. + + Reference: + DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27 + """ + def __init__(self, raw_dim: int, embed_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(raw_dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim), + ) + + @torch.compile + def forward(self, x): + return self.mlp(x) diff --git a/core/encoders/__init__.py b/core/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e564b34377e38d6a29c5dcaa87c7de12b34ec6d --- /dev/null +++ b/core/encoders/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/core/encoders/dino_wrapper.py b/core/encoders/dino_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..dc041a3ec4c6418b251d70ade6d358ccd627a889 --- /dev/null +++ b/core/encoders/dino_wrapper.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from transformers import ViTImageProcessor, ViTModel +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +class DinoWrapper(nn.Module): + """ + Dino v1 wrapper using huggingface transformer implementation. + """ + def __init__(self, model_name: str, freeze: bool = True): + super().__init__() + self.model, self.processor = self._build_dino(model_name) + if freeze: + self._freeze() + + @torch.compile + def forward_model(self, inputs): + return self.model(**inputs, interpolate_pos_encoding=True) + + def forward(self, image): + # image: [N, C, H, W], on cpu + # RGB image with [0,1] scale and properly sized + inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device) + # This resampling of positional embedding uses bicubic interpolation + outputs = self.forward_model(inputs) + last_hidden_states = outputs.last_hidden_state + return last_hidden_states + + def _freeze(self): + logger.warning(f"======== Freezing DinoWrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): + import requests + try: + model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) + processor = ViTImageProcessor.from_pretrained(model_name) + return model, processor + except requests.exceptions.ProxyError as err: + if proxy_error_retries > 0: + print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...") + import time + time.sleep(proxy_error_cooldown) + return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) + else: + raise err diff --git a/core/encoders/dinov2/__init__.py b/core/encoders/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e564b34377e38d6a29c5dcaa87c7de12b34ec6d --- /dev/null +++ b/core/encoders/dinov2/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/core/encoders/dinov2/hub/__init__.py b/core/encoders/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40afb43678d1db842a67445d79260c338a1c1ab5 --- /dev/null +++ b/core/encoders/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/core/encoders/dinov2/hub/backbones.py b/core/encoders/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..c488adb16071eb1f60f37b99399b045b456c78f7 --- /dev/null +++ b/core/encoders/dinov2/hub/backbones.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + # ********** Modified by Zexin He in 2023-2024 ********** + state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern + if vit_kwargs.get("modulation_dim") is not None: + state_dict = { + k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict, strict=False) + else: + model.load_state_dict(state_dict, strict=True) + # ******************************************************** + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/core/encoders/dinov2/hub/classifiers.py b/core/encoders/dinov2/hub/classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..b996f567237db149c482062496195aeed9483910 --- /dev/null +++ b/core/encoders/dinov2/hub/classifiers.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + arch_name: str = "vit_large", + patch_size: int = 14, + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=True) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + **kwargs, +): + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + linear_head = _make_dinov2_linear_classification_head( + arch_name=arch_name, + patch_size=patch_size, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=num_register_tokens, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitb14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitl14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitg14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/core/encoders/dinov2/hub/depth/__init__.py b/core/encoders/dinov2/hub/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e344877b36fee8e6a48b4b83daa04c3c99441069 --- /dev/null +++ b/core/encoders/dinov2/hub/depth/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .decode_heads import BNHead, DPTHead +from .encoder_decoder import DepthEncoderDecoder diff --git a/core/encoders/dinov2/hub/depth/decode_heads.py b/core/encoders/dinov2/hub/depth/decode_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0b84314a0ff213b711b7141821949ba86c9ed8 --- /dev/null +++ b/core/encoders/dinov2/hub/depth/decode_heads.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from functools import partial +import math +import warnings + +import torch +import torch.nn as nn + +from .ops import resize + + +# XXX: (Untested) replacement for mmcv.imdenormalize() +def _imdenormalize(img, mean, std, to_bgr=True): + import numpy as np + + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = (img * std) + mean + if to_bgr: + img = img[::-1] + return img + + +class DepthBaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_layer (nn.Module): Conv layers. Default: None. + act_layer (nn.Module): Activation layers. Default: nn.ReLU. + loss_decode (dict): Config of decode loss. + Default: (). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_layer (dict|None): Norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + conv_layer=None, + act_layer=nn.ReLU, + channels=96, + loss_decode=(), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_layer=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conf_layer = conv_layer + self.act_layer = act_layer + self.loss_decode = loss_decode + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_layer = norm_layer + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + import numpy as np + + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = _imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} + + +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + return output + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_layer. Bias will be set as True if `norm_layer` is None, otherwise + False. Default: "auto". + conv_layer (nn.Module): Convolution layer. Default: None, + which means using conv2d. + norm_layer (nn.Module): Normalization layer. Default: None. + act_layer (nn.Module): Activation layer. Default: nn.ReLU. + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_layer=nn.Conv2d, + norm_layer=None, + act_layer=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super(ConvModule, self).__init__() + official_padding_mode = ["zeros", "circular"] + self.conv_layer = conv_layer + self.norm_layer = norm_layer + self.act_layer = act_layer + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(["conv", "norm", "act"]) + + self.with_norm = norm_layer is not None + self.with_activation = act_layer is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + if padding_mode == "zeros": + padding_layer = nn.ZeroPad2d + else: + raise AssertionError(f"Unsupported padding mode: {padding_mode}") + self.pad = padding_layer(padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = self.conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + norm = partial(norm_layer, num_features=norm_channels) + self.add_module("norm", norm) + if self.with_bias: + from torch.nnModules.batchnorm import _BatchNorm + from torch.nnModules.instancenorm import _InstanceNorm + + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn("Unnecessary conv bias before batch/instance norm") + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + # nn.Tanh has no 'inplace' argument + # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) + if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): + act_layer = partial(act_layer, inplace=inplace) + self.activate = act_layer() + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): + nonlinearity = "leaky_relu" + a = 0.01 # XXX: default negative_slope + else: + nonlinearity = "relu" + a = 0 + if hasattr(self.conv, "weight") and self.conv.weight is not None: + nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) + if hasattr(self.conv, "bias") and self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + if self.with_norm: + if hasattr(self.norm, "weight") and self.norm.weight is not None: + nn.init.constant_(self.norm.weight, 1) + if hasattr(self.norm, "bias") and self.norm.bias is not None: + nn.init.constant_(self.norm.bias, 0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == "conv": + if self.with_explicit_padding: + x = self.pad(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + """ + + def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_layer=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_layer (nn.Module): activation layer. + norm_layer (nn.Module): norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + """ + + def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): + super(PreActResidualConvUnit, self).__init__() + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_layer (nn.Module): activation layer for ResidualConvUnit. + norm_layer (nn.Module): normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + """ + + def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): + super(FeatureFusionBlock, self).__init__() + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/core/encoders/dinov2/hub/depth/encoder_decoder.py b/core/encoders/dinov2/hub/depth/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..88f169cd8fa29393212a084077105edeb320ab62 --- /dev/null +++ b/core/encoders/dinov2/hub/depth/encoder_decoder.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ops import resize + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +class DepthEncoderDecoder(nn.Module): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone and decode_head. + """ + + def __init__(self, backbone, decode_head): + super(DepthEncoderDecoder, self).__init__() + + self.backbone = backbone + self.decode_head = decode_head + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + return self.backbone(img) + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + return self.encode_decode(img, img_meta, rescale, size=size) + + def slide_inference(self, img, img_meta, rescale, stride, crop_size): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None, mode="whole"): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + import torch.distributed as dist + + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/core/encoders/dinov2/hub/depth/ops.py b/core/encoders/dinov2/hub/depth/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0abb2a7040f487371b2f89875ef5fa180a87f4ae --- /dev/null +++ b/core/encoders/dinov2/hub/depth/ops.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/core/encoders/dinov2/hub/depthers.py b/core/encoders/dinov2/hub/depthers.py new file mode 100644 index 0000000000000000000000000000000000000000..a08f71c838f4235811d6a68cb6a5f5cec2204438 --- /dev/null +++ b/core/encoders/dinov2/hub/depthers.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from functools import partial +from typing import Optional, Tuple, Union + +import torch + +from .backbones import _make_dinov2_model +from .depth import BNHead, DepthEncoderDecoder, DPTHead +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding + + +class Weights(Enum): + NYU = "NYU" + KITTI = "KITTI" + + +def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: + if not pretrained: # Default + return (0.001, 10.0) + + # Pretrained, set according to the training dataset for the provided weights + if weights == Weights.KITTI: + return (0.001, 80.0) + + if weights == Weights.NYU: + return (0.001, 10.0) + + return (0.001, 10.0) + + +def _make_dinov2_linear_depth_head( + *, + embed_dim: int, + layers: int, + min_depth: float, + max_depth: float, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + + if layers == 1: + in_index = [0] + else: + assert layers == 4 + in_index = [0, 1, 2, 3] + + return BNHead( + classify=True, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + upsample=4, + in_channels=[embed_dim] * len(in_index), + in_index=in_index, + input_transform="resize_concat", + channels=embed_dim * len(in_index) * 2, + align_corners=False, + min_depth=0.001, + max_depth=80, + loss_decode=(), + ) + + +def _make_dinov2_linear_depther( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_depth_head = _make_dinov2_linear_depth_head( + embed_dim=embed_dim, + layers=layers, + min_depth=min_depth, + max_depth=max_depth, + ) + + layer_count = { + "vit_small": 12, + "vit_base": 12, + "vit_large": 24, + "vit_giant2": 40, + }[arch_name] + + if layers == 4: + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + else: + assert layers == 1 + out_index = [layer_count - 1] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) + + +def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): + return DPTHead( + in_channels=[embed_dim] * 4, + channels=256, + embed_dims=embed_dim, + post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], + readout_type="project", + min_depth=min_depth, + max_depth=max_depth, + loss_decode=(), + ) + + +def _make_dinov2_dpt_depther( + *, + arch_name: str = "vit_large", + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) + dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) + + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) + + if pretrained: + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/core/encoders/dinov2/hub/utils.py b/core/encoders/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7afea3273713518e891d1e6b8e86d58b4700fddc --- /dev/null +++ b/core/encoders/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/core/encoders/dinov2/layers/__init__.py b/core/encoders/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5afbe658132f5dc86d923302922a6a2317be3d2 --- /dev/null +++ b/core/encoders/dinov2/layers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# ****************************************************************************** +# Code modified by Zexin He in 2023-2024. +# Modifications are marked with clearly visible comments +# licensed under the Apache License, Version 2.0. +# ****************************************************************************** + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +# ********** Modified by Zexin He in 2023-2024 ********** +# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock +from .block import Block, BlockWithModulation +# ******************************************************** +from .attention import MemEffAttention diff --git a/core/encoders/dinov2/layers/attention.py b/core/encoders/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f62db59d697e151d4aaec272b897fd07c8a8ab --- /dev/null +++ b/core/encoders/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/core/encoders/dinov2/layers/block.py b/core/encoders/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..53baf6dd4ba9688cccf1cbc9025442e76f861df4 --- /dev/null +++ b/core/encoders/dinov2/layers/block.py @@ -0,0 +1,296 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +# ****************************************************************************** +# Code modified by Zexin He in 2023-2024. +# Modifications are marked with clearly visible comments +# licensed under the Apache License, Version 2.0. +# ****************************************************************************** + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +# ********** Modified by Zexin He in 2023-2024 ********** +# Override forward with modulation input +class BlockWithModulation(Block): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, mod: Tensor) -> Tensor: + def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x, mod))) + + def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x, mod))) + + if self.training and self.sample_drop_ratio > 0.1: + raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet") + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, mod)) + x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, mod) + x = x + ffn_residual_func(x, mod) + return x +# ******************************************************** + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + + # ********** Modified by Zexin He in 2023-2024 ********** + warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning) + # ******************************************************** + + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/core/encoders/dinov2/layers/dino_head.py b/core/encoders/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ccca59999e1d686e1341281c61e8961f1b0e6545 --- /dev/null +++ b/core/encoders/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/core/encoders/dinov2/layers/drop_path.py b/core/encoders/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb1487b0eed4cb14dc0d5d1ee57a2acc78de34a --- /dev/null +++ b/core/encoders/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/core/encoders/dinov2/layers/layer_scale.py b/core/encoders/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..5468ee2dce0a9446c028791de5cff1ff068a4fe5 --- /dev/null +++ b/core/encoders/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/core/encoders/dinov2/layers/mlp.py b/core/encoders/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..0965768a9aef04ac6b81322f4dd60cf035159e91 --- /dev/null +++ b/core/encoders/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/core/encoders/dinov2/layers/patch_embed.py b/core/encoders/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3aaf46c523ab1ae27430419187bbad11e302ab --- /dev/null +++ b/core/encoders/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/core/encoders/dinov2/layers/swiglu_ffn.py b/core/encoders/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..3765d5def655f0a23f3803f4c7f79c33d3ecfd55 --- /dev/null +++ b/core/encoders/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/core/encoders/dinov2/models/__init__.py b/core/encoders/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01c40a2a10ce2b1e39f666328132c2a80111072d --- /dev/null +++ b/core/encoders/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/core/encoders/dinov2/models/vision_transformer.py b/core/encoders/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd35912b543832d597997728694e4bd239498cb6 --- /dev/null +++ b/core/encoders/dinov2/models/vision_transformer.py @@ -0,0 +1,443 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +# ****************************************************************************** +# Code modified by Zexin He in 2023-2024. +# Modifications are marked with clearly visible comments +# licensed under the Apache License, Version 2.0. +# ****************************************************************************** + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +# ********** Modified by Zexin He in 2023-2024 ********** +# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation +# ******************************************************** + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + # ********** Modified by Zexin He in 2023-2024 ********** + modulation_dim: int = None, + # ******************************************************** + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + + # ********** Modified by Zexin He in 2023-2024 ********** + block_norm_layer = None + if modulation_dim is not None: + from ....modulate import ModLN + block_norm_layer = partial(ModLN, mod_dim=modulation_dim) + else: + block_norm_layer = nn.LayerNorm + block_norm_layer = partial(block_norm_layer, eps=1e-6) + # ******************************************************** + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + # ********** Modified by Zexin He in 2023-2024 ********** + norm_layer=block_norm_layer, + # ******************************************************** + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + # ********** Modified by Zexin He in 2023-2024 ********** + # hacking unused mask_token for better DDP + # self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + # ******************************************************** + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + # ********** Modified by Zexin He in 2023-2024 ********** + raise NotImplementedError("Masking is not supported in hacked DINOv2") + # x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + # ******************************************************** + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + # ********** Modified by Zexin He in 2023-2024 ********** + def forward_features(self, x, masks=None, mod=None): + if isinstance(x, list): + raise DeprecationWarning("forward_features_list is deprecated, use forward_features") + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + if mod is None: + for blk in self.blocks: + x = blk(x) + else: + for blk in self.blocks: + x = blk(x, mod) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + # ******************************************************** + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +# ********** Modified by Zexin He in 2023-2024 ********** +# block class selected from Block and BlockWithModulation + +def _block_cls(**kwargs): + modulation_dim = kwargs.get("modulation_dim", None) + if modulation_dim is None: + block_cls = Block + else: + block_cls = BlockWithModulation + return block_cls + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + +# ******************************************************** diff --git a/core/encoders/dinov2_wrapper.py b/core/encoders/dinov2_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d223a8a3c37bc2d1c6d448ea3ed54475bc241ef8 --- /dev/null +++ b/core/encoders/dinov2_wrapper.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +# from accelerate.logging import get_logger + + +# logger = get_logger(__name__) + + +class Dinov2Wrapper(nn.Module): + """ + Dino v2 wrapper using original implementation, hacked with modulation. + """ + def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True): + super().__init__() + self.modulation_dim = modulation_dim + self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim) + if freeze: + if modulation_dim is not None: + raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") + self._freeze() + + def _freeze(self): + #logger.warning(f"======== Freezing Dinov2Wrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): + from importlib import import_module + dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) + model_fn = getattr(dinov2_hub, model_name) + #logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.") + model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) + return model + + #@torch.compile + def forward(self, image: torch.Tensor, mod: torch.Tensor = None): + # image: [N, C, H, W] + # mod: [N, D] or None + # RGB image with [0,1] scale and properly sized + if self.modulation_dim is None: + assert mod is None, "Unexpected modulation input in dinov2 forward." + outs = self.model(image, is_training=True) + else: + assert mod is not None, "Modulation input is required in modulated dinov2 forward." + outs = self.model(image, mod=mod, is_training=True) + ret = torch.cat([ + outs["x_norm_clstoken"].unsqueeze(dim=1), + outs["x_norm_patchtokens"], + ], dim=1) + return ret diff --git a/core/geometry/__init__.py b/core/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66a94c737c98087a954fcc1bbe77028b2851a5af --- /dev/null +++ b/core/geometry/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. diff --git a/core/geometry/camera/__init__.py b/core/geometry/camera/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecbcd2912674b8085531509c27ccfec64e09cc06 --- /dev/null +++ b/core/geometry/camera/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from torch import nn + + +class Camera(nn.Module): + def __init__(self): + super(Camera, self).__init__() + pass diff --git a/core/geometry/camera/perspective_camera.py b/core/geometry/camera/perspective_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..bb31b8ff917ca0a30e1264cfbbb4d53d391a6ba3 --- /dev/null +++ b/core/geometry/camera/perspective_camera.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from . import Camera +import numpy as np + + + +def projection(fovy, n=1.0, f=50.0, near_plane=None): + focal = np.tan(fovy / 180.0 * np.pi * 0.5) + if near_plane is None: + near_plane = n + return np.array( + [[n / focal, 0, 0, 0], + [0, n / -focal, 0, 0], + [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], + [0, 0, -1, 0]]).astype(np.float32) + +def projection_2(opt): + zfar= opt.zfar + znear= opt.znear + tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) + proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + proj_matrix[0, 0] = 1 / tan_half_fov + proj_matrix[1, 1] = 1 / tan_half_fov + proj_matrix[2, 2] = (zfar + znear) / (zfar - znear) + proj_matrix[3, 2] = - (zfar * znear) / (zfar - znear) + proj_matrix[2, 3] = 1 + + return proj_matrix + + +class PerspectiveCamera(Camera): + def __init__(self, opt, device='cuda'): + super(PerspectiveCamera, self).__init__() + self.device = device + self.proj_mtx = torch.from_numpy(projection(opt.fovy, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) + #self.proj_mtx= projection_2(opt).to(self.device).unsqueeze(dim=0) + + + def project(self, points_bxnx4): + out = torch.matmul( + points_bxnx4, + torch.transpose(self.proj_mtx, 1, 2)) + return out diff --git a/core/geometry/render/__init__.py b/core/geometry/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79c6cdfaaaa4a9c66ddac51c126bba1d021a7e18 --- /dev/null +++ b/core/geometry/render/__init__.py @@ -0,0 +1,8 @@ +import torch + +class Renderer(): + def __init__(self): + pass + + def forward(self): + pass \ No newline at end of file diff --git a/core/geometry/render/neural_render.py b/core/geometry/render/neural_render.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb4bd48e11c1842a72786a85481cb87e698f04b --- /dev/null +++ b/core/geometry/render/neural_render.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr +from . import Renderer + +_FG_LUT = None + + +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate( + attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + +def xfm_points(points, matrix, use_python=True): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def compute_vertex_normal(v_pos, t_pos_idx): + i0 = t_pos_idx[:, 0] + i1 = t_pos_idx[:, 1] + i2 = t_pos_idx[:, 2] + + v0 = v_pos[i0, :] + v1 = v_pos[i1, :] + v2 = v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + +class NeuralRender(Renderer): + def __init__(self, device='cuda', camera_model=None): + super(NeuralRender, self).__init__() + self.device = device + self.ctx = dr.RasterizeCudaContext(device=device) + self.projection_mtx = None + self.camera = camera_model + + def render_mesh( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + camera_mv_bx4x4, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + + v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + normal = F.normalize(normal, dim=-1) + normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background + + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal diff --git a/core/geometry/rep_3d/__init__.py b/core/geometry/rep_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e142b8dc4c7fe951e77225e9c70d517ea001be --- /dev/null +++ b/core/geometry/rep_3d/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np + + +class Geometry(): + def __init__(self): + pass + + def forward(self): + pass diff --git a/core/geometry/rep_3d/dmtet.py b/core/geometry/rep_3d/dmtet.py new file mode 100644 index 0000000000000000000000000000000000000000..578e5f53dec25f82f579e395016462b84cd05839 --- /dev/null +++ b/core/geometry/rep_3d/dmtet.py @@ -0,0 +1,504 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .dmtet_utils import get_center_boundary_index +import torch.nn.functional as F + + +############################################################################### +# DMTet utility functions +############################################################################### +def create_mt_variable(device): + triangle_table = torch.tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device=device) + + num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) + base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) + return triangle_table, num_triangles_table, base_tet_edges, v_id + + +def sort_edges(edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + return torch.stack([a, b], -1) + + +############################################################################### +# marching tetrahedrons (differentiable) +############################################################################### + +def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + return verts, faces + + +def create_tetmesh_variables(device='cuda'): + tet_table = torch.tensor( + [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], + [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], + [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], + [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], + [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], + [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], + [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], + [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], + [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], + [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) + num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) + return tet_table, num_tets_table + + +def marching_tets_tetmesh( + pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, + return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + if not return_tet_mesh: + return verts, faces + occupied_verts = ori_v[occ_n] + mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") + tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) + + idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 + tet_verts = torch.cat([verts, occupied_verts], 0) + num_tets = num_tets_table[tetindex] + + tets = torch.cat( + ( + torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( + -1, + 4), + torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( + -1, + 4), + ), dim=0) + # add fully occupied tets + fully_occupied = occ_fx4.sum(-1) == 4 + tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] + tets = torch.cat([tets, tet_fully_occupied]) + + return verts, faces, tet_verts, tets + + +############################################################################### +# Compact tet grid +############################################################################### + +def compact_tets(pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + # Find surface tets + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets + + valid_vtx = tet_fx4[valid_tets].reshape(-1) + unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) + new_pos = pos_nx3[unique_vtx] + new_sdf = sdf_n[unique_vtx] + new_tets = idx_map.reshape(-1, 4) + return new_pos, new_sdf, new_tets + + +############################################################################### +# Subdivide volume +############################################################################### + +def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + device = tet_pos_bxnx3.device + # get new verts + tet_fx4 = tet_bxfx4[0] + edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] + all_edges = tet_fx4[:, edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + idx_map = idx_map + tet_pos_bxnx3.shape[1] + all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) + mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( + all_values.shape[0], -1, 2, + all_values.shape[-1]).mean(2) + new_v = torch.cat([all_values, mid_points_pos], 1) + new_v, new_sdf = new_v[..., :3], new_v[..., 3] + + # get new tets + + idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] + idx_ab = idx_map[0::6] + idx_ac = idx_map[1::6] + idx_ad = idx_map[2::6] + idx_bc = idx_map[3::6] + idx_bd = idx_map[4::6] + idx_cd = idx_map[5::6] + + tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) + tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) + tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) + tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) + tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) + tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) + tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) + tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) + + tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) + tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) + tet = tet_np.long().to(device) + + return new_v, tet, new_sdf + + +############################################################################### +# Adjacency +############################################################################### +def tet_to_tet_adj_sparse(tet_tx4): + # include self connection!!!!!!!!!!!!!!!!!!! + with torch.no_grad(): + t = tet_tx4.shape[0] + device = tet_tx4.device + idx_array = torch.LongTensor( + [0, 1, 2, + 1, 0, 3, + 2, 3, 0, + 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) + + # get all faces + all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( + -1, + 3) # (tx4, 3) + all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) + # sort and group + all_faces_sorted, _ = torch.sort(all_faces, dim=1) + + all_faces_unique, inverse_indices, counts = torch.unique( + all_faces_sorted, dim=0, return_counts=True, + return_inverse=True) + tet_face_fx3 = all_faces_unique[counts == 2] + counts = counts[inverse_indices] # tx4 + valid = (counts == 2) + + group = inverse_indices[valid] + # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) + _, indices = torch.sort(group) + all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] + tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) + + tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) + adj_self = torch.arange(t, device=tet_tx4.device) + adj_self = torch.stack([adj_self, adj_self], -1) + tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) + + tet_adj_idx = torch.unique(tet_adj_idx, dim=0) + values = torch.ones( + tet_adj_idx.shape[0], device=tet_tx4.device).float() + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + + # normalization + neighbor_num = 1.0 / torch.sparse.sum( + adj_sparse, dim=1).to_dense() + values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + return adj_sparse + + +############################################################################### +# Compact grid +############################################################################### + +def get_tet_bxfx4x3(bxnxz, bxfx4): + n_batch, z = bxnxz.shape[0], bxnxz.shape[2] + gather_input = bxnxz.unsqueeze(2).expand( + n_batch, bxnxz.shape[1], 4, z) + gather_index = bxfx4.unsqueeze(-1).expand( + n_batch, bxfx4.shape[1], 4, z).long() + tet_bxfx4xz = torch.gather( + input=gather_input, dim=1, index=gather_index) + + return tet_bxfx4xz + + +def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + with torch.no_grad(): + assert tet_pos_bxnx3.shape[0] == 1 + + occ = grid_sdf[0] > 0 + occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) + mask = (occ_sum > 0) & (occ_sum < 4) + + # build connectivity graph + adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) + mask = mask.float().unsqueeze(-1) + + # Include a one ring of neighbors + for i in range(1): + mask = torch.sparse.mm(adj_matrix, mask) + mask = mask.squeeze(-1) > 0 + + mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) + new_tet_bxfx4 = tet_bxfx4[:, mask].long() + selected_verts_idx = torch.unique(new_tet_bxfx4) + new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] + mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) + new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) + new_grid_sdf = grid_sdf[:, selected_verts_idx] + return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf + + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], + (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], + (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +def sdf_reg_loss_batch(sdf, all_edges): + sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +############################################################################### +# Geometry interface +############################################################################### +class DMTetGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(DMTetGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + tets = np.load('data/tets/%d_compress.npz' % (grid_res)) + self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) + # Make sure the tet is zero-centered and length is equal to 1 + length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] + length = length.max() + mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 + self.verts = (self.verts - mid.unsqueeze(dim=0)) / length + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + self.indices = torch.from_numpy(tets['tets']).long().to(self.device) + self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) + self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) + # Parameters for regularization computation + edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) + all_edges = self.indices[:, edges].reshape(-1, 2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces = marching_tets( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces + + def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces, tet_verts, tets = marching_tets_tetmesh( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, + num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces, tet_verts, tets + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/core/geometry/rep_3d/dmtet_utils.py b/core/geometry/rep_3d/dmtet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb6d81d039b9bc925f65a0beddba2e690d71b60 --- /dev/null +++ b/core/geometry/rep_3d/dmtet_utils.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch + + +def get_center_boundary_index(verts): + length_ = torch.sum(verts ** 2, dim=-1) + center_idx = torch.argmin(length_) + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + boundary_idx = torch.nonzero(boundary) + return center_idx, boundary_idx.squeeze(dim=-1) diff --git a/core/geometry/rep_3d/extract_texture_map.py b/core/geometry/rep_3d/extract_texture_map.py new file mode 100644 index 0000000000000000000000000000000000000000..cd0c60833490830036be391fbe127bb2d5136f8f --- /dev/null +++ b/core/geometry/rep_3d/extract_texture_map.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import numpy as np +import nvdiffrast.torch as dr + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/core/geometry/rep_3d/flexicubes.py b/core/geometry/rep_3d/flexicubes.py new file mode 100644 index 0000000000000000000000000000000000000000..5d7f9f9c6bdc5449f5cb57b0c22ea27d025ffd33 --- /dev/null +++ b/core/geometry/rep_3d/flexicubes.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + device (str): Specifies the computational device (default is "cuda"). + dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners (torch.Tensor): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (torch.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + """ + + def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + def construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = torch.arange(8).to(self.device) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = torch.ones(res, device=self.device) + + res = torch.tensor([res], dtype=torch.float, device=self.device) + coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 + verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) + cubes = (base_cube_f.unsqueeze(0) + + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) + + verts_rounded = torch.round(verts * 10**5) / (10**5) + verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes + + def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, + gamma_f=None, training=False, output_tetmesh=False, grad_func=None): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return torch.zeros( + (0, 3), + device=self.device), torch.zeros( + (0, 4), + dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( + (0, 3), + dtype=torch.long, device=self.device), torch.zeros( + (0), + device=self.device) + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) + else: + beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha_fx8 is not None: + alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) + else: + alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + if grad_func is not None: + normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) + vd = [] + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + + if grad_func is not None: + with torch.no_grad(): + cube_e_verts_idx = idx_map[cur_cubes] + curr_edge_group[~curr_mask] = 0 + + verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) + verts_group_idx[verts_group_idx == -1] = 0 + verts_group_pos = torch.index_select( + input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) + v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) + curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) + verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) + + normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( + -1, num.item(), 7, + 3) + curr_mask = curr_mask.squeeze(2) + vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, + verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + + if grad_func is not None: + vd = torch.cat(vd) + L_dev = torch.zeros([1], device=self.device) + else: + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + with torch.no_grad(): + vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( + 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) + gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( + 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 + vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / + weight_sum.unsqueeze(-1)).squeeze(1) + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = torch.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ + s_edges < 0]] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) + + tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = (occ_sum == 8) + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = torch.arange( + inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), + dtype=torch.long, device=x_nx3.device) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = torch.sort(group) + cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, + device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] + edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( + 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] + + tets = torch.cat([tets_surface, tets_inside]) + vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets diff --git a/core/geometry/rep_3d/flexicubes_geometry.py b/core/geometry/rep_3d/flexicubes_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..702b9e257f6e2393ae226b9cc442e56fa339d232 --- /dev/null +++ b/core/geometry/rep_3d/flexicubes_geometry.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .flexicubes import FlexiCubes # replace later +from .dmtet import sdf_reg_loss_batch +import torch.nn.functional as F + +def get_center_boundary_index(grid_res, device): + v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True + center_indices = torch.nonzero(v.reshape(-1)) + + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = torch.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(FlexiCubesGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + self.fc = FlexiCubes(device, weight_scale=0.5) + self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + + all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + self.all_edges = torch.unique(all_edges, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, + beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], training=is_training + ) + return verts, faces, v_reg_loss + + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + return_value['normal'] = normal + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/core/geometry/rep_3d/tables.py b/core/geometry/rep_3d/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..936a4bc5e2f95891f72651f2c42272e01a3a2bc3 --- /dev/null +++ b/core/geometry/rep_3d/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/core/instant_utils/__init__.py b/core/instant_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/core/instant_utils/camera_util.py b/core/instant_utils/camera_util.py new file mode 100644 index 0000000000000000000000000000000000000000..046cbd7159776df8ca2494818baa72fd955b36c4 --- /dev/null +++ b/core/instant_utils/camera_util.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F +import numpy as np + + +def pad_camera_extrinsics_4x4(extrinsics): + if extrinsics.shape[-2] == 4: + return extrinsics + padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) + if extrinsics.ndim == 3: + padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) + extrinsics = torch.cat([extrinsics, padding], dim=-2) + return extrinsics + + +def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None): + """ + Create OpenGL camera extrinsics from camera locations and look-at position. + + camera_position: (M, 3) or (3,) + look_at: (3) + up_world: (3) + return: (M, 3, 4) or (3, 4) + """ + # by default, looking at the origin and world up is z-axis + if look_at is None: + look_at = torch.tensor([0, 0, 0], dtype=torch.float32) + if up_world is None: + up_world = torch.tensor([0, 0, 1], dtype=torch.float32) + if camera_position.ndim == 2: + look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) + up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) + + # OpenGL camera: z-backward, x-right, y-up + z_axis = camera_position - look_at + z_axis = F.normalize(z_axis, dim=-1).float() + x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) + x_axis = F.normalize(x_axis, dim=-1).float() + y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) + y_axis = F.normalize(y_axis, dim=-1).float() + + extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) + extrinsics = pad_camera_extrinsics_4x4(extrinsics) + return extrinsics + + +def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + + xs = radius * np.cos(elevations) * np.cos(azimuths) + ys = radius * np.cos(elevations) * np.sin(azimuths) + zs = radius * np.sin(elevations) + + cam_locations = np.stack([xs, ys, zs], axis=-1) + cam_locations = torch.from_numpy(cam_locations).float() + + c2ws = center_looking_at_camera_pose(cam_locations) + return c2ws + + +def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0): + # M: number of circular views + # radius: camera dist to center + # elevation: elevation degrees of the camera + # return: (M, 4, 4) + assert M > 0 and radius > 0 + + elevation = np.deg2rad(elevation) + + camera_positions = [] + for i in range(M): + azimuth = 2 * np.pi * i / M + x = radius * np.cos(elevation) * np.cos(azimuth) + y = radius * np.cos(elevation) * np.sin(azimuth) + z = radius * np.sin(elevation) + camera_positions.append([x, y, z]) + camera_positions = np.array(camera_positions) + camera_positions = torch.from_numpy(camera_positions).float() + extrinsics = center_looking_at_camera_pose(camera_positions) + return extrinsics + + +def FOV_to_intrinsics(fov, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics + + +def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) + elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/core/instant_utils/infer_util.py b/core/instant_utils/infer_util.py new file mode 100644 index 0000000000000000000000000000000000000000..18b28a9453da113d679d15cc5149ac0330a97c65 --- /dev/null +++ b/core/instant_utils/infer_util.py @@ -0,0 +1,97 @@ +import os +import imageio +import rembg +import torch +import numpy as np +import PIL.Image +from PIL import Image +from typing import Any + + +def remove_background(image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def images_to_video( + images: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + video_dir = os.path.dirname(output_path) + video_name = os.path.basename(output_path) + os.makedirs(video_dir, exist_ok=True) + + frames = [] + for i in range(len(images)): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) + + +def save_video( + frames: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] + writer = imageio.get_writer(output_path, fps=fps) + for frame in frames: + writer.append_data(frame) + writer.close() \ No newline at end of file diff --git a/core/instant_utils/mesh_util.py b/core/instant_utils/mesh_util.py new file mode 100644 index 0000000000000000000000000000000000000000..75338bd43d7a30e3463fb6741de069fdff46af70 --- /dev/null +++ b/core/instant_utils/mesh_util.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import trimesh +import cv2 +import numpy as np +import nvdiffrast.torch as dr +from PIL import Image + + +def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + #facenp_fx3 = facenp_fx3[:, [2, 1, 0]] + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'obj') + + +def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): + + pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) + + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fpath, 'glb') + + +def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + import os + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = '%s/%s.mtl' % (fol, na) + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write('map_Kd %s.png\n' % na) + fid.close() + #### + + fid = open(fname, 'w') + fid.write('mtllib %s.mtl\n' % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write('vt %f %f\n' % (pp[0], pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') + + +def loadobj(meshfile): + v = [] + f = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if len(data) != 4: + continue + if data[0] == 'v': + v.append([float(d) for d in data[1:]]) + if data[0] == 'f': + data = [da.split('/')[0] for da in data] + f.append([int(d) for d in data[1:]]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + return pointnp_px3, facenp_fx3 + + +def loadobjtex(meshfile): + v = [] + vt = [] + f = [] + ft = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): + continue + if data[0] == 'v': + assert len(data) == 4 + + v.append([float(d) for d in data[1:]]) + if data[0] == 'vt': + if len(data) == 3 or len(data) == 4: + vt.append([float(d) for d in data[1:3]]) + if data[0] == 'f': + data = [da.split('/') for da in data] + if len(data) == 4: + f.append([int(d[0]) for d in data[1:]]) + ft.append([int(d[1]) for d in data[1:]]) + elif len(data) == 5: + idx1 = [1, 2, 3] + data1 = [data[i] for i in idx1] + f.append([int(d[0]) for d in data1]) + ft.append([int(d[1]) for d in data1]) + idx2 = [1, 3, 4] + data2 = [data[i] for i in idx2] + f.append([int(d[0]) for d in data2]) + ft.append([int(d[1]) for d in data2]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + uvs = np.array(vt, dtype=np.float32) + return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/core/instant_utils/train_util.py b/core/instant_utils/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..350349aa66cd25b4177238f96ac779d4b56faaa6 --- /dev/null +++ b/core/instant_utils/train_util.py @@ -0,0 +1,26 @@ +import importlib + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/core/lrm_reconstructor.py b/core/lrm_reconstructor.py new file mode 100644 index 0000000000000000000000000000000000000000..bc51c22e9ed69a2230b00dc1c10e2ea72cd64c3a --- /dev/null +++ b/core/lrm_reconstructor.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from typing import Tuple, Literal +from functools import partial + +import itertools + + +# LRM +from .embedder import CameraEmbedder +from .transformer import TransformerDecoder +# from accelerate.logging import get_logger + +# logger = get_logger(__name__) + + +class LRM_VSD_Mesh_Net(nn.Module): + """ + predict VSD using transformer + """ + def __init__(self, camera_embed_dim: int, + transformer_dim: int, transformer_layers: int, transformer_heads: int, + triplane_low_res: int, triplane_high_res: int, triplane_dim: int, + encoder_freeze: bool = True, encoder_type: str = 'dino', + encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, app_dim = 27, density_dim = 8, app_n_comp=24, + density_n_comp=8): + super().__init__() + + # attributes + self.encoder_feat_dim = encoder_feat_dim + self.camera_embed_dim = camera_embed_dim + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + self.transformer_dim=transformer_dim + + # modules + self.encoder = self._encoder_fn(encoder_type)( + model_name=encoder_model_name, + modulation_dim=self.camera_embed_dim, #mod camera vector + freeze=encoder_freeze, + ) + self.camera_embedder = CameraEmbedder( + raw_dim=12+4, embed_dim=camera_embed_dim, + ) + + self.n_comp=app_n_comp+density_n_comp + self.app_dim=app_dim + self.density_dim=density_dim + self.app_n_comp=app_n_comp + self.density_n_comp=density_n_comp + + self.pos_embed = nn.Parameter(torch.randn(1, 3*(triplane_low_res**2)+3*triplane_low_res, transformer_dim) * (1. / transformer_dim) ** 0.5) + self.transformer = TransformerDecoder( + block_type='cond', + num_layers=transformer_layers, num_heads=transformer_heads, + inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=None, + ) + # for plane + self.upsampler = nn.ConvTranspose2d(transformer_dim, self.n_comp, kernel_size=2, stride=2, padding=0) + self.dim_map = nn.Linear(transformer_dim,self.n_comp) + self.up_line = nn.Linear(triplane_low_res,triplane_low_res*2) + + + @staticmethod + def _encoder_fn(encoder_type: str): + encoder_type = encoder_type.lower() + assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type" + if encoder_type == 'dino': + from .encoders.dino_wrapper import DinoWrapper + #logger.info("Using DINO as the encoder") + return DinoWrapper + elif encoder_type == 'dinov2': + from .encoders.dinov2_wrapper import Dinov2Wrapper + #logger.info("Using DINOv2 as the encoder") + return Dinov2Wrapper + + def forward_transformer(self, image_feats, camera_embeddings=None): + N = image_feats.shape[0] + x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] + x = self.transformer( + x, + cond=image_feats, + mod=camera_embeddings, + ) + return x + def reshape_upsample(self, tokens): + #B,_,3*ncomp + N = tokens.shape[0] + H = W = self.triplane_low_res + P=self.n_comp + + offset=3*H*W + + # planes + plane_tokens= tokens[:,:3*H*W,:].view(N,H,W,3,self.transformer_dim) + plane_tokens = torch.einsum('nhwip->inphw', plane_tokens) # [3, N, P, H, W] + plane_tokens = plane_tokens.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + plane_tokens = self.upsampler(plane_tokens) # [3*N, P, H', W'] + plane_tokens = plane_tokens.view(3, N, *plane_tokens.shape[-3:]) # [3, N, P, H', W'] + plane_tokens = torch.einsum('inphw->niphw', plane_tokens) # [N, 3, P, H', W'] + plane_tokens = plane_tokens.reshape(N, 3*P, *plane_tokens.shape[-2:]) # # [N, 3*P, H', W'] + plane_tokens = plane_tokens.contiguous() + + #lines + line_tokens= tokens[:,3*H*W:3*H*W+3*H,:].view(N,H,3,self.transformer_dim) + line_tokens= self.dim_map(line_tokens) + line_tokens = torch.einsum('nhip->npih', line_tokens) # [ N, P, 3, H] + line_tokens=self.up_line(line_tokens) + line_tokens = torch.einsum('npih->niph', line_tokens) # [ N, 3, P, H] + line_tokens=line_tokens.reshape(N,3*P,line_tokens.shape[-1],1) + line_tokens = line_tokens.contiguous() + + mat_tokens=None + + d_mat_tokens=None + + return plane_tokens[:,:self.app_n_comp*3,:,:],line_tokens[:,:self.app_n_comp*3,:,:],mat_tokens,d_mat_tokens,plane_tokens[:,self.app_n_comp*3:,:,:],line_tokens[:,self.app_n_comp*3:,:,:] + + def forward_planes(self, image, camera): + # image: [N, V, C_img, H_img, W_img] + # camera: [N,V, D_cam_raw] + N,V,_,H,W = image.shape + image=image.reshape(N*V,3,H,W) + camera=camera.reshape(N*V,-1) + + + # embed camera + camera_embeddings = self.camera_embedder(camera) + assert camera_embeddings.shape[-1] == self.camera_embed_dim, \ + f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}" + + # encode image + image_feats = self.encoder(image, camera_embeddings) + assert image_feats.shape[-1] == self.encoder_feat_dim, \ + f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}" + + image_feats=image_feats.reshape(N,V*image_feats.shape[-2],image_feats.shape[-1]) + + # transformer generating planes + tokens = self.forward_transformer(image_feats) + + app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.reshape_upsample(tokens) + + return app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines + + def forward(self, image,source_camera): + # image: [N,V, C_img, H_img, W_img] + # source_camera: [N, V, D_cam_raw] + + assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera" + planes = self.forward_planes(image, source_camera) + + #B,3,dim,H,W + return planes \ No newline at end of file diff --git a/core/models.py b/core/models.py new file mode 100644 index 0000000000000000000000000000000000000000..0b57075ad554c5f199829d61cfde39c3e46d8fc6 --- /dev/null +++ b/core/models.py @@ -0,0 +1,783 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import mcubes + +import kiui +from kiui.lpips import LPIPS + +from core.lrm_reconstructor import LRM_VSD_Mesh_Net +from core.options import Options +from core.tensoRF import TensorVMSplit_Mesh,TensorVMSplit_NeRF +from torchvision.transforms import v2 +from core.geometry.camera.perspective_camera import PerspectiveCamera +from core.geometry.render.neural_render import NeuralRender +from core.geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry +import nvdiffrast.torch as dr +from core.instant_utils.mesh_util import xatlas_uvmap + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +#tensorSDF + transformer + volume_rendering +class LTRFM_NeRF(nn.Module): + def __init__( + self, + opt: Options, + ): + super().__init__() + + self.opt = opt + + + #predict svd using transformer + self.vsd_net = LRM_VSD_Mesh_Net( + camera_embed_dim=opt.camera_embed_dim, + transformer_dim=opt.transformer_dim, + transformer_layers=opt.transformer_layers, + transformer_heads=opt.transformer_heads, + triplane_low_res=opt.triplane_low_res, + triplane_high_res=opt.triplane_high_res, + triplane_dim=opt.triplane_dim, + encoder_freeze=opt.encoder_freeze, + encoder_type=opt.encoder_type, + encoder_model_name=opt.encoder_model_name, + encoder_feat_dim=opt.encoder_feat_dim, + app_dim=opt.app_dim, + density_dim=opt.density_dim, + app_n_comp=opt.app_n_comp, + density_n_comp=opt.density_n_comp, + ) + + aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device) + grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device) + near_far =torch.tensor([opt.znear, opt.zfar]).to(device) + # tensorf Renderer + self.tensorRF = TensorVMSplit_NeRF(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\ + density_dim=opt.density_dim,near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe) + + # LPIPS loss + if self.opt.lambda_lpips > 0: + self.lpips_loss = LPIPS(net='vgg') + self.lpips_loss.requires_grad_(False) + + + def state_dict(self, **kwargs): + # remove lpips_loss + state_dict = super().state_dict(**kwargs) + for k in list(state_dict.keys()): + if 'lpips_loss' in k: + del state_dict[k] + return state_dict + + def set_beta(self,t): + self.tensorRF.lap_density.set_beta(t) + + + + # predict svd_volume + def forward_svd_volume(self, images, data): + # images: [B, 4, 9, H, W] + # return: Gaussians: [B, dim_t] + B, V, C, H, W = images.shape + + + source_camera=data['source_camera'] + images_vit=data['input_vit'] # for transformer + source_camera=source_camera.reshape(B,V,-1) # [B*V, 16] + app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera) + + + app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size) + app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1) + density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size) + density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1) + + results = { + 'app_planes': app_planes, + 'app_lines': app_lines, + 'basis_mat':basis_mat, + 'd_basis_mat':d_basis_mat, + 'density_planes':density_planes, + 'density_lines':density_lines + } + + return results + + def extract_mesh(self, + planes: torch.Tensor, + mesh_resolution: int = 256, + mesh_threshold: int = 0.009, + use_texture_map: bool = False, + texture_resolution: int = 1024,): + + device = planes['app_planes'].device + + grid_size = mesh_resolution + points = torch.linspace(-1, 1, steps=grid_size).half() + + x, y, z = torch.meshgrid(points, points, points) + + xyz_samples = torch.stack((x, y, z), dim=0).unsqueeze(0).to(device) + xyz_samples=xyz_samples.permute(0,2,3,4,1) + xyz_samples=xyz_samples.view(1,-1,1,3) + + + grid_out = self.tensorRF.predict_sdf(planes,xyz_samples) + grid_out['sigma']=grid_out['sigma'].view(grid_size,grid_size,grid_size).float() + + vertices, faces = mcubes.marching_cubes( + grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), + mesh_threshold, + ) + vertices = vertices / (mesh_resolution - 1) * 2 - 1 + + if not use_texture_map: + # query vertex colors + vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) + rgb_colors = self.tensorRF.predict_color( + planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() + rgb_colors = (rgb_colors * 255).astype(np.uint8) + + albedob_colors = self.tensorRF.predict_color( + planes, vertices_tensor)['albedo'].squeeze(0).cpu().numpy() + albedob_colors = (albedob_colors * 255).astype(np.uint8) + + shading_colors = self.tensorRF.predict_color( + planes, vertices_tensor)['shading'].squeeze(0).cpu().numpy() + shading_colors = (shading_colors * 255).astype(np.uint8) + + return vertices, faces, [rgb_colors,albedob_colors,shading_colors] + + # use x-atlas to get uv mapping for the mesh + vertices = torch.tensor(vertices, dtype=torch.float32, device=device) + faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) + + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.float().cpu() + + # query the texture field to get the RGB color for texture map + #TBD here + query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3) + + vertices_colors = self.tensorRF.predict_color( + planes, query_vertices)['rgb'].squeeze(0).cpu() + + vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3) + + background_feature = torch.zeros_like(vertices_colors) + img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask.half()) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + #albedo + vertices_colors_albedo = self.tensorRF.predict_color( + planes, query_vertices)['albedo'].squeeze(0).cpu() + + vertices_colors_albedo=vertices_colors_albedo.reshape(1,texture_resolution,texture_resolution,3) + + background_feature = torch.zeros_like(vertices_colors_albedo) + img_feat = torch.lerp(background_feature, vertices_colors_albedo, tex_hard_mask.half()) + texture_map_albedo = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, [texture_map,texture_map_albedo] + + + def forward(self, data, step_ratio=1): + # data: output of the dataloader + # return: loss + #self.set_beta(data['t']) + results = {} + loss = 0 + + images = data['input'] # [B, 4, 9, h, W], input features + + # use the first view to predict gaussians + svd_volume = self.forward_svd_volume(images,data) # [B, N, 14] + + results['svd_volume'] = svd_volume + + # always use white bg + bg_color = torch.ones(3, dtype=torch.float32).to(device) + + # use the other views for rendering and supervision + results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample) + pred_shading = results['image'] # [B, V, C, output_size, output_size] + pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] + pred_albedos = results['albedo'] # [B, V, C, output_size, output_size] + + pred_images = pred_shading*pred_albedos + + results['images_pred'] = pred_images + results['alphas_pred'] = pred_alphas + results['pred_albedos'] = pred_albedos + results['pred_shading'] = pred_shading + + + gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views + gt_albedos = data['albedos_output'] # [B, V, 3, output_size, output_size], ground-truth novel views + gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks + + gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) + gt_albedos = gt_albedos * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) + + loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + F.mse_loss(pred_albedos, gt_albedos) + loss = loss + loss_mse + + # eikonal_loss = ((results['eik_grads'].norm(2, dim=1) - 1) ** 2).mean() + # loss = loss+ 0.1*eikonal_loss + + if self.opt.lambda_lpips > 0: + loss_lpips = self.lpips_loss( + F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + ).mean() + results['loss_lpips'] = loss_lpips + loss = loss + self.opt.lambda_lpips * loss_lpips + + results['loss'] = loss + + # metric + with torch.no_grad(): + psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) + results['psnr'] = psnr + + return results + + + def render_frame(self, data): + # data: output of the dataloader + # return: loss + #self.set_beta(data['t']) + results = {} + loss = 0 + + images = data['input_vit'] + + # use the first view to predict gaussians + svd_volume = self.forward_svd_volume(images,data) # [B, N, 14] + + results['svd_volume'] = svd_volume + + # always use white bg + bg_color = torch.ones(3, dtype=torch.float32).to(device) + + # use the other views for rendering and supervision + results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample) + pred_shading = results['image'] # [B, V, C, output_size, output_size] + pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] + pred_albedos = results['albedo'] # [B, V, C, output_size, output_size] + + pred_images = pred_shading*pred_albedos + + results['images_pred'] = pred_images + results['alphas_pred'] = pred_alphas + results['pred_albedos'] = pred_albedos + results['pred_shading'] = pred_shading + + + return results + + + + + +#tensorSDF + transformer + SDF + Mesh +class LTRFM_Mesh(nn.Module): + def __init__( + self, + opt: Options, + ): + super().__init__() + + self.opt = opt + + # attributes + self.grid_res = 128 #grid_res + self.grid_scale = 2.0 #grid_scale + self.deformation_multiplier = 4.0 + + + self.init_flexicubes_geometry(device, self.opt) + + #predict svd using transformer + self.vsd_net = LRM_VSD_Mesh_Net( + camera_embed_dim=opt.camera_embed_dim, + transformer_dim=opt.transformer_dim, + transformer_layers=opt.transformer_layers, + transformer_heads=opt.transformer_heads, + triplane_low_res=opt.triplane_low_res, + triplane_high_res=opt.triplane_high_res, + triplane_dim=opt.triplane_dim, + encoder_freeze=opt.encoder_freeze, + encoder_type=opt.encoder_type, + encoder_model_name=opt.encoder_model_name, + encoder_feat_dim=opt.encoder_feat_dim, + app_dim=opt.app_dim, + density_dim=opt.density_dim, + app_n_comp=opt.app_n_comp, + density_n_comp=opt.density_n_comp, + ) + + aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device) + grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device) + near_far =torch.tensor([opt.znear, opt.zfar]).to(device) + # tensorf Renderer + self.tensorRF = TensorVMSplit_Mesh(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\ + density_dim=opt.density_dim, near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe) + + # LPIPS loss + if self.opt.lambda_lpips > 0: + self.lpips_loss = LPIPS(net='vgg') + self.lpips_loss.requires_grad_(False) + + + # load ckpt + if opt.ckpt_nerf is not None: + sd = torch.load(opt.ckpt_nerf, map_location='cpu')['model'] + #sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')} + sd_fc = {} + for k, v in sd.items(): + k=k.replace('module.', '') + if k.startswith('vsd.renderModule.'): + continue + else: + sd_fc[k] = v + sd_fc = {k.replace('vsd_net.', ''): v for k, v in sd_fc.items()} + sd_fc = {k.replace('tensorRF.', ''): v for k, v in sd_fc.items()} + # missing `net_deformation` and `net_weight` parameters + self.vsd_net.load_state_dict(sd_fc, strict=False) + self.tensorRF.load_state_dict(sd_fc, strict=False) + print(f'Loaded weights from {opt.ckpt_nerf}') + + + def state_dict(self, **kwargs): + # remove lpips_loss + state_dict = super().state_dict(**kwargs) + for k in list(state_dict.keys()): + if 'lpips_loss' in k: + del state_dict[k] + return state_dict + + + # predict svd_volume + def forward_svd_volume(self, images, data): + # images: [B, 4, 9, H, W] + # return: Gaussians: [B, dim_t] + B, V, C, H, W = images.shape + + source_camera=data['source_camera'] + images_vit=data['input_vit'] # for transformer + source_camera=source_camera.reshape(B,V,-1) # [B*V, 16] + app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera) + + + app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size) + app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1) + density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size) + density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1) + + results = { + 'app_planes': app_planes, + 'app_lines': app_lines, + 'basis_mat':basis_mat, + 'd_basis_mat':d_basis_mat, + 'density_planes':density_planes, + 'density_lines':density_lines + } + + return results + + + def init_flexicubes_geometry(self, device, opt): + camera = PerspectiveCamera(opt, device=device) + renderer = NeuralRender(device, camera_model=camera) + self.geometry = FlexiCubesGeometry( + grid_res=self.grid_res, + scale=self.grid_scale, + renderer=renderer, + render_type='neural_render', + device=device, + ) + + + # query vsd for sdf weight and ... + def get_sdf_deformation_prediction(self, planes): + ''' + Predict SDF and deformation for tetrahedron vertices + :param planes: triplane feature map for the geometry + ''' + B = planes['app_lines'].shape[0] + init_position = self.geometry.verts.unsqueeze(0).expand(B, -1, -1) + + + sdf, deformation, weight = self.tensorRF.get_geometry_prediction(planes,init_position,self.geometry.indices) + + deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation) + sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) + + sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) + sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) + pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) + neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) + zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) + if torch.sum(zero_surface).item() > 0: + update_sdf = torch.zeros_like(sdf[0:1]) + max_sdf = sdf.max() + min_sdf = sdf.min() + update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero + update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero + new_sdf = torch.zeros_like(sdf) + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + new_sdf[i_batch:i_batch + 1] += update_sdf + update_mask = (new_sdf == 0).float() + # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) + sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) + sdf_reg_loss = sdf_reg_loss * zero_surface.float() + sdf = sdf * update_mask + new_sdf * (1 - update_mask) + + final_sdf = [] + final_def = [] + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + final_sdf.append(sdf[i_batch: i_batch + 1].detach()) + final_def.append(deformation[i_batch: i_batch + 1].detach()) + else: + final_sdf.append(sdf[i_batch: i_batch + 1]) + final_def.append(deformation[i_batch: i_batch + 1]) + sdf = torch.cat(final_sdf, dim=0) + deformation = torch.cat(final_def, dim=0) + return sdf, deformation, sdf_reg_loss, weight + + def get_geometry_prediction(self, planes=None): + ''' + Function to generate mesh with give triplanes + :param planes: triplane features + ''' + + sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes) + + + v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation + tets = self.geometry.indices + n_batch = planes['app_planes'].shape[0] + v_list = [] + f_list = [] + flexicubes_surface_reg_list = [] + + + for i_batch in range(n_batch): + verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( + v_deformed[i_batch], + sdf[i_batch].squeeze(dim=-1), + with_uv=False, + indices=tets, + weight_n=weight[i_batch].squeeze(dim=-1), + is_training=self.training, + ) + flexicubes_surface_reg_list.append(flexicubes_surface_reg) + v_list.append(verts) + f_list.append(faces) + + flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() + flexicubes_weight_reg = (weight ** 2).mean() + + return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg) + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + B = planes['app_planes'].shape[0] + tex_pos = torch.cat(tex_pos, dim=0) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + if max_point==0: # xrg: hard mask may filter all points, and don not left any point + max_point=max_point+1 + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + + #return texture rgb + tex_feat = self.tensorRF.get_texture_prediction(tex_pos,vsd_vome=planes) + + if hard_mask is not None: + final_tex_feat = torch.zeros( + B, hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + for i in range(B): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(B, hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): + ''' + Function to render a generated mesh with nvdiffrast + :param mesh_v: List of vertices for the mesh + :param mesh_f: List of faces for the mesh + :param cam_mv: 4x4 rotation matrix + :return: + ''' + return_value_list = [] + for i_mesh in range(len(mesh_v)): + return_value = self.geometry.render_mesh( + mesh_v[i_mesh], + mesh_f[i_mesh].int(), + cam_mv[i_mesh], + resolution=render_size, + hierarchical_mask=False + ) + return_value_list.append(return_value) + + return_keys = return_value_list[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in return_value_list] + return_value[k] = value + + mask = torch.cat(return_value['mask'], dim=0) + hard_mask = torch.cat(return_value['hard_mask'], dim=0) + tex_pos = return_value['tex_pos'] + depth = torch.cat(return_value['depth'], dim=0) + normal = torch.cat(return_value['normal'], dim=0) + return mask, hard_mask, tex_pos, depth, normal + + def forward_geometry(self, planes, render_cameras, render_size=256): + ''' + Main function of our Generator. It first generate 3D mesh, then render it into 2D image + with given `render_cameras`. + :param planes: triplane features + :param render_cameras: cameras to render generated 3D shape, a w2c matrix + ''' + B, NV = render_cameras.shape[:2] + + # Generate 3D mesh first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + + # Render the mesh into 2D image (get 3d position of each image plane) continue for here + cam_mv = render_cameras + run_n_view = cam_mv.shape[1] + antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size) + + tex_hard_mask = hard_mask + tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos] + tex_hard_mask = torch.cat( + [torch.cat( + [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1] + for i_view in range(run_n_view)], dim=2) + for i in range(B)], dim=0) + + # Querying the texture field to predict the texture feature for each pixel on the image + tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) + background_feature = torch.ones_like(tex_feat) # white background + + # Merge them together + img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) + + # We should split it back to the original image shape + img_feat = torch.cat( + [torch.cat( + [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] + for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) + + img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + + albedo=img[:,:,3:6,:,:] + img=img[:,:,0:3,:,:] + + antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive + normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + + out = { + 'image': img, + 'albedo': albedo, + 'mask': antilias_mask, + 'depth': depth, + 'normal': normal, + 'sdf': sdf, + 'mesh_v': mesh_v, + 'mesh_f': mesh_f, + 'sdf_reg_loss': sdf_reg_loss, + } + return out + + def forward(self, data, step_ratio=1): + # data: output of the dataloader + # return: loss + + results = {} + loss = 0 + + images = data['input'] # [B, 4, 9, h, W], input features + + # use the first view to predict gaussians + svd_volume = self.forward_svd_volume(images,data) # [B, N, 14] + + results['svd_volume'] = svd_volume + + # return the rendered images + results = self.forward_geometry(svd_volume, data['w2c'], self.opt.output_size) + + + # always use white bg + bg_color = torch.ones(3, dtype=torch.float32).to(device) + + # use the other views for rendering and supervision + #results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample) + + + pred_shading = results['image'] # [B, V, C, output_size, output_size] + pred_alphas = results['mask'] # [B, V, 1, output_size, output_size] + pred_albedos = results['albedo'] # [B, V, C, output_size, output_size] + + pred_images=pred_shading*pred_albedos + + results['images_pred'] = pred_images + results['alphas_pred'] = pred_alphas + results['pred_albedos'] = pred_albedos + results['pred_shading'] = pred_shading + + + gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views + gt_albedos = data['albedos_output'] # [B, V, 3, output_size, output_size], ground-truth novel views + gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks + + gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) + gt_albedos = gt_albedos * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) + + loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + F.mse_loss(pred_albedos, gt_albedos) + loss = loss + loss_mse + + if self.opt.lambda_lpips > 0: + loss_lpips = self.lpips_loss( + F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + ).mean() + results['loss_lpips'] = loss_lpips + loss = loss + self.opt.lambda_lpips * loss_lpips + + results['loss'] = loss + + # metric + with torch.no_grad(): + psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) + results['psnr'] = psnr + + return results + + + def render_frame(self, data): + # data: output of the dataloader + # return: loss + + results = {} + + images = data['input_vit'] # [B, 4, 9, h, W], input features + + # use the first view to predict gaussians + svd_volume = self.forward_svd_volume(images,data) # [B, N, 14] + + results['svd_volume'] = svd_volume + + # return the rendered images + results = self.forward_geometry(svd_volume, data['w2c'], self.opt.infer_render_size) + + + # always use white bg + bg_color = torch.ones(3, dtype=torch.float32).to(device) + + + pred_shading = results['image'] # [B, V, C, output_size, output_size] + pred_alphas = results['mask'] # [B, V, 1, output_size, output_size] + pred_albedos = results['albedo'] # [B, V, C, output_size, output_size] + + pred_images=pred_shading*pred_albedos + + results['images_pred'] = pred_images + results['alphas_pred'] = pred_alphas + results['pred_albedos'] = pred_albedos + results['pred_shading'] = pred_shading + + return results + + def extract_mesh( + self, + planes: torch.Tensor, + use_texture_map: bool = False, + texture_resolution: int = 1024, + **kwargs, + ): + ''' + Extract a 3D mesh from FlexiCubes. Only support batch_size 1. + :param planes: triplane features + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texure map + ''' + assert planes['app_planes'].shape[0] == 1 + device = planes['app_planes'].device + + + # predict geometry first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + vertices, faces = mesh_v[0], mesh_f[0] + + if not use_texture_map: + # query vertex colors + vertices_tensor = vertices.unsqueeze(0) + rgb_colors = self.tensorRF.predict_color(planes, vertices_tensor)['rgb'].clamp(0, 1).squeeze(0).cpu().numpy() + rgb_colors = (rgb_colors * 255).astype(np.uint8) + + albedob_colors = self.tensorRF.predict_color(planes, vertices_tensor)['albedo'].clamp(0, 1).squeeze(0).cpu().numpy() + albedob_colors = (albedob_colors * 255).astype(np.uint8) + + shading_colors = self.tensorRF.predict_color(planes, vertices_tensor)['shading'].clamp(0, 1).squeeze(0).cpu().numpy() + shading_colors = (shading_colors * 255).astype(np.uint8) + + + return vertices.cpu().numpy(), faces.cpu().numpy(), [rgb_colors,albedob_colors,shading_colors] + + # use x-atlas to get uv mapping for the mesh + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution) + + tex_hard_mask = tex_hard_mask.float().cpu() + + # query the texture field to get the RGB color for texture map + #TBD here + query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3) + + vertices_colors = self.tensorRF.predict_color( + planes, query_vertices)['rgb'].squeeze(0).cpu() + + vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3) + + background_feature = torch.zeros_like(vertices_colors) + img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map + + diff --git a/core/modulate.py b/core/modulate.py new file mode 100644 index 0000000000000000000000000000000000000000..a15524d2de9c456575a054f29afaf780053c0f0f --- /dev/null +++ b/core/modulate.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class ModLN(nn.Module): + """ + Modulation with adaLN. + + References: + DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101 + """ + def __init__(self, inner_dim: int, mod_dim: int, eps: float): + super().__init__() + self.norm = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(mod_dim, inner_dim * 2), + ) + + @staticmethod + def modulate(x, shift, scale): + # x: [N, L, D] + # shift, scale: [N, D] + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D] + return self.modulate(self.norm(x), shift, scale) # [N, L, D] diff --git a/core/options.py b/core/options.py new file mode 100644 index 0000000000000000000000000000000000000000..9023dd387fe1a5b561e893c45cf8b60373dfb981 --- /dev/null +++ b/core/options.py @@ -0,0 +1,232 @@ +import tyro +from dataclasses import dataclass +from typing import Tuple, Literal, Dict, Optional + + +@dataclass +class Options: + seed: Optional[int] = None + is_crop: bool = True + is_fix_views: bool = False + specific_demo: Optional[str] = None + txt_or_image: Optional[bool] = False #True=text prompts + infer_render_size: int = 256 + mvdream_or_zero123: Optional[bool] = True # True for mvdream False for zero123plus + #true for rar + rar_data: bool = True + ### model + # Unet image input size + input_size: int = 512 + # Unet definition + down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) + down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) + mid_attention: bool = True + up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) + up_attention: Tuple[bool, ...] = (True, True, True, False) + # Unet output size, dependent on the input_size and U-Net structure! + splat_size: int = 64 + # svd render size + output_size: Optional[int] = 128 + + #for tensor + density_n_comp: int = 8 + app_n_comp: int = 32 + app_dim: int = 27 + density_dim: int = 8 + shadingMode: Literal['MLP_Fea']='MLP_Fea' #'MLP_Fea' + view_pe: int = 2 + fea_pe: int = 2 + pos_pe: int = 6 + # points number sampled per ray + n_sample: int = 64 + + # model type TRF for vsd+nerf TRF_GS for vsd+gs TRI_GS for tri+gs + volume_mode: Literal['TRF_Mesh','TRF_NeRF'] = 'TRF_NeRF' + + + # for LRM_Net + camera_embed_dim: int=1024 + transformer_dim: int=1024 + transformer_layers: int=16 + transformer_heads: int=16 + triplane_low_res: int=32 + triplane_high_res: int=64 + triplane_dim: int=32 + encoder_type: str ='dinov2' + encoder_model_name: str = 'dinov2_vitb14_reg'#'dinov2_vits14_reg' #'dinov2_vitb14_reg' + encoder_feat_dim: int = 768 #768 + encoder_freeze: bool = False + + #training + over_fit: Optional[bool] = False + is_grid_sample: bool = False + + ### dataset + # data mode (only support s3 now) + data_mode: Literal['s3','s4','s5'] = 's4' + data_path: str = 'train_data' + data_debug_list: str = 'dataset_debug/gobj_merged_debug.json' + data_list_path: str = 'dataset_debug/gobj_merged_debug_selected.json' #dataset_debug/gobj_merged_debug.json' + # fovy of the dataset + fovy: float = 39.6 #49.1 + # camera near plane + znear: float = 0.5 + # camera far plane + zfar: float = 2.5 + # number of all views (input + output) + num_views: int = 12 + # number of views + num_input_views: int = 4 + # camera radius + cam_radius: float = 1.5 # to better use [-1, 1]^3 space + # num workers + num_workers: int = 8 #8 + # 是否考虑单个视角的view + training_view_plane: bool = False + is_certainty: bool = False + + ### training + # workspace + workspace: str = './workspace_test' + # resume + resume: Optional[str] = None + ckpt_nerf: Optional[str] = None + # batch size (per-GPU) + batch_size: int = 8 + # gradient accumulation + gradient_accumulation_steps: Optional[int] = 1 + # training epochs + num_epochs: int = 50 + # lpips loss weight + lambda_lpips: float = 2.0 + # gradient clip + gradient_clip: float = 1.0 + # mixed precision + mixed_precision: str = 'bf16' + # learning rate + lr: Optional[float] = 4e-4 + lr_scheduler: str = 'OneCycleLR' + warmup_real_iters: int = 3000 + + # augmentation prob for grid distortion + prob_grid_distortion: float = 0.5 + # augmentation prob for camera jitter + prob_cam_jitter: float = 0.5 + + ### testing + # test image path + test_path: Optional[str] = None + + ### misc + # nvdiffrast backend setting + force_cuda_rast: bool = False + # render fancy video with gaussian scaling effect + fancy_video: bool = False + + +# all the default settings +config_defaults: Dict[str, Options] = {} +config_doc: Dict[str, str] = {} + +config_doc['lrm'] = 'the default settings for LGM' +config_defaults['lrm'] = Options() + +config_doc['small'] = 'small model with lower resolution Gaussians' +config_defaults['small'] = Options( + input_size=256, + splat_size=64, + output_size=256, + batch_size=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['big'] = 'big model with higher resolution Gaussians' +config_defaults['big'] = Options( + input_size=256, + up_channels=(1024, 1024, 512, 256, 128), # one more decoder + up_attention=(True, True, True, False, False), + splat_size=128, + output_size=512, # render & supervise Gaussians at a higher resolution. + batch_size=8, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + + +config_doc['tiny_trf_trans_mesh'] = 'tiny model for ablation' +config_defaults['tiny_trf_trans_mesh'] = Options( + input_size=512, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + volume_mode='TRF_Mesh', + # ckpt_nerf='workspace_debug/0428_02/last.ckpt', + splat_size=64, + output_size=512, + data_mode='s6', + batch_size=1, #8 + num_views=8, + gradient_accumulation_steps=1, #2 + mixed_precision='no', +) + +config_doc['tiny_trf_trans_nerf'] = 'tiny model for ablation' +config_defaults['tiny_trf_trans_nerf'] = Options( + input_size=512, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + volume_mode='TRF_NeRF', + splat_size=64, + output_size=62, #crop patch + data_mode='s5', + batch_size=4, #8 + num_views=8, + gradient_accumulation_steps=1, #2 + mixed_precision='bf16', +) + +config_doc['tiny_trf_trans_nerf_123plus'] = 'tiny model for ablation' +config_defaults['tiny_trf_trans_nerf_123plus'] = Options( + input_size=512, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + volume_mode='TRF_NeRF', + splat_size=64, + output_size=116, #crop patch + data_mode='s5', + mvdream_or_zero123=False, + batch_size=1, #8 + num_views=10, + num_input_views=6, + gradient_accumulation_steps=1, #2 + mixed_precision='bf16', +) + + +config_doc['tiny_trf_trans_nerf_nocrop'] = 'tiny model for ablation' +config_defaults['tiny_trf_trans_nerf_nocrop'] = Options( + input_size=512, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + volume_mode='TRF_NeRF', + splat_size=64, + output_size=62, #crop patch + data_mode='s5', + batch_size=4, #8 + is_crop=False, + num_views=8, + gradient_accumulation_steps=1, #2 + mixed_precision='bf16', +) + + +AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) diff --git a/core/provider_gobjaverse_crop.py b/core/provider_gobjaverse_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7ec55aaa0497e72a9c0016a4b14a38c2d2f5f6 --- /dev/null +++ b/core/provider_gobjaverse_crop.py @@ -0,0 +1,392 @@ +import os +import cv2 +import random +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset +from PIL import Image +import json +from torchvision.transforms import v2 +import tarfile + +import kiui +from core.options import Options +from core.utils import get_rays, grid_distortion, orbit_camera_jitter + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" + +class GobjaverseDataset(Dataset): + + def _warn(self): + raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') + + def __init__(self, opt: Options, training=True): + + self.total_epoch = 30 + self.cur_epoch = 0 + self.cur_itrs = 0 + + # 不切片的比例,原始尺寸可以保持稳定训练 + self.original_scale = 0.1 + self.bata_line_scale = self.original_scale * 0.5 + self.beta_line_ites = 3000 + + self.opt = opt + self.training = training + + if opt.over_fit: + data_list_path=opt.data_debug_list + else: + data_list_path=opt.data_list_path + + # TODO: load the list of objects for training + self.items = [] + with open(data_list_path, 'r') as f: + data = json.load(f) + for item in data: + self.items.append(item) + + # naive split + if not opt.over_fit: + if self.training: + self.items = self.items[:-self.opt.batch_size] + else: + self.items = self.items[-self.opt.batch_size:] + else: + self.opt.batch_size=len(self.items) + self.opt.num_workers=0 + + # default camera intrinsics + self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) + self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) + self.proj_matrix[2, 3] = 1 + + + def __len__(self): + return len(self.items) + + def get_random_crop(self, batch_masks, minsize): + n, h, w = batch_masks.shape + # 初始化一个全为-1的张量,用于存储随机裁剪区域的左上角坐标 + crop_topleft_points = torch.full((n, 4), -1, dtype=torch.int) + + for i, mask in enumerate(batch_masks): + # 获取非零坐标 + nonzero_coords = torch.nonzero(mask, as_tuple=False) + if nonzero_coords.size(0) == 0: + crop_topleft_points[i] = torch.tensor([0, 0, minsize, minsize]) + continue # 如果没有非零元素,保留初始化时的-1值 + # 计算最小和最大坐标 + min_coords = torch.min(nonzero_coords, dim=0)[0] + max_coords = torch.max(nonzero_coords, dim=0)[0] + y_min, x_min = min_coords + y_max, x_max = max_coords + + # 确保包围盒不小于 minsize * minsize + y_center = (y_min + y_max) // 2 + x_center = (x_min + x_max) // 2 + + y_min = max(0, y_center - (minsize // 2)) + y_max = min(h - 1, y_center + (minsize // 2)) + x_min = max(0, x_center - (minsize // 2)) + x_max = min(w - 1, x_center + (minsize // 2)) + + # 如果计算后仍然小于 minsize,则调整 + if (y_max - y_min + 1) < minsize: + y_min = max(0, y_max - minsize + 1) + y_max = y_min + minsize - 1 + if (x_max - x_min + 1) < minsize: + x_min = max(0, x_max - minsize + 1) + x_max = x_min + minsize - 1 + + # 随机选择左上角点 + top_y = torch.randint(y_min, y_max - minsize + 2, (1,)).item() # 确保裁剪区域在包围盒内 + top_x = torch.randint(x_min, x_max - minsize + 2, (1,)).item() + + crop_topleft_points[i] = torch.tensor([top_x, top_y, minsize, minsize]) + + return crop_topleft_points + + def __getitem__(self, idx): + + uid = self.items[idx] + results = {} + + # load num_views images + images = [] + albedos = [] + normals = [] + depths = [] + masks = [] + cam_poses = [] + + vid_cnt = 0 + + # TODO: choose views, based on your rendering settings + if self.training: + if self.opt.is_fix_views: + if self.opt.mvdream_or_zero123: + vids = [0,30,12,36,27,6,33,18][:self.opt.num_input_views] + np.random.permutation(24).tolist() + else: + vids = [0,29,8,33,16,37,2,10,18,28][:self.opt.num_input_views] + np.random.permutation(24).tolist() + else: + vids = np.random.permutation(np.arange(0, 36))[:self.opt.num_input_views].tolist() + np.random.permutation(36).tolist() + + else: + #fixed views + # if self.opt.mvdream_or_zero123: + # vids = np.arange(0, 40, 6).tolist() + np.arange(100).tolist() + # else: + # vids = np.arange(0, 40, 4).tolist() + np.arange(100).tolist() + if self.opt.mvdream_or_zero123: + vids = [0,30,12,36,27,6,33,18]#np.arange(0, 24, 6).tolist() + np.arange(27, 40, 3).tolist() + else: + vids = [0,29,8,33,16,37,2,10,18,28] + + for vid in vids: + + + #try: + uid_last = uid.split('/')[1] + + if self.opt.rar_data: + tar_path = os.path.join(self.opt.data_path, f"{uid}.tar") + image_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}.png") + meta_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}.json") + albedo_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_albedo.png") # black bg... + # mr_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_mr.png") + nd_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_nd.exr") + + with tarfile.open(tar_path, 'r') as tar: + with tar.extractfile(image_path) as f: + image = np.frombuffer(f.read(), np.uint8) + with tar.extractfile(albedo_path) as f: + albedo = np.frombuffer(f.read(), np.uint8) + with tar.extractfile(meta_path) as f: + meta = json.loads(f.read().decode()) + with tar.extractfile(nd_path) as f: + nd = np.frombuffer(f.read(), np.uint8) + + image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + albedo = torch.from_numpy(cv2.imdecode(albedo, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + else: + image_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}.png") + meta_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}.json") + # albedo_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_albedo.png") # black bg... + # mr_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_mr.png") + nd_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}_nd.exr") + + albedo_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}_albedo.png") + + # 读取图片并转换为np.uint8类型的数组 + with open(image_path, 'rb') as f: + image = np.frombuffer(f.read(), dtype=np.uint8) + + with open(albedo_path, 'rb') as f: + albedo = np.frombuffer(f.read(), dtype=np.uint8) + + # 读取JSON文件作为元数据 + with open(meta_path, 'r') as f: + meta = json.load(f) + + # 读取图片并转换为np.uint8类型的数组 + with open(nd_path, 'rb') as f: + nd = np.frombuffer(f.read(), np.uint8) + + image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + albedo = torch.from_numpy(cv2.imdecode(albedo, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) + + c2w = np.eye(4) + c2w[:3, 0] = np.array(meta['x']) + c2w[:3, 1] = np.array(meta['y']) + c2w[:3, 2] = np.array(meta['z']) + c2w[:3, 3] = np.array(meta['origin']) + c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) + + nd = cv2.imdecode(nd, cv2.IMREAD_UNCHANGED).astype(np.float32) # [512, 512, 4] in [-1, 1] + normal = nd[..., :3] # in [-1, 1], bg is [0, 0, 1] + depth = nd[..., 3] # in [0, +?), bg is 0 + + # rectify normal directions + normal = normal[..., ::-1] + normal[..., 0] *= -1 + normal = torch.from_numpy(normal.astype(np.float32)).nan_to_num_(0) # there are nans in gt normal... + depth = torch.from_numpy(depth.astype(np.float32)).nan_to_num_(0) + + # except Exception as e: + # # print(f'[WARN] dataset {uid} {vid}: {e}') + # continue + + # blender world + opencv cam --> opengl world & cam + c2w[1] *= -1 + c2w[[1, 2]] = c2w[[2, 1]] + c2w[:3, 1:3] *= -1 # invert up and forward direction + + image = image.permute(2, 0, 1) # [4, 512, 512] + mask = image[3:4] # [1, 512, 512] + image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg + image = image[[2,1,0]].contiguous() # bgr to rgb + + # albdeo + albedo = albedo.permute(2, 0, 1) # [4, 512, 512] + albedo = albedo[:3] * mask + (1 - mask) # [3, 512, 512], to white bg + albedo = albedo[[2,1,0]].contiguous() # bgr to rgb + + + normal = normal.permute(2, 0, 1) # [3, 512, 512] + normal = normal * mask # to [0, 0, 0] bg + + images.append(image) + albedos.append(albedo) + normals.append(normal) + depths.append(depth) + masks.append(mask.squeeze(0)) + cam_poses.append(c2w) + + vid_cnt += 1 + if vid_cnt == self.opt.num_views: + break + + if vid_cnt < self.opt.num_views: + print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') + n = self.opt.num_views - vid_cnt + images = images + [images[-1]] * n + normals = normals + [normals[-1]] * n + depths = depths + [depths[-1]] * n + masks = masks + [masks[-1]] * n + cam_poses = cam_poses + [cam_poses[-1]] * n + + images = torch.stack(images, dim=0) # [V, 3, H, W] + albedos = torch.stack(albedos, dim=0) # [V, 3, H, W] + normals = torch.stack(normals, dim=0) # [V, 3, H, W] + depths = torch.stack(depths, dim=0) # [V, H, W] + masks = torch.stack(masks, dim=0) # [V, H, W] + cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4] + + # normalized camera feats as in paper (transform the first pose to a fixed position) + radius = torch.norm(cam_poses[0, :3, 3]) + cam_poses[:, :3, 3] *= self.opt.cam_radius / radius + transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) + cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4] + cam_poses_input = cam_poses[:self.opt.num_input_views].clone() + + # 模拟的设定input size,原图512可以模拟输入320 + images = F.interpolate(images, size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W] + albedos = F.interpolate(albedos, size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) + + # increase_size= np.maximum((self.cur_epoch/self.total_epoch-self.original_scale),0)/(1-self.original_scale) * (self.opt.input_size-self.opt.output_size) + # max_scale_input_size = int(self.opt.output_size + increase_size) + + if self.opt.is_crop and self.training: + #max_scale_input_size=self.opt.input_size + increase_size= np.maximum((self.cur_epoch/self.total_epoch-self.original_scale),0)/(1-self.original_scale) * (self.opt.input_size-self.opt.output_size) + increase_size= np.maximum(self.opt.output_size*0.5,increase_size) + max_scale_input_size = int(self.opt.output_size + increase_size) + else: + max_scale_input_size=self.opt.output_size + + # random crop, 先随机一个目标尺寸,再从中裁剪一块固定尺寸作为目标 + if max_scale_input_size > self.opt.output_size: + scaled_input_size = np.random.randint(self.opt.output_size, max_scale_input_size+1) + else: + scaled_input_size = self.opt.output_size + + target_images = v2.functional.resize( + images, scaled_input_size, interpolation=3, antialias=True).clamp(0, 1) + + target_albedos = v2.functional.resize( + albedos, scaled_input_size, interpolation=3, antialias=True).clamp(0, 1) + # target_depths = v2.functional.resize( + # target_depths, render_size, interpolation=0, antialias=True) + target_alphas = v2.functional.resize( + masks.unsqueeze(1), scaled_input_size, interpolation=0, antialias=True) + + # crop_params = v2.RandomCrop.get_params( + # target_images, output_size=(self.opt.output_size, self.opt.output_size)) + + # 拿 mask的包围盒,并且保证包围盒大于crop patch + crop_params = self.get_random_crop(target_alphas[:,0], self.opt.output_size ) + + target_images = torch.stack([v2.functional.crop(target_images[i], *crop_params[i]) for i in range(target_images.shape[0])],0) + target_albedos = torch.stack([v2.functional.crop(target_albedos[i], *crop_params[i]) for i in range(target_albedos.shape[0])],0) + target_alphas = torch.stack([v2.functional.crop(target_alphas[i], *crop_params[i]) for i in range(target_alphas.shape[0])],0) + + #target gt + results['images_output']=target_images + results['albedos_output']=target_albedos + results['masks_output']=target_alphas + #bake sdf bata schedule + #results['t']=torch.tensor(self.cur_epoch/(self.opt.num_epochs*self.bata_line_scale), dtype=torch.float32).clamp(0, 1) + #results['t']=torch.tensor(self.cur_itrs/self.beta_line_ites, dtype=torch.float32).clamp(0, 1) + + # data augmentation condition input image + images_input = images[:self.opt.num_input_views].clone() + if self.training: + # apply random grid distortion to simulate 3D inconsistency + if random.random() < self.opt.prob_grid_distortion: + images_input[1:] = grid_distortion(images_input[1:]) + # apply camera jittering (only to input!) + if random.random() < self.opt.prob_cam_jitter: + cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) + #images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + results['input']=images_input #input view images, unused for tranformer based + #results['input'] = None # for gs based mesh + + #for transformer hard code size + images_input_vit = F.interpolate(images_input, size=(224, 224), mode='bilinear', align_corners=False) + #images_input_vit = TF.normalize(images_input_vit, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + results['input_vit']=images_input_vit + + #if self.opt.volume_mode=='TRF': + all_rays_o=[] + all_rays_d=[] + for i in range(vid_cnt): + rays_o, rays_d = get_rays(cam_poses[i], scaled_input_size, scaled_input_size, self.opt.fovy) # [h, w, 3] + all_rays_o.append(rays_o) + all_rays_d.append(rays_d) + all_rays_o=torch.stack(all_rays_o, dim=0) + all_rays_d=torch.stack(all_rays_d, dim=0) + + if crop_params is not None: + all_rays_o_crop=[] + all_rays_d_crop=[] + for k in range(all_rays_o.shape[0]): + i, j, h, w = crop_params[k] + all_rays_o_crop.append(all_rays_o[k][i:i+h, j:j+w, :]) + all_rays_d_crop.append(all_rays_d[k][i:i+h, j:j+w, :]) + + all_rays_o=torch.stack(all_rays_o_crop, dim=0) + all_rays_d=torch.stack(all_rays_d_crop, dim=0) + + results['all_rays_o']=all_rays_o + results['all_rays_d']=all_rays_d + + # 相机外参,c2w + # opengl to colmap camera for gaussian renderer + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # c2w的逆,w2c*投影内参,等于mvp矩阵 + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] #相机位子 + + results['cam_view'] = cam_view + results['cam_view_proj'] = cam_view_proj + results['cam_pos'] = cam_pos + + #lrm用的是内参和外参的混合,这里先直接用外参试下, 实验可行 + results['source_camera']=cam_poses_input + + return results \ No newline at end of file diff --git a/core/provider_gobjaverse_mesh.py b/core/provider_gobjaverse_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2aadba79567a1fe8eb498682953bc1fdf3f08214 --- /dev/null +++ b/core/provider_gobjaverse_mesh.py @@ -0,0 +1,280 @@ +import os +import cv2 +import random +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset +from PIL import Image +import json +from torchvision.transforms import v2 +import tarfile + +import kiui +from core.options import Options +from core.utils import get_rays, grid_distortion, orbit_camera_jitter + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" + +class GobjaverseDataset(Dataset): + + def _warn(self): + raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') + + def __init__(self, opt: Options, training=True): + + self.opt = opt + self.training = training + + if opt.over_fit: + data_list_path=opt.data_debug_list + else: + data_list_path=opt.data_list_path + + # TODO: load the list of objects for training + self.items = [] + with open(data_list_path, 'r') as f: + data = json.load(f) + for item in data: + self.items.append(item) + + # naive split + if not opt.over_fit: + if self.training: + self.items = self.items[:-self.opt.batch_size] + else: + self.items = self.items[-self.opt.batch_size:] + else: + self.opt.batch_size=len(self.items) + self.opt.num_workers=0 + + # default camera intrinsics + self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) + self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) + self.proj_matrix[2, 3] = 1 + + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + + uid = self.items[idx] + results = {} + + # load num_views images + images = [] + albedos = [] + normals = [] + depths = [] + masks = [] + cam_poses = [] + + vid_cnt = 0 + + # TODO: choose views, based on your rendering settings + if self.training: + # input views are in (36, 72), other views are randomly selected + if self.opt.mvdream_or_zero123: + vids = [0,30,12,36,27,6,33,18][:self.opt.num_input_views] + np.random.permutation(24).tolist() + else: + vids = [0,29,8,33,16,37,2,10][:self.opt.num_input_views] + np.random.permutation(24).tolist() + else: + # fixed views + if self.opt.mvdream_or_zero123: + vids = [0,30,12,36,27,6,33,18]#np.arange(0, 24, 6).tolist() + np.arange(27, 40, 3).tolist() + else: + vids = [0,29,8,33,16,37,2,10,18,28] + + + for vid in vids: + + + #try: + uid_last = uid.split('/')[1] + + if self.opt.rar_data: + tar_path = os.path.join(self.opt.data_path, f"{uid}.tar") + image_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}.png") + meta_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}.json") + albedo_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_albedo.png") # black bg... + # mr_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_mr.png") + nd_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_nd.exr") + + with tarfile.open(tar_path, 'r') as tar: + with tar.extractfile(image_path) as f: + image = np.frombuffer(f.read(), np.uint8) + with tar.extractfile(albedo_path) as f: + albedo = np.frombuffer(f.read(), np.uint8) + with tar.extractfile(meta_path) as f: + meta = json.loads(f.read().decode()) + with tar.extractfile(nd_path) as f: + nd = np.frombuffer(f.read(), np.uint8) + + image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + albedo = torch.from_numpy(cv2.imdecode(albedo, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + else: + image_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}.png") + meta_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}.json") + # albedo_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_albedo.png") # black bg... + # mr_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_mr.png") + nd_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}_nd.exr") + + albedo_path = os.path.join(self.opt.data_path,uid, f"{vid:05d}/{vid:05d}_albedo.png") + + # 读取图片并转换为np.uint8类型的数组 + with open(image_path, 'rb') as f: + image = np.frombuffer(f.read(), dtype=np.uint8) + + with open(albedo_path, 'rb') as f: + albedo = np.frombuffer(f.read(), dtype=np.uint8) + + # 读取JSON文件作为元数据 + with open(meta_path, 'r') as f: + meta = json.load(f) + + # 读取图片并转换为np.uint8类型的数组 + with open(nd_path, 'rb') as f: + nd = np.frombuffer(f.read(), np.uint8) + + image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] + albedo = torch.from_numpy(cv2.imdecode(albedo, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) + + c2w = np.eye(4) + c2w[:3, 0] = np.array(meta['x']) + c2w[:3, 1] = np.array(meta['y']) + c2w[:3, 2] = np.array(meta['z']) + c2w[:3, 3] = np.array(meta['origin']) + c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) + + nd = cv2.imdecode(nd, cv2.IMREAD_UNCHANGED).astype(np.float32) # [512, 512, 4] in [-1, 1] + normal = nd[..., :3] # in [-1, 1], bg is [0, 0, 1] + depth = nd[..., 3] # in [0, +?), bg is 0 + + # rectify normal directions + normal = normal[..., ::-1] + normal[..., 0] *= -1 + normal = torch.from_numpy(normal.astype(np.float32)).nan_to_num_(0) # there are nans in gt normal... + depth = torch.from_numpy(depth.astype(np.float32)).nan_to_num_(0) + + # except Exception as e: + # # print(f'[WARN] dataset {uid} {vid}: {e}') + # continue + + # blender world + opencv cam --> opengl world & cam + # world transform, 只要坐标系手系相同,不转不影响画图,会影响normal的着色 + c2w[1] *= -1 + c2w[[1, 2]] = c2w[[2, 1]] + + # cam transform + c2w[:3, 1:3] *= -1 # invert up and forward direction + + image = image.permute(2, 0, 1) # [4, 512, 512] + mask = image[3:4] # [1, 512, 512] + + image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg + + image = image[[2,1,0]].contiguous() # bgr to rgb + + # albdeo + albedo = albedo.permute(2, 0, 1) # [4, 512, 512] + albedo = albedo[:3] * mask + (1 - mask) # [3, 512, 512], to white bg + albedo = albedo[[2,1,0]].contiguous() # bgr to rgb + + normal = normal.permute(2, 0, 1) # [3, 512, 512] + normal = normal * mask # to [0, 0, 0] bg + + images.append(image) + albedos.append(albedo) + normals.append(normal) + depths.append(depth) + masks.append(mask.squeeze(0)) + cam_poses.append(c2w) + + vid_cnt += 1 + if vid_cnt == self.opt.num_views: + break + + if vid_cnt < self.opt.num_views: + print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') + n = self.opt.num_views - vid_cnt + images = images + [images[-1]] * n + normals = normals + [normals[-1]] * n + depths = depths + [depths[-1]] * n + masks = masks + [masks[-1]] * n + cam_poses = cam_poses + [cam_poses[-1]] * n + + images = torch.stack(images, dim=0) # [V, 3, H, W] + albedos = torch.stack(albedos, dim=0) # [V, 3, H, W] + normals = torch.stack(normals, dim=0) # [V, 3, H, W] + depths = torch.stack(depths, dim=0) # [V, H, W] + masks = torch.stack(masks, dim=0) # [V, H, W] + cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4] + + # normalized camera feats as in paper (transform the first pose to a fixed position) + radius = torch.norm(cam_poses[0, :3, 3]) + cam_poses[:, :3, 3] *= self.opt.cam_radius / radius + transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) + cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4] + cam_poses_input = cam_poses[:self.opt.num_input_views].clone() + + # 模拟的设定input size,原图512可以模拟输入320 + images = F.interpolate(images, size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W] + albedos = F.interpolate(albedos, size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) + + target_images = v2.functional.resize( + images, self.opt.output_size, interpolation=3, antialias=True).clamp(0, 1) + + target_albedos = v2.functional.resize( + albedos, self.opt.output_size, interpolation=3, antialias=True).clamp(0, 1) + # target_depths = v2.functional.resize( + # target_depths, render_size, interpolation=0, antialias=True) + target_alphas = v2.functional.resize( + masks.unsqueeze(1), self.opt.output_size, interpolation=0, antialias=True) + + + #target gt + results['images_output']=target_images + results['albedos_output']=target_albedos + results['masks_output']=target_alphas + + # data augmentation condition input image + images_input = images[:self.opt.num_input_views].clone() + if self.training: + # apply random grid distortion to simulate 3D inconsistency + if random.random() < self.opt.prob_grid_distortion: + images_input[1:] = grid_distortion(images_input[1:]) + # apply camera jittering (only to input!) + if random.random() < self.opt.prob_cam_jitter: + cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) + #images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + results['input']=images_input #input view images, unused for tranformer based + #results['input'] = None # for gs based mesh + + #for transformer hard code size + images_input_vit = F.interpolate(images_input, size=(224, 224), mode='bilinear', align_corners=False) + #images_input_vit = TF.normalize(images_input_vit, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + results['input_vit']=images_input_vit + + cam_view = torch.inverse(cam_poses)#.transpose(1, 2) #w2c + + cam_pos = - cam_poses[:, :3, 3] + + results['w2c'] = cam_view + + results['cam_pos'] = cam_pos + + #lrm用的是内参和外参的混合,这里先直接用外参试下, 实验可行 + results['source_camera']=cam_poses_input + + return results \ No newline at end of file diff --git a/core/scheduler.py b/core/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..47f33288595d55947df057f9cd9345dfd6a5646d --- /dev/null +++ b/core/scheduler.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from torch.optim.lr_scheduler import LRScheduler +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +class CosineWarmupScheduler(LRScheduler): + def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): + self.warmup_iters = warmup_iters + self.max_iters = max_iters + self.initial_lr = initial_lr + super().__init__(optimizer, last_iter) + + def get_lr(self): + logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}") + if self._step_count <= self.warmup_iters: + return [ + self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters + for base_lr in self.base_lrs] + else: + cos_iter = self._step_count - self.warmup_iters + cos_max_iter = self.max_iters - self.warmup_iters + cos_theta = cos_iter / cos_max_iter * math.pi + cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] + return cos_lr diff --git a/core/sh.py b/core/sh.py new file mode 100644 index 0000000000000000000000000000000000000000..b17429dd3328bd7a72c2912fcc5e8e4fc6272b2f --- /dev/null +++ b/core/sh.py @@ -0,0 +1,133 @@ +import torch + +################## sh function ################## +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + :param deg: int SH max degree. Currently, 0-4 supported + :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) + :param dirs: torch.Tensor unit directions (..., 3) + :return: (..., C) + """ + assert deg <= 4 and deg >= 0 + assert (deg + 1) ** 2 == sh.shape[-1] + C = sh.shape[-2] + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def eval_sh_bases(deg, dirs): + """ + Evaluate spherical harmonics bases at unit directions, + without taking linear combination. + At each point, the final result may the be + obtained through simple multiplication. + :param deg: int SH max degree. Currently, 0-4 supported + :param dirs: torch.Tensor (..., 3) unit directions + :return: torch.Tensor (..., (deg+1) ** 2) + """ + assert deg <= 4 and deg >= 0 + result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) + result[..., 0] = C0 + if deg > 0: + x, y, z = dirs.unbind(-1) + result[..., 1] = -C1 * y; + result[..., 2] = C1 * z; + result[..., 3] = -C1 * x; + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result[..., 4] = C2[0] * xy; + result[..., 5] = C2[1] * yz; + result[..., 6] = C2[2] * (2.0 * zz - xx - yy); + result[..., 7] = C2[3] * xz; + result[..., 8] = C2[4] * (xx - yy); + + if deg > 2: + result[..., 9] = C3[0] * y * (3 * xx - yy); + result[..., 10] = C3[1] * xy * z; + result[..., 11] = C3[2] * y * (4 * zz - xx - yy); + result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); + result[..., 13] = C3[4] * x * (4 * zz - xx - yy); + result[..., 14] = C3[5] * z * (xx - yy); + result[..., 15] = C3[6] * x * (xx - 3 * yy); + + if deg > 3: + result[..., 16] = C4[0] * xy * (xx - yy); + result[..., 17] = C4[1] * yz * (3 * xx - yy); + result[..., 18] = C4[2] * xy * (7 * zz - 1); + result[..., 19] = C4[3] * yz * (7 * zz - 3); + result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); + result[..., 21] = C4[5] * xz * (7 * zz - 3); + result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); + result[..., 23] = C4[7] * xz * (xx - 3 * yy); + result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); + return result diff --git a/core/tensoRF.py b/core/tensoRF.py new file mode 100644 index 0000000000000000000000000000000000000000..d347586b69d19a061ca6c362f9cd1ce75d80b915 --- /dev/null +++ b/core/tensoRF.py @@ -0,0 +1,485 @@ +from .tensorBase import * +import torch.nn as nn +import itertools + +class Density(nn.Module): + def __init__(self, params_init={}): + super().__init__() + for p in params_init: + param = nn.Parameter(torch.tensor(params_init[p])) + setattr(self, p, param) + + # self.beta0=0.1 + # self.beta1=0.001 + # self.beta=self.beta0 + + def forward(self, sdf, beta=None): + return self.density_func(sdf, beta=beta) + + +class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) + #params_init{ beta = 0.1 } beta_min = 0.0001 + def __init__(self, params_init={}, beta_min=0.0001): + super().__init__(params_init=params_init) + self.beta_min = torch.tensor(beta_min).cuda() + + def density_func(self, sdf, beta=None): + if beta is None: + beta = self.get_beta() + + alpha = 1 / beta + return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) + + def get_beta(self): + beta = self.beta.abs() + self.beta_min + return self.beta + + # t for 0-1 + def set_beta(self,t): + + self.beta = self.beta0 * (1 + ((self.beta0 - self.beta1) / self.beta1) * (t**0.8)) ** -1 + return self.beta + + + +class TensorVMSplit_Mesh(TensorBase): + def __init__(self, aabb, gridSize, **kargs): + super(TensorVMSplit_Mesh, self).__init__(aabb, gridSize, **kargs) + + hidden_dim = 64 + num_layers = 5 + activation = nn.ReLU + + n_comp=self.density_n_comp+self.app_n_comp + + self.decoder = nn.Sequential( + nn.Linear(n_comp*3, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 7), + ) + + # self.net_sdf = nn.Sequential( + # nn.Linear(n_comp*3, hidden_dim), + # activation(), + # *itertools.chain(*[[ + # nn.Linear(hidden_dim, hidden_dim), + # activation(), + # ] for _ in range(num_layers - 2)]), + # nn.Linear(hidden_dim, 1), + # ) + + hidden_dim_min = 64 + num_layers_min = 2 + + self.net_deformation = nn.Sequential( + nn.Linear(n_comp*3, hidden_dim_min), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim_min, hidden_dim_min), + activation(), + ] for _ in range(num_layers_min - 2)]), + nn.Linear(hidden_dim_min, 3), + ) + + self.net_weight = nn.Sequential( + nn.Linear(n_comp*3*8, hidden_dim_min), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim_min, hidden_dim_min), + activation(), + ] for _ in range(num_layers_min - 2)]), + nn.Linear(hidden_dim_min, 21), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def init_render_func(self,shadingMode, pos_pe, view_pe, fea_pe, featureC): + pass + + + def compute_densityfeature(self, xyz_sampled): + + B, N_point, _=xyz_sampled.shape + + # plane + line basis + coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2) + coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) + coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2) + + plane_coef_point,line_coef_point = [],[] + for idx_plane in range(3): + + density_plane=self.density_plane[:,idx_plane]#.contiguous() + density_line=self.density_line[:,idx_plane]#.contiguous() + + plane_coef_point.append(F.grid_sample(density_plane, coordinate_plane[idx_plane], + align_corners=True).view(B, -1, N_point)) + line_coef_point.append(F.grid_sample(density_line, coordinate_line[idx_plane], + align_corners=True).view(B, -1, N_point)) + + plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1) + plane_coef=plane_coef_point * line_coef_point + plane_coef=plane_coef.permute(0,2,1) + + result = torch.matmul(plane_coef, self.d_basis_mat) + + return result + + + def compute_appfeature(self, xyz_sampled): + + B, N_point, _=xyz_sampled.shape + # plane + line basis + coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2) + coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) + coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2) + + plane_coef_point,line_coef_point = [],[] + for idx_plane in range(3): + + app_plane=self.app_plane[:,idx_plane] + app_line=self.app_line[:,idx_plane] + + plane_coef_point.append(F.grid_sample(app_plane, coordinate_plane[idx_plane], + align_corners=True).view(B, -1, N_point)) + line_coef_point.append(F.grid_sample(app_line, coordinate_line[idx_plane], + align_corners=True).view(B, -1, N_point)) + plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1) + plane_coef=plane_coef_point * line_coef_point + plane_coef=plane_coef.permute(0,2,1) + + # result = torch.matmul(plane_coef, self.basis_mat) + + return plane_coef + + + def geometry_feature_decode(self, sampled_features, flexicubes_indices): + + sdf = self.decoder(sampled_features)[...,-1:] + deformation = self.net_deformation(sampled_features) + + grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) + grid_features = grid_features.reshape( + sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) + weight = self.net_weight(grid_features) * 0.1 + + return sdf, deformation, weight + + + def get_geometry_prediction(self, svd_volume, sample_coordinates, flexicubes_indices): + + self.svd_volume=svd_volume + self.app_plane=svd_volume['app_planes'] + self.app_line=svd_volume['app_lines'] + self.basis_mat=svd_volume['basis_mat'] + self.density_plane=svd_volume['density_planes'] + self.density_line=svd_volume['density_lines'] + self.d_basis_mat=svd_volume['d_basis_mat'] + + self.app_plane=torch.cat([self.app_plane,self.density_plane],2) + self.app_line=torch.cat([self.app_line,self.density_line],2) + + sampled_features = self.compute_appfeature(sample_coordinates) + + sdf, deformation, weight = self.geometry_feature_decode(sampled_features, flexicubes_indices) + + return sdf, deformation, weight + + def get_texture_prediction(self,texture_pos, vsd_vome=None):\ + + app_features = self.compute_appfeature(texture_pos) + + texture_rgb=self.decoder(app_features)[...,0:-1] + + texture_rgb = torch.sigmoid(texture_rgb)*(1 + 2*0.001) - 0.001 + + return texture_rgb + + + + def predict_color(self, svd_volume, xyz_sampled, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): + + self.svd_volume=svd_volume + self.app_plane=svd_volume['app_planes'] + self.app_line=svd_volume['app_lines'] + self.basis_mat=svd_volume['basis_mat'] + self.d_basis_mat=svd_volume['d_basis_mat'] + self.density_plane=svd_volume['density_planes'] + self.density_line=svd_volume['density_lines'] + + self.app_plane=torch.cat([self.app_plane,self.density_plane],2) + self.app_line=torch.cat([self.app_line,self.density_line],2) + + #xyz_sampled=xyz_sampled.unsqueeze(2) + + chunk_size: int = 2**20 + outs = [] + for i in range(0, xyz_sampled.shape[2], chunk_size): + xyz_sampled_chunk = self.normalize_coord(xyz_sampled[:,i:i+chunk_size]) + #xyz_sampled.requires_grad_(True) + + app_features = self.compute_appfeature(xyz_sampled_chunk) + + chunk_out = self.decoder(app_features)[...,0:-1] + + chunk_out = torch.sigmoid(chunk_out)*(1 + 2*0.001) - 0.001 + + rgbs = chunk_out.clamp(0,1) + outs.append(chunk_out) + + rgbs=torch.cat(outs,1) + + albedo=rgbs[:,:,3:6] + rgb=rgbs[:,:,0:3] + + results = { + 'shading':rgb, + 'albedo':albedo, + 'rgb':rgb*albedo, + } + return results # rgb, sigma, alpha, weight, bg_weight + + + + +# special nerf for mesh +class TensorVMSplit_NeRF(TensorBase): + def __init__(self, aabb, gridSize, **kargs): + super(TensorVMSplit_NeRF, self).__init__(aabb, gridSize, **kargs) + + hidden_dim = 64 + num_layers = 4 + activation = nn.ReLU + + self.lap_density = LaplaceDensity(params_init={ 'beta' : 0.1}) + + n_comp=self.density_n_comp+self.app_n_comp + + self.net_sdf = nn.Sequential( + nn.Linear(n_comp*3, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1), + ) + + self.decoder = nn.Sequential( + nn.Linear(n_comp*3, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 6), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def init_render_func(self,shadingMode, pos_pe, view_pe, fea_pe, featureC): + pass + + + def compute_densityfeature(self, xyz_sampled): + + B, N_pixel, N_sample, _=xyz_sampled.shape + + # plane + line basis + coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2) + coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) + coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2) + + plane_coef_point,line_coef_point = [],[] + for idx_plane in range(3): + + density_plane=self.density_plane[:,idx_plane]#.contiguous() + density_line=self.density_line[:,idx_plane]#.contiguous() + + plane_coef_point.append(F.grid_sample(density_plane, coordinate_plane[idx_plane], + align_corners=True).view(B, -1, N_pixel, N_sample)) + line_coef_point.append(F.grid_sample(density_line, coordinate_line[idx_plane], + align_corners=True).view(B, -1, N_pixel, N_sample)) + + plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1) + plane_coef=plane_coef_point * line_coef_point + plane_coef=plane_coef.permute(0,2,3,1) + + result = torch.matmul(plane_coef, self.d_basis_mat.unsqueeze(1)) + + return result + + + def compute_appfeature(self, xyz_sampled): + + B, N_pixel, N_sample, _=xyz_sampled.shape + # plane + line basis + coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, B, -1, 1, 2) + coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) + coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, B, -1, 1, 2) + + plane_coef_point,line_coef_point = [],[] + for idx_plane in range(3): + + app_plane=self.app_plane[:,idx_plane] + app_line=self.app_line[:,idx_plane] + + plane_coef_point.append(F.grid_sample(app_plane, coordinate_plane[idx_plane], + align_corners=True).view(B, -1, N_pixel, N_sample)) + line_coef_point.append(F.grid_sample(app_line, coordinate_line[idx_plane], + align_corners=True).view(B, -1, N_pixel, N_sample)) + plane_coef_point, line_coef_point = torch.cat(plane_coef_point,dim=1), torch.cat(line_coef_point,dim=1) + plane_coef=plane_coef_point * line_coef_point + plane_coef=plane_coef.permute(0,2,3,1) + + return plane_coef + + def forward(self, svd_volume, rays_o, rays_d, bg_color, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): + self.svd_volume=svd_volume + self.app_plane=svd_volume['app_planes'] + self.app_line=svd_volume['app_lines'] + self.basis_mat=svd_volume['basis_mat'] + self.d_basis_mat=svd_volume['d_basis_mat'] + self.density_plane=svd_volume['density_planes'] + self.density_line=svd_volume['density_lines'] + + self.app_plane=torch.cat([self.app_plane,self.density_plane],2) + self.app_line=torch.cat([self.app_line,self.density_line],2) + + B,V,H,W,_=rays_o.shape + rays_o=rays_o.reshape(B,-1, 3) + rays_d=rays_d.reshape(B,-1, 3) + if ndc_ray: + pass + else: + #B,H*W*V,sample_num,3 + xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_o, rays_d, is_train=is_train,N_samples=N_samples) + dists = torch.cat((z_vals[..., 1:] - z_vals[..., :-1], torch.zeros_like(z_vals[..., :1])), dim=-1) + rays_d = rays_d.unsqueeze(-2).expand(xyz_sampled.shape) + + xyz_sampled = self.normalize_coord(xyz_sampled) + + mix_feature = self.compute_appfeature(xyz_sampled) + + sdf = self.net_sdf(mix_feature) + + sigma= self.lap_density(sdf) + sigma=sigma[...,0] + alpha, weight, bg_weight = raw2alpha(sigma, dists) + + rgbs = self.decoder(mix_feature) + rgbs = torch.sigmoid(rgbs)*(1 + 2*0.001) - 0.001 + #rgb[app_mask] = valid_rgbs + + acc_map = torch.sum(weight, -1) + rgb_map = torch.sum(weight[..., None] * rgbs, -2) + + if white_bg or (is_train and torch.rand((1,))<0.5): + rgb_map = rgb_map + (1. - acc_map[..., None]) + + + rgb_map = rgb_map.clamp(0,1) + rgb_map=rgb_map.view(B,V,H,W,6).permute(0,1,4,2,3) + + albedo_map=rgb_map[:,:,3:6,:,:] + rgb_map=rgb_map[:,:,0:3,:,:] + + with torch.no_grad(): + depth_map = torch.sum(weight * z_vals, -1) + depth_map=depth_map.view(B,V,H,W,1).permute(0,1,4,2,3) + acc_map=acc_map.view(B,V,H,W,1).permute(0,1,4,2,3) + + results = { + 'image':rgb_map, + 'albedo':albedo_map, + 'alpha':acc_map, + 'depth_map':depth_map + } + + return results # rgb, sigma, alpha, weight, bg_weight + + + def predict_sdf(self, svd_volume, xyz_sampled, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): + + self.svd_volume=svd_volume + self.app_plane=svd_volume['app_planes'] + self.app_line=svd_volume['app_lines'] + self.basis_mat=svd_volume['basis_mat'] + self.d_basis_mat=svd_volume['d_basis_mat'] + self.density_plane=svd_volume['density_planes'] + self.density_line=svd_volume['density_lines'] + + self.app_plane=torch.cat([self.app_plane,self.density_plane],2) + self.app_line=torch.cat([self.app_line,self.density_line],2) + + chunk_size: int = 2**20 + outs = [] + for i in range(0, xyz_sampled.shape[1], chunk_size): + xyz_sampled_chunk = self.normalize_coord(xyz_sampled[:,i:i+chunk_size]).half() + + sigma_feature = self.compute_appfeature(xyz_sampled_chunk) + chunk_out = self.net_sdf(sigma_feature) + + outs.append(chunk_out) + sdf=torch.cat(outs,1) + results = { + 'sigma':sdf + } + return results # rgb, sigma, alpha, weight, bg_weight + + + def predict_color(self, svd_volume, xyz_sampled, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): + + self.svd_volume=svd_volume + self.app_plane=svd_volume['app_planes'] + self.app_line=svd_volume['app_lines'] + self.basis_mat=svd_volume['basis_mat'] + self.d_basis_mat=svd_volume['d_basis_mat'] + self.density_plane=svd_volume['density_planes'] + self.density_line=svd_volume['density_lines'] + + self.app_plane=torch.cat([self.app_plane,self.density_plane],2) + self.app_line=torch.cat([self.app_line,self.density_line],2) + + xyz_sampled=xyz_sampled.unsqueeze(2) + + chunk_size: int = 2**20 + outs = [] + for i in range(0, xyz_sampled.shape[2], chunk_size): + xyz_sampled_chunk = self.normalize_coord(xyz_sampled[:,i:i+chunk_size]).half() + #xyz_sampled.requires_grad_(True) + + app_features = self.compute_appfeature(xyz_sampled_chunk) + + chunk_out = self.decoder(app_features) + + chunk_out = torch.sigmoid(chunk_out)*(1 + 2*0.001) - 0.001 + + rgbs = chunk_out.clamp(0,1) + outs.append(chunk_out) + + rgbs=torch.cat(outs,1) + rgbs=rgbs[:,:,0,:] + + albedo=rgbs[:,:,3:6] + rgb=rgbs[:,:,0:3] + + results = { + 'shading':rgb, + 'albedo':albedo, + 'rgb':rgb*albedo, + } + return results # rgb, sigma, alpha, weight, bg_weight + + + \ No newline at end of file diff --git a/core/tensorBase.py b/core/tensorBase.py new file mode 100644 index 0000000000000000000000000000000000000000..a2660bc22fad5ff7d891b78c74cae7c9e0eb2475 --- /dev/null +++ b/core/tensorBase.py @@ -0,0 +1,410 @@ +import torch +import torch.nn +import torch.nn.functional as F +from .sh import eval_sh_bases +import numpy as np +import time + + +def get_ray_directions_blender(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], + -1) # (H, W, 3) + + return directions + + +def get_rays(directions, c2w): + """ + Get ray origin and normalized directions in world coordinate for all pixels in one image. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + directions: (H, W, 3) precomputed ray directions in camera coordinate + c2w: (3, 4) transformation matrix from camera coordinate to world coordinate + Outputs: + rays_o: (H*W, 3), the origin of the rays in world coordinate + rays_d: (H*W, 3), the normalized direction of the rays in world coordinate + """ + # Rotate ray directions from camera coordinate to the world coordinate + rays_d = directions @ c2w[:3, :3].T # (H, W, 3) + # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) + # The origin of all rays is the camera origin in world coordinate + rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) + + rays_d = rays_d.view(-1, 3) + rays_o = rays_o.view(-1, 3) + + return rays_o, rays_d + + +def positional_encoding(positions, freqs): + + freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) + pts = (positions[..., None] * freq_bands).reshape( + positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) + pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) + return pts + +def raw2alpha(sigma, dist): + # sigma, dist [N_rays, N_samples] + alpha = 1. - torch.exp(-sigma*dist) + + T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0],alpha.shape[1], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) + + weights = alpha * T[:,:, :-1] # [N_rays, N_samples] + return alpha, weights, T[:,:,-1:] + + +def SHRender(xyz_sampled, viewdirs, features): + sh_mult = eval_sh_bases(2, viewdirs)[:, None] + rgb_sh = features.view(-1, 3, sh_mult.shape[-1]) + rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) + return rgb + + +def RGBRender(xyz_sampled, viewdirs, features): + + rgb = features + return rgb + +class AlphaGridMask(torch.nn.Module): + def __init__(self, device, aabb, alpha_volume): + super(AlphaGridMask, self).__init__() + self.device = device + + self.aabb=aabb.to(self.device) + self.aabbSize = self.aabb[1] - self.aabb[0] + self.invgridSize = 1.0/self.aabbSize * 2 + self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:]) + self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device) + + def sample_alpha(self, xyz_sampled): + xyz_sampled = self.normalize_coord(xyz_sampled) + alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1) + + return alpha_vals + + def normalize_coord(self, xyz_sampled): + return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1 + + +class MLPRender_Fea(torch.nn.Module): + def __init__(self,inChanel, viewpe=6, feape=6, featureC=128): + super(MLPRender_Fea, self).__init__() + + self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel + self.viewpe = viewpe + self.feape = feape + layer1 = torch.nn.Linear(self.in_mlpC, featureC) + layer2 = torch.nn.Linear(featureC, featureC) + layer3 = torch.nn.Linear(featureC,3) + + self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) + torch.nn.init.constant_(self.mlp[-1].bias, 0) + + def forward(self, pts, viewdirs, features): + indata = [features, viewdirs] + if self.feape > 0: + indata += [positional_encoding(features, self.feape)] + if self.viewpe > 0: + indata += [positional_encoding(viewdirs, self.viewpe)] + mlp_in = torch.cat(indata, dim=-1) + rgb = self.mlp(mlp_in) + rgb = torch.sigmoid(rgb) + + return rgb + +class MLPRender_PE(torch.nn.Module): + def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128): + super(MLPRender_PE, self).__init__() + + self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # + self.viewpe = viewpe + self.pospe = pospe + layer1 = torch.nn.Linear(self.in_mlpC, featureC) + layer2 = torch.nn.Linear(featureC, featureC) + layer3 = torch.nn.Linear(featureC,3) + + self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) + torch.nn.init.constant_(self.mlp[-1].bias, 0) + + def forward(self, pts, viewdirs, features): + indata = [features, viewdirs] + if self.pospe > 0: + indata += [positional_encoding(pts, self.pospe)] + if self.viewpe > 0: + indata += [positional_encoding(viewdirs, self.viewpe)] + mlp_in = torch.cat(indata, dim=-1) + rgb = self.mlp(mlp_in) + rgb = torch.sigmoid(rgb) + + return rgb + +class MLPRender(torch.nn.Module): + def __init__(self,inChanel, viewpe=6, featureC=128): + super(MLPRender, self).__init__() + + self.in_mlpC = (3+2*viewpe*3) + inChanel + self.viewpe = viewpe + + layer1 = torch.nn.Linear(self.in_mlpC, featureC) + layer2 = torch.nn.Linear(featureC, featureC) + layer3 = torch.nn.Linear(featureC,3) + + self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) + torch.nn.init.constant_(self.mlp[-1].bias, 0) + + def forward(self, pts, viewdirs, features): + indata = [features, viewdirs] + if self.viewpe > 0: + indata += [positional_encoding(viewdirs, self.viewpe)] + mlp_in = torch.cat(indata, dim=-1) + rgb = self.mlp(mlp_in) + rgb = torch.sigmoid(rgb) + + return rgb + + + +class TensorBase(torch.nn.Module): + def __init__(self, aabb, gridSize, density_n_comp = 16, appearance_n_comp = 48, app_dim = 27, density_dim = 8, + shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], + density_shift = -10, alphaMask_thres=0.0001, distance_scale=25, rayMarch_weight_thres=0.0001, + pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=0.5, + fea2denseAct = 'softplus'): + super(TensorBase, self).__init__() + + self.density_n_comp = density_n_comp + self.app_n_comp = appearance_n_comp + self.app_dim = app_dim + self.density_dim=density_dim + self.aabb = aabb + self.alphaMask = alphaMask + #self.device=device + + self.density_shift = density_shift + self.alphaMask_thres = alphaMask_thres + self.distance_scale = distance_scale + self.rayMarch_weight_thres = rayMarch_weight_thres + self.fea2denseAct = fea2denseAct + + self.near_far = near_far + self.step_ratio = 0.9 #step_ratio原作0.5 + + + self.update_stepSize(gridSize) + + self.matMode = [[0,1], [0,2], [1,2]] + self.vecMode = [2, 1, 0] + self.comp_w = [1,1,1] + + + #self.init_svd_volume(gridSize[0], device) + + self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC + self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC) + + def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC): + if shadingMode == 'MLP_PE': + self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC) + elif shadingMode == 'MLP_Fea': + self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC) + elif shadingMode == 'MLP': + self.renderModule = MLPRender(self.app_dim, view_pe, featureC) + elif shadingMode == 'SH': + self.renderModule = SHRender + elif shadingMode == 'RGB': + assert self.app_dim == 3 + self.renderModule = RGBRender + else: + print("Unrecognized shading module") + exit() + print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe) + print(self.renderModule) + + def update_stepSize(self, gridSize): + print("aabb", self.aabb.view(-1)) + print("grid size", gridSize) + self.aabbSize = self.aabb[1] - self.aabb[0] + self.invaabbSize = 2.0/self.aabbSize + self.gridSize= gridSize.float() + self.units=self.aabbSize / (self.gridSize-1) + self.stepSize=torch.mean(self.units)*self.step_ratio # TBD step_ratio? why so small 0.5 + self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) + self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 + print("sampling step size: ", self.stepSize) + print("sampling number: ", self.nSamples) + + def init_svd_volume(self, res, device): + pass + + def compute_features(self, xyz_sampled): + pass + + def compute_densityfeature(self, xyz_sampled): + pass + + def compute_appfeature(self, xyz_sampled): + pass + + def normalize_coord(self, xyz_sampled): + return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1 + + def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001): + pass + + + def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): + N_samples = N_samples if N_samples > 0 else self.nSamples + near, far = self.near_far + interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) + if is_train: + interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) + + rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] + mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) + return rays_pts, interpx, ~mask_outbbox + + def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): + N_samples = N_samples if N_samples>0 else self.nSamples + stepsize = self.stepSize + near, far = self.near_far + vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d) + rate_a = (self.aabb[1] - rays_o) / vec + rate_b = (self.aabb[0] - rays_o) / vec + t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) + + rng = torch.arange(N_samples)[None,None].float() + if is_train: + rng = rng.repeat(rays_d.shape[-3],rays_d.shape[-2],1) + rng += torch.rand_like(rng[...,[0]]) + step = stepsize * rng.to(rays_o.device) + interpx = (t_min[...,None] + step) + + rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] + mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1) + + return rays_pts, interpx, ~mask_outbbox + + + def shrink(self, new_aabb, voxel_size): + pass + + @torch.no_grad() + def getDenseAlpha(self,gridSize=None): + gridSize = self.gridSize if gridSize is None else gridSize + + samples = torch.stack(torch.meshgrid( + torch.linspace(0, 1, gridSize[0]), + torch.linspace(0, 1, gridSize[1]), + torch.linspace(0, 1, gridSize[2]), + ), -1).to(self.device) + dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples + + # dense_xyz = dense_xyz + # print(self.stepSize, self.distance_scale*self.aabbDiag) + alpha = torch.zeros_like(dense_xyz[...,0]) + for i in range(gridSize[0]): + alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2])) + return alpha, dense_xyz + + + def feature2density(self, density_features): + if self.fea2denseAct == "softplus": + return F.softplus(density_features+self.density_shift) + elif self.fea2denseAct == "relu": + return F.relu(density_features) + + + def compute_alpha(self, xyz_locs, length=1): + + if self.alphaMask is not None: + alphas = self.alphaMask.sample_alpha(xyz_locs) + alpha_mask = alphas > 0 + else: + alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool) + + + sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) + + if alpha_mask.any(): + xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) + sigma_feature = self.compute_densityfeature(xyz_sampled) + validsigma = self.feature2density(sigma_feature) + sigma[alpha_mask] = validsigma + + + alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1]) + + return alpha + + + def forward(self, svd_volume, rays_o, rays_d, bg_color, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): + + self.svd_volume=svd_volume + self.app_plane=svd_volume['app_planes'] + self.app_line=svd_volume['app_lines'] + self.basis_mat=svd_volume['basis_mat'] + self.density_plane=svd_volume['density_planes'] + self.density_line=svd_volume['density_lines'] + + B,V,H,W,_=rays_o.shape + rays_o=rays_o.reshape(B,-1, 3) + rays_d=rays_d.reshape(B,-1, 3) + if ndc_ray: + pass + else: + #B,H*W*V,sample_num,3 + xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_o, rays_d, is_train=is_train,N_samples=N_samples) + dists = torch.cat((z_vals[..., 1:] - z_vals[..., :-1], torch.zeros_like(z_vals[..., :1])), dim=-1) + rays_d = rays_d.unsqueeze(-2).expand(xyz_sampled.shape) + + + xyz_sampled = self.normalize_coord(xyz_sampled) + sigma_feature = self.compute_densityfeature(xyz_sampled) + + sigma = self.feature2density(sigma_feature) + alpha, weight, bg_weight = raw2alpha(sigma, dists) + + + app_features = self.compute_appfeature(xyz_sampled) + rgbs = self.renderModule(xyz_sampled, rays_d, app_features) + #rgb[app_mask] = valid_rgbs + + acc_map = torch.sum(weight, -1) + rgb_map = torch.sum(weight[..., None] * rgbs, -2) + + if white_bg or (is_train and torch.rand((1,))<0.5): + rgb_map = rgb_map + (1. - acc_map[..., None]) + + + rgb_map = rgb_map.clamp(0,1) + rgb_map=rgb_map.view(B,V,H,W,3).permute(0,1,4,2,3) + + with torch.no_grad(): + depth_map = torch.sum(weight * z_vals, -1) + depth_map=depth_map.view(B,V,H,W,1).permute(0,1,4,2,3) + acc_map=acc_map.view(B,V,H,W,1).permute(0,1,4,2,3) + + results = { + 'image':rgb_map, + 'alpha':acc_map, + 'depth_map':depth_map + } + + return results # rgb, sigma, alpha, weight, bg_weight + diff --git a/core/transformer.py b/core/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9e85882f167546ccae0efbe333729659daf25a --- /dev/null +++ b/core/transformer.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +import torch +import torch.nn as nn +#from accelerate.logging import get_logger + + +#logger = get_logger(__name__) + + +class TransformerDecoder(nn.Module): + + """ + Transformer blocks that process the input and optionally use condition and modulation. + """ + + def __init__(self, block_type: str, + num_layers: int, num_heads: int, + inner_dim: int, cond_dim: int = None, mod_dim: int = None, + eps: float = 1e-6): + super().__init__() + self.block_type = block_type + self.layers = nn.ModuleList([ + self._block_fn(inner_dim, cond_dim, mod_dim)( + num_heads=num_heads, + eps=eps, + ) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(inner_dim, eps=eps) + + @property + def block_type(self): + return self._block_type + + @block_type.setter + def block_type(self, block_type): + assert block_type in ['basic', 'cond', 'mod', 'cond_mod'], \ + f"Unsupported block type: {block_type}" + self._block_type = block_type + + def _block_fn(self, inner_dim, cond_dim, mod_dim): + assert inner_dim is not None, f"inner_dim must always be specified" + if self.block_type == 'basic': + assert cond_dim is None and mod_dim is None, \ + f"Condition and modulation are not supported for BasicBlock" + from .block import BasicBlock + #logger.debug(f"Using BasicBlock") + return partial(BasicBlock, inner_dim=inner_dim) + elif self.block_type == 'cond': + assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock" + assert mod_dim is None, f"Modulation dimension is not supported for ConditionBlock" + from .block import ConditionBlock + #logger.debug(f"Using ConditionBlock") + return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim) + elif self.block_type == 'mod': + #logger.error(f"modulation without condition is not implemented") + raise NotImplementedError(f"modulation without condition is not implemented") + elif self.block_type == 'cond_mod': + assert cond_dim is not None and mod_dim is not None, \ + f"Condition and modulation dimensions must be specified for ConditionModulationBlock" + from .block import ConditionModulationBlock + #logger.debug(f"Using ConditionModulationBlock") + return partial(ConditionModulationBlock, inner_dim=inner_dim, cond_dim=cond_dim, mod_dim=mod_dim) + else: + raise ValueError(f"Unsupported block type during runtime: {self.block_type}") + + def assert_runtime_integrity(self, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor): + assert x is not None, f"Input tensor must be specified" + if self.block_type == 'basic': + assert cond is None and mod is None, \ + f"Condition and modulation are not supported for BasicBlock" + elif self.block_type == 'cond': + assert cond is not None and mod is None, \ + f"Condition must be specified and modulation is not supported for ConditionBlock" + elif self.block_type == 'mod': + raise NotImplementedError(f"modulation without condition is not implemented") + else: + assert cond is not None and mod is not None, \ + f"Condition and modulation must be specified for ConditionModulationBlock" + + def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor): + if self.block_type == 'basic': + return layer(x) + elif self.block_type == 'cond': + return layer(x, cond) + elif self.block_type == 'mod': + return layer(x, mod) + else: + return layer(x, cond, mod) + + def forward(self, x: torch.Tensor, cond: torch.Tensor = None, mod: torch.Tensor = None): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] or None + # mod: [N, D_mod] or None + self.assert_runtime_integrity(x, cond, mod) + for layer in self.layers: + x = self.forward_layer(layer, x, cond, mod) + x = self.norm(x) + return x diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d372dd3e96f3c87349903341a21f69699044a5a --- /dev/null +++ b/core/utils.py @@ -0,0 +1,109 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import roma +from kiui.op import safe_normalize + +def get_rays(pose, h, w, fovy, opengl=True): + + x, y = torch.meshgrid( + torch.arange(w, device=pose.device), + torch.arange(h, device=pose.device), + indexing="xy", + ) + x = x.flatten() + y = y.flatten() + + cx = w * 0.5 + cy = h * 0.5 + + focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) + + camera_dirs = F.pad( + torch.stack( + [ + (x - cx + 0.5) / focal, + (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), + ], + dim=-1, + ), + (0, 1), + value=(-1.0 if opengl else 1.0), + ) # [hw, 3] + + rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] + rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] + + rays_o = rays_o.view(h, w, 3) + rays_d = safe_normalize(rays_d).view(h, w, 3) + + return rays_o, rays_d + +def orbit_camera_jitter(poses, strength=0.1): + # poses: [B, 4, 4], assume orbit camera in opengl format + # random orbital rotate + + B = poses.shape[0] + rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) + rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) + + rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) + R = rot @ poses[:, :3, :3] + T = rot @ poses[:, :3, 3:] + + new_poses = poses.clone() + new_poses[:, :3, :3] = R + new_poses[:, :3, 3:] = T + + return new_poses + +def grid_distortion(images, strength=0.5): + # images: [B, C, H, W] + # num_steps: int, grid resolution for distortion + # strength: float in [0, 1], strength of distortion + + B, C, H, W = images.shape + + num_steps = np.random.randint(8, 17) + grid_steps = torch.linspace(-1, 1, num_steps) + + # have to loop batch... + grids = [] + for b in range(B): + # construct displacement + x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive + x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb + x_steps = (x_steps * W).long() # [num_steps] + x_steps[0] = 0 + x_steps[-1] = W + xs = [] + for i in range(num_steps - 1): + xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) + xs = torch.cat(xs, dim=0) # [W] + + y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive + y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb + y_steps = (y_steps * H).long() # [num_steps] + y_steps[0] = 0 + y_steps[-1] = H + ys = [] + for i in range(num_steps - 1): + ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) + ys = torch.cat(ys, dim=0) # [H] + + # construct grid + grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] + grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] + + grids.append(grid) + + grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] + + # grid sample + images = F.grid_sample(images, grids, align_corners=False) + + return images + diff --git a/example/arrangement.png b/example/arrangement.png new file mode 100644 index 0000000000000000000000000000000000000000..9fc766dc27147f913c30a79e30ca06e0f8b29c25 Binary files /dev/null and b/example/arrangement.png differ diff --git a/example/bear.png b/example/bear.png new file mode 100644 index 0000000000000000000000000000000000000000..ca1fd40dfd17c9b88e1972a92981a0f5099d72b7 Binary files /dev/null and b/example/bear.png differ diff --git a/example/blue_cat.png b/example/blue_cat.png new file mode 100644 index 0000000000000000000000000000000000000000..aa933e07cc2d6b058900730649637d75be19ef5f Binary files /dev/null and b/example/blue_cat.png differ diff --git a/example/bubble_mart_blue.png b/example/bubble_mart_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..af870322d4a8a2f237546fbea9560bb8e5f50364 Binary files /dev/null and b/example/bubble_mart_blue.png differ diff --git a/example/catstatue_rgba.png b/example/catstatue_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..3b44eb51645f4ecf9e288c53a46000c5a795af69 Binary files /dev/null and b/example/catstatue_rgba.png differ diff --git a/example/chair_armed.png b/example/chair_armed.png new file mode 100644 index 0000000000000000000000000000000000000000..2ab67e95ed57fbc5ebcd7d934827fd7fb03ab3ff Binary files /dev/null and b/example/chair_armed.png differ diff --git a/example/chair_comfort.jpg b/example/chair_comfort.jpg new file mode 100644 index 0000000000000000000000000000000000000000..918347fe51773d7ecaa7fb929274db8d7d5d3e19 Binary files /dev/null and b/example/chair_comfort.jpg differ diff --git a/example/chair_watermelon.png b/example/chair_watermelon.png new file mode 100644 index 0000000000000000000000000000000000000000..52b39917abcbd2f1eef9b7c8cf9aa602bddde1bf Binary files /dev/null and b/example/chair_watermelon.png differ diff --git a/example/chair_wood.jpg b/example/chair_wood.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc60569896fb02a46185aabb85086890f0f400d7 Binary files /dev/null and b/example/chair_wood.jpg differ diff --git a/example/chest.jpg b/example/chest.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26ae0b145887e43b850d298b94fe54828e909492 Binary files /dev/null and b/example/chest.jpg differ diff --git a/example/colored_mushroom.png b/example/colored_mushroom.png new file mode 100644 index 0000000000000000000000000000000000000000..8dea486d49d887f8132a85a3dfb7d99492aba475 Binary files /dev/null and b/example/colored_mushroom.png differ diff --git a/example/colored_mushroom2.png b/example/colored_mushroom2.png new file mode 100644 index 0000000000000000000000000000000000000000..76aa411b51a9fe04f159dd1913b1f5f94c4e4d5d Binary files /dev/null and b/example/colored_mushroom2.png differ diff --git a/example/dinosaur.png b/example/dinosaur.png new file mode 100644 index 0000000000000000000000000000000000000000..24f4212e0bbd8124e7e568fb2ac627fa72577b68 Binary files /dev/null and b/example/dinosaur.png differ diff --git a/example/dinosaur_set.png b/example/dinosaur_set.png new file mode 100644 index 0000000000000000000000000000000000000000..c719cc1b7bed6fbf5346581925fc75e50902b02c Binary files /dev/null and b/example/dinosaur_set.png differ diff --git a/example/dog.png b/example/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..e71afb763449e3a141a79582de0ac05f86edec5d Binary files /dev/null and b/example/dog.png differ diff --git a/example/extinguisher.png b/example/extinguisher.png new file mode 100644 index 0000000000000000000000000000000000000000..98f43a66718d49980e61d466b652bfca0a7ae9df Binary files /dev/null and b/example/extinguisher.png differ diff --git a/example/flower.png b/example/flower.png new file mode 100644 index 0000000000000000000000000000000000000000..d3eb7e210bb116c673f9d8722f68b4960f2bae1f Binary files /dev/null and b/example/flower.png differ diff --git a/example/fox.jpg b/example/fox.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f2efc1c3a9c4ad8f36ad93082c124c91a6e9ef7 Binary files /dev/null and b/example/fox.jpg differ diff --git a/example/genshin_building.png b/example/genshin_building.png new file mode 100644 index 0000000000000000000000000000000000000000..00b6a949d01283e1ae30fac4bd6040e13f18a055 Binary files /dev/null and b/example/genshin_building.png differ diff --git a/example/gun.png b/example/gun.png new file mode 100644 index 0000000000000000000000000000000000000000..648b74ac783dbbd389dc86d9ed9d45a0ce940af3 Binary files /dev/null and b/example/gun.png differ diff --git a/example/hatsune_miku.png b/example/hatsune_miku.png new file mode 100644 index 0000000000000000000000000000000000000000..2fecf005fdd56a396c4894256fbb98fcc1c4dd8f Binary files /dev/null and b/example/hatsune_miku.png differ diff --git a/example/house2.jpg b/example/house2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2eb8d63a6b91d5b16e729710c8b703aa5c11f9e5 Binary files /dev/null and b/example/house2.jpg differ diff --git a/example/kunkun.png b/example/kunkun.png new file mode 100644 index 0000000000000000000000000000000000000000..5306a2fec4eb879a4ed86e3ed10be8e5d96be677 Binary files /dev/null and b/example/kunkun.png differ diff --git a/example/mushroom_set.png b/example/mushroom_set.png new file mode 100644 index 0000000000000000000000000000000000000000..5ee997e30fc8f192fef9e87c2287e0d63e50186e Binary files /dev/null and b/example/mushroom_set.png differ diff --git a/example/mushroom_teapot.jpg b/example/mushroom_teapot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6c767354305f5467a4c0d5f199eee2a120f4501 Binary files /dev/null and b/example/mushroom_teapot.jpg differ diff --git a/example/owl.png b/example/owl.png new file mode 100644 index 0000000000000000000000000000000000000000..b5653a268dccaf8a6aefb4b67936996f00752d5b Binary files /dev/null and b/example/owl.png differ diff --git a/example/plant.jpg b/example/plant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3519c1639c3f837d9f1147cba1172e6aaab25a23 Binary files /dev/null and b/example/plant.jpg differ diff --git a/example/pumpkin.png b/example/pumpkin.png new file mode 100644 index 0000000000000000000000000000000000000000..ab8abcd15d6793921c1ee077d144b28be4b8a7e0 Binary files /dev/null and b/example/pumpkin.png differ diff --git a/example/rink.png b/example/rink.png new file mode 100644 index 0000000000000000000000000000000000000000..aea4ceb7fa6c0951926135dab6f61c69534d3a7e Binary files /dev/null and b/example/rink.png differ diff --git a/example/robot.jpg b/example/robot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..929450fba69a20389f39d46cb51d27facc1bba6d Binary files /dev/null and b/example/robot.jpg differ diff --git a/example/sea_turtle.png b/example/sea_turtle.png new file mode 100644 index 0000000000000000000000000000000000000000..27c3e2a9c7d44cb33914422b410ef41cf6591433 Binary files /dev/null and b/example/sea_turtle.png differ diff --git a/example/shelf.png b/example/shelf.png new file mode 100644 index 0000000000000000000000000000000000000000..1d956ccceb62cc0d16d26eca343f558a5378e722 Binary files /dev/null and b/example/shelf.png differ diff --git a/example/shoe.png b/example/shoe.png new file mode 100644 index 0000000000000000000000000000000000000000..5d7a4a688320bf166b4b73158b69417cf8f72657 Binary files /dev/null and b/example/shoe.png differ diff --git a/example/skating_shoe.jpg b/example/skating_shoe.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f21cb1d43e9d42d2836118963fc1d2874523748 Binary files /dev/null and b/example/skating_shoe.jpg differ diff --git a/example/stone.png b/example/stone.png new file mode 100644 index 0000000000000000000000000000000000000000..ec1907cb0b061b730911ba545a2df92378bbe54b Binary files /dev/null and b/example/stone.png differ diff --git a/example/sword.png b/example/sword.png new file mode 100644 index 0000000000000000000000000000000000000000..3068cb9bdbbd9ed3c0a143fd5c741abbc58508e3 Binary files /dev/null and b/example/sword.png differ diff --git a/example/teapot.png b/example/teapot.png new file mode 100644 index 0000000000000000000000000000000000000000..9172f47071ec9b8a95ec1edc3d29c2b8bd94ac43 Binary files /dev/null and b/example/teapot.png differ diff --git a/example/turtle.png b/example/turtle.png new file mode 100644 index 0000000000000000000000000000000000000000..c7797847ec016bb893c7a7c3af4b92b3a77d13bc Binary files /dev/null and b/example/turtle.png differ diff --git a/example/wing.png b/example/wing.png new file mode 100644 index 0000000000000000000000000000000000000000..5ae1de91fac96d53dae60695f0b22ee5d8d38686 Binary files /dev/null and b/example/wing.png differ diff --git a/mvdream/mv_unet.py b/mvdream/mv_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..d18b460caa92e3ac034f91c14792358bf2360538 --- /dev/null +++ b/mvdream/mv_unet.py @@ -0,0 +1,1005 @@ +import math +import numpy as np +from inspect import isfunction +from typing import Optional, Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.modeling_utils import ModelMixin + +# require xformers! +import xformers +import xformers.ops + +from kiui.cam import orbit_camera + +def get_camera( + num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False, +): + angle_gap = azimuth_span / num_frames + cameras = [] + for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): + + pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4] + + # opengl to blender + if blender_coord: + pose[2] *= -1 + pose[[1, 2]] = pose[[2, 1]] + + cameras.append(pose.flatten()) + + if extra_view: + cameras.append(np.zeros_like(cameras[0])) + + return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16] + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + # import pdb; pdb.set_trace() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def default(val, d): + if val is not None: + return val + return d() if isfunction(d) else d + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.ip_dim = ip_dim + self.ip_weight = ip_weight + + if self.ip_dim > 0: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None): + q = self.to_q(x) + context = default(context, x) + + if self.ip_dim > 0: + # context: [B, 77 + 16(ip), 1024] + token_len = context.shape[1] + context_ip = context[:, -self.ip_dim :, :] + k_ip = self.to_k_ip(context_ip) + v_ip = self.to_v_ip(context_ip) + context = context[:, : (token_len - self.ip_dim), :] + + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + if self.ip_dim > 0: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + # actually compute the attention, what we cannot get enough of + out_ip = xformers.ops.memory_efficient_attention( + q, k_ip, v_ip, attn_bias=None, op=self.attention_op + ) + out = out + self.ip_weight * out_ip + + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock3D(nn.Module): + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim, + dropout=0.0, + gated_ff=True, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + self.attn1 = MemoryEfficientCrossAttention( + query_dim=dim, + context_dim=None, # self-attention + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = MemoryEfficientCrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + # ip only applies to cross-attention + ip_dim=ip_dim, + ip_weight=ip_weight, + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + + def forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = self.attn1(self.norm1(x), context=None) + x + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + + def __init__( + self, + in_channels, + n_heads, + d_head, + context_dim, # cross attention input dim + depth=1, + dropout=0.0, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + if not isinstance(context_dim, list): + context_dim = [context_dim] + + self.in_channels = in_channels + + inner_dim = n_heads * d_head + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock3D( + inner_dim, + n_heads, + d_head, + context_dim=context_dim[d], + dropout=dropout, + ip_dim=ip_dim, + ip_weight=ip_weight, + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + + return x + x_in + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q, k, v = map( + lambda t: t.reshape(b, t.shape[1], self.heads, -1) + .transpose(1, 2) + .reshape(b, self.heads, t.shape[1], -1) + .contiguous(), + (q, k, v), + ) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * ff_mult, bias=False), + nn.GELU(), + nn.Linear(dim * ff_mult, dim, bias=False), + ) + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class CondSequential(nn.Sequential): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, num_frames=1): + for layer in self: + if isinstance(layer, ResBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer3D): + x = layer(x, context, num_frames=num_frames) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class MultiViewUNetModel(ModelMixin, ConfigMixin): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + transformer_depth=1, + context_dim=None, + n_embed=None, + num_attention_blocks=None, + adm_in_channels=None, + camera_dim=None, + ip_dim=0, # imagedream uses ip_dim > 0 + ip_weight=1.0, + **kwargs, + ): + super().__init__() + assert context_dim is not None + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + self.ip_dim = ip_dim + self.ip_weight = ip_weight + + if self.ip_dim > 0: + self.image_embed = Resampler( + dim=context_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=ip_dim, # num token + embedding_dim=1280, + output_dim=context_dim, + ff_mult=4, + ) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if camera_dim is not None: + time_embed_dim = model_channels * 4 + self.camera_embed = nn.Sequential( + nn.Linear(camera_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) + elif self.num_classes == "continuous": + # print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + CondSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers: List[Any] = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ) + ) + self.input_blocks.append(CondSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + CondSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = CondSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if num_attention_blocks is None or i < num_attention_blocks[level]: + layers.append( + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(CondSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + nn.GroupNorm(32, ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + nn.GroupNorm(32, ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + ip=None, + ip_img=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y is not None + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + emb = emb + self.camera_embed(camera) + + # imagedream variant + if self.ip_dim > 0: + x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9] + ip_emb = self.image_embed(ip) + context = torch.cat((context, ip_emb), 1) + + h = x + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) \ No newline at end of file diff --git a/mvdream/pipeline_mvdream.py b/mvdream/pipeline_mvdream.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd8a9fc6a7649951c813c5ab865a4fccd3edb3e --- /dev/null +++ b/mvdream/pipeline_mvdream.py @@ -0,0 +1,559 @@ +import torch +import torch.nn.functional as F +import inspect +import numpy as np +from typing import Callable, List, Optional, Union +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor +from diffusers import AutoencoderKL, DiffusionPipeline +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, +) +from diffusers.configuration_utils import FrozenDict +from diffusers.schedulers import DDIMScheduler +from diffusers.utils.torch_utils import randn_tensor + +from mvdream.mv_unet import MultiViewUNetModel, get_camera + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MVDreamPipeline(DiffusionPipeline): + + _optional_components = ["feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + unet: MultiViewUNetModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + scheduler: DDIMScheduler, + # imagedream variant + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + requires_safety_checker: bool = False, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + unet=unet, + scheduler=scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError( + "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_offload` requires `accelerate v0.17.0` or higher." + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook + ) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance: bool, + negative_prompt=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` should be either a string or a list of strings, but got {type(prompt)}." + ) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if image.dtype == np.float32: + image = (image * 255).astype(np.uint8) + + image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + return torch.zeros_like(image_embeds), image_embeds + + def encode_image_latents(self, image, device, num_images_per_prompt): + + dtype = next(self.image_encoder.parameters()).dtype + + image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W] + image = 2 * image - 1 + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image = image.to(dtype=dtype) + + posterior = self.vae.encode(image).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + return torch.zeros_like(latents), latents + + @torch.no_grad() + def __call__( + self, + prompt: str = "", + image: Optional[np.ndarray] = None, + height: int = 256, + width: int = 256, + elevation: float = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.0, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "numpy", # pil, numpy, latents + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + num_frames: int = 4, + device=torch.device("cuda:0"), + ): + self.unet = self.unet.to(device=device) + self.vae = self.vae.to(device=device) + self.text_encoder = self.text_encoder.to(device=device) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # imagedream variant + if image is not None: + assert isinstance(image, np.ndarray) and image.dtype == np.float32 + self.image_encoder = self.image_encoder.to(device=device) + image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt) + image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt) + + _prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + ) # type: ignore + prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2) + + # Prepare latent variables + actual_num_frames = num_frames if image is None else num_frames + 1 + latents: torch.Tensor = self.prepare_latents( + actual_num_frames * num_images_per_prompt, + 4, + height, + width, + prompt_embeds_pos.dtype, + device, + generator, + None, + ) + + if image is not None: + camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device) + else: + camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device) + camera = camera.repeat_interleave(num_images_per_prompt, dim=0) + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + multiplier = 2 if do_classifier_free_guidance else 1 + latent_model_input = torch.cat([latents] * multiplier) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + unet_inputs = { + 'x': latent_model_input, + 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device), + 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames), + 'num_frames': actual_num_frames, + 'camera': torch.cat([camera] * multiplier), + } + + if image is not None: + unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames) + unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat + + # predict the noise residual + noise_pred = self.unet.forward(**unet_inputs) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents: torch.Tensor = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # type: ignore + + # Post-processing + if output_type == "latent": + image = latents + elif output_type == "pil": + image = self.decode_latents(latents) + image = self.numpy_to_pil(image) + else: # numpy + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6123a4646e44628c73552b08bd00757a8630b03a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +torch +numpy +tyro +diffusers +dearpygui +einops +accelerate +gradio +imageio +imageio-ffmpeg +lpips +matplotlib +packaging +Pillow +pygltflib +rembg[gpu,cli] +rich +safetensors +scikit-image +scikit-learn +scipy +tqdm +transformers +trimesh +kiui >= 0.2.3 +xatlas +roma +plyfile diff --git a/zero123plus/model.py b/zero123plus/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0371bd7c071f7272c9fe1ac68c66781e6c4137f3 --- /dev/null +++ b/zero123plus/model.py @@ -0,0 +1,272 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from tqdm import tqdm +from torchvision.transforms import v2 +from torchvision.utils import make_grid, save_image +from einops import rearrange + +from src.utils.train_util import instantiate_from_config +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel +from .pipeline import RefOnlyNoisedUNet + + +def scale_latents(latents): + latents = (latents - 0.22) * 0.75 + return latents + + +def unscale_latents(latents): + latents = latents / 0.75 + 0.22 + return latents + + +def scale_image(image): + image = image * 0.5 / 0.8 + return image + + +def unscale_image(image): + image = image / 0.5 * 0.8 + return image + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class MVDiffusion(pl.LightningModule): + def __init__( + self, + stable_diffusion_config, + drop_cond_prob=0.1, + ): + super(MVDiffusion, self).__init__() + + self.drop_cond_prob = drop_cond_prob + + self.register_schedule() + + # init modules + pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config) + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipeline.scheduler.config, timestep_spacing='trailing' + ) + self.pipeline = pipeline + + train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config) + if isinstance(self.pipeline.unet, UNet2DConditionModel): + self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler) + + self.train_scheduler = train_sched # use ddpm scheduler during training + + self.unet = pipeline.unet + + # validation output buffer + self.validation_step_outputs = [] + + def register_schedule(self): + self.num_timesteps = 1000 + + # replace scaled_linear schedule with linear schedule as Zero123++ + beta_start = 0.00085 + beta_end = 0.0120 + betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) + + self.register_buffer('betas', betas.float()) + self.register_buffer('alphas_cumprod', alphas_cumprod.float()) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float()) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float()) + self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float()) + + self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float()) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float()) + + def on_fit_start(self): + device = torch.device(f'cuda:{self.global_rank}') + self.pipeline.to(device) + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) + + def prepare_batch_data(self, batch): + # prepare stable diffusion input + cond_imgs = batch['cond_imgs'] # (B, C, H, W) + cond_imgs = cond_imgs.to(self.device) + + # random resize the condition image + cond_size = np.random.randint(128, 513) + cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1) + + target_imgs = batch['target_imgs'] # (B, 6, C, H, W) + target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1) + target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W) + target_imgs = target_imgs.to(self.device) + + return cond_imgs, target_imgs + + @torch.no_grad() + def forward_vision_encoder(self, images): + dtype = next(self.pipeline.vision_encoder.parameters()).dtype + image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] + image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=self.device, dtype=dtype) + global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds + global_embeds = global_embeds.unsqueeze(-2) + + encoder_hidden_states = self.pipeline._encode_prompt("", self.device, 1, False)[0] + ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1) + encoder_hidden_states = encoder_hidden_states + global_embeds * ramp + + return encoder_hidden_states + + @torch.no_grad() + def encode_condition_image(self, images): + dtype = next(self.pipeline.vae.parameters()).dtype + image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] + image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values + image_pt = image_pt.to(device=self.device, dtype=dtype) + latents = self.pipeline.vae.encode(image_pt).latent_dist.sample() + return latents + + @torch.no_grad() + def encode_target_images(self, images): + dtype = next(self.pipeline.vae.parameters()).dtype + # equals to scaling images to [-1, 1] first and then call scale_image + images = (images - 0.5) / 0.8 # [-0.625, 0.625] + posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist + latents = posterior.sample() * self.pipeline.vae.config.scaling_factor + latents = scale_latents(latents) + return latents + + def forward_unet(self, latents, t, prompt_embeds, cond_latents): + dtype = next(self.pipeline.unet.parameters()).dtype + latents = latents.to(dtype) + prompt_embeds = prompt_embeds.to(dtype) + cond_latents = cond_latents.to(dtype) + cross_attention_kwargs = dict(cond_lat=cond_latents) + pred_noise = self.pipeline.unet( + latents, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + return pred_noise + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def training_step(self, batch, batch_idx): + # get input + cond_imgs, target_imgs = self.prepare_batch_data(batch) + + # sample random timestep + B = cond_imgs.shape[0] + + t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device) + + # classifier-free guidance + if np.random.rand() < self.drop_cond_prob: + prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False) + cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs)) + else: + prompt_embeds = self.forward_vision_encoder(cond_imgs) + cond_latents = self.encode_condition_image(cond_imgs) + + latents = self.encode_target_images(target_imgs) + noise = torch.randn_like(latents) + latents_noisy = self.train_scheduler.add_noise(latents, noise, t) + + v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents) + v_target = self.get_v(latents, noise, t) + + loss, loss_dict = self.compute_loss(v_pred, v_target) + + # logging + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.global_step % 500 == 0 and self.global_rank == 0: + with torch.no_grad(): + latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred) + + latents = unscale_latents(latents_pred) + images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] + images = (images * 0.5 + 0.5).clamp(0, 1) + images = torch.cat([target_imgs, images], dim=-2) + + grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1)) + save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) + + return loss + + def compute_loss(self, noise_pred, noise_gt): + loss = F.mse_loss(noise_pred, noise_gt) + + prefix = 'train' + loss_dict = {} + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # get input + cond_imgs, target_imgs = self.prepare_batch_data(batch) + + images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])] + + outputs = [] + for cond_img in images_pil: + latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images + image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] + image = (image * 0.5 + 0.5).clamp(0, 1) + outputs.append(image) + outputs = torch.cat(outputs, dim=0).to(self.device) + images = torch.cat([target_imgs, outputs], dim=-2) + + self.validation_step_outputs.append(images) + + @torch.no_grad() + def on_validation_epoch_end(self): + images = torch.cat(self.validation_step_outputs, dim=0) + + all_images = self.all_gather(images) + all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') + + if self.global_rank == 0: + grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1)) + save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')) + + self.validation_step_outputs.clear() # free memory + + def configure_optimizers(self): + lr = self.learning_rate + + optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) + + return {'optimizer': optimizer, 'lr_scheduler': scheduler} diff --git a/zero123plus/pipeline.py b/zero123plus/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..efd52fd12c784a42861123b69d49c313df7bd9a7 --- /dev/null +++ b/zero123plus/pipeline.py @@ -0,0 +1,406 @@ +from typing import Any, Dict, Optional +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers + +import numpy +import torch +import torch.nn as nn +import torch.utils.checkpoint +import torch.distributed +import transformers +from collections import OrderedDict +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + EulerAncestralDiscreteScheduler, + UNet2DConditionModel, + ImagePipelineOutput +) +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 +from diffusers.utils.import_utils import is_xformers_available + + +def to_rgb_image(maybe_rgba: Image.Image): + if maybe_rgba.mode == 'RGB': + return maybe_rgba + elif maybe_rgba.mode == 'RGBA': + rgba = maybe_rgba + img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) + img = Image.fromarray(img, 'RGB') + img.paste(rgba, mask=rgba.getchannel('A')) + return img + else: + raise ValueError("Unsupported image type.", maybe_rgba.mode) + + +class ReferenceOnlyAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, + mode="w", ref_dict: dict = None, is_cfg_guidance = False + ) -> Any: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + if self.enabled and is_cfg_guidance: + res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) + hidden_states = hidden_states[1:] + encoder_hidden_states = encoder_hidden_states[1:] + if self.enabled: + if mode == 'w': + ref_dict[self.name] = encoder_hidden_states + elif mode == 'r': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) + elif mode == 'm': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) + else: + assert False, mode + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + if self.enabled and is_cfg_guidance: + res = torch.cat([res0, res]) + return res + + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: + super().__init__() + self.unet = unet + self.train_sched = train_sched + self.val_sched = val_sched + + unet_lora_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + if torch.__version__ >= '2.0': + default_attn_proc = AttnProcessor2_0() + elif is_xformers_available(): + default_attn_proc = XFormersAttnProcessor() + else: + default_attn_proc = AttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + unet.set_attn_processor(unet_lora_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): + if is_cfg_guidance: + encoder_hidden_states = encoder_hidden_states[1:] + class_labels = class_labels[1:] + self.unet( + noisy_cond_lat, timestep, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), + **kwargs + ) + + def forward( + self, sample, timestep, encoder_hidden_states, class_labels=None, + *args, cross_attention_kwargs, + down_block_res_samples=None, mid_block_res_sample=None, + **kwargs + ): + cond_lat = cross_attention_kwargs['cond_lat'] + is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) + noise = torch.randn_like(cond_lat) + if self.training: + noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) + noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) + else: + noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + ref_dict = {} + self.forward_cond( + noisy_cond_lat, timestep, + encoder_hidden_states, class_labels, + ref_dict, is_cfg_guidance, **kwargs + ) + weight_dtype = self.unet.dtype + return self.unet( + sample, timestep, + encoder_hidden_states, *args, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=( + mid_block_res_sample.to(dtype=weight_dtype) + if mid_block_res_sample is not None else None + ), + **kwargs + ) + + +def scale_latents(latents): + latents = (latents - 0.22) * 0.75 + return latents + + +def unscale_latents(latents): + latents = latents / 0.75 + 0.22 + return latents + + +def scale_image(image): + image = image * 0.5 / 0.8 + return image + + +def unscale_image(image): + image = image / 0.5 * 0.8 + return image + + +class DepthControlUNet(torch.nn.Module): + def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: + super().__init__() + self.unet = unet + if controlnet is None: + self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) + else: + self.controlnet = controlnet + DefaultAttnProc = AttnProcessor2_0 + if is_xformers_available(): + DefaultAttnProc = XFormersAttnProcessor + self.controlnet.set_attn_processor(DefaultAttnProc()) + self.conditioning_scale = conditioning_scale + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): + cross_attention_kwargs = dict(cross_attention_kwargs) + control_depth = cross_attention_kwargs.pop('control_depth') + down_block_res_samples, mid_block_res_sample = self.controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=control_depth, + conditioning_scale=self.conditioning_scale, + return_dict=False, + ) + return self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_res_samples=down_block_res_samples, + mid_block_res_sample=mid_block_res_sample, + cross_attention_kwargs=cross_attention_kwargs + ) + + +class ModuleListDict(torch.nn.Module): + def __init__(self, procs: dict) -> None: + super().__init__() + self.keys = sorted(procs.keys()) + self.values = torch.nn.ModuleList(procs[k] for k in self.keys) + + def __getitem__(self, key): + return self.values[self.keys.index(key)] + + +class SuperNet(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = dict(enumerate(state_dict.keys())) + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + return key.split('.')[0] + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = remap_key(key, state_dict) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): + tokenizer: transformers.CLIPTokenizer + text_encoder: transformers.CLIPTextModel + vision_encoder: transformers.CLIPVisionModelWithProjection + + feature_extractor_clip: transformers.CLIPImageProcessor + unet: UNet2DConditionModel + scheduler: diffusers.schedulers.KarrasDiffusionSchedulers + + vae: AutoencoderKL + ramping: nn.Linear + + feature_extractor_vae: transformers.CLIPImageProcessor + + depth_transforms_multi = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]) + ]) + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vision_encoder: transformers.CLIPVisionModelWithProjection, + feature_extractor_clip: CLIPImageProcessor, + feature_extractor_vae: CLIPImageProcessor, + ramping_coefficients: Optional[list] = None, + safety_checker=None, + ): + DiffusionPipeline.__init__(self) + + self.register_modules( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + unet=unet, scheduler=scheduler, safety_checker=None, + vision_encoder=vision_encoder, + feature_extractor_clip=feature_extractor_clip, + feature_extractor_vae=feature_extractor_vae + ) + self.register_to_config(ramping_coefficients=ramping_coefficients) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def prepare(self): + train_sched = DDPMScheduler.from_config(self.scheduler.config) + if isinstance(self.unet, UNet2DConditionModel): + self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() + + def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): + self.prepare() + self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) + return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) + + def encode_condition_image(self, image: torch.Tensor): + image = self.vae.encode(image).latent_dist.sample() + return image + + @torch.no_grad() + def __call__( + self, + image: Image.Image = None, + prompt = "", + *args, + num_images_per_prompt: Optional[int] = 1, + guidance_scale=4.0, + depth_image: Image.Image = None, + output_type: Optional[str] = "pil", + width=640, + height=960, + num_inference_steps=28, + return_dict=True, + **kwargs + ): + self.prepare() + if image is None: + raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") + assert not isinstance(image, torch.Tensor) + image = to_rgb_image(image) + image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values + image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values + if depth_image is not None and hasattr(self.unet, "controlnet"): + depth_image = to_rgb_image(depth_image) + depth_image = self.depth_transforms_multi(depth_image).to( + device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype + ) + image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) + image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) + cond_lat = self.encode_condition_image(image) + if guidance_scale > 1: + negative_lat = self.encode_condition_image(torch.zeros_like(image)) + cond_lat = torch.cat([negative_lat, cond_lat]) + encoded = self.vision_encoder(image_2, output_hidden_states=False) + global_embeds = encoded.image_embeds + global_embeds = global_embeds.unsqueeze(-2) + + if hasattr(self, "encode_prompt"): + encoder_hidden_states = self.encode_prompt( + prompt, + self.device, + num_images_per_prompt, + False + )[0] + else: + encoder_hidden_states = self._encode_prompt( + prompt, + self.device, + num_images_per_prompt, + False + ) + ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) + encoder_hidden_states = encoder_hidden_states + global_embeds * ramp + cak = dict(cond_lat=cond_lat) + if hasattr(self.unet, "controlnet"): + cak['control_depth'] = depth_image + latents: torch.Tensor = super().__call__( + None, + *args, + cross_attention_kwargs=cak, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=encoder_hidden_states, + num_inference_steps=num_inference_steps, + output_type='latent', + width=width, + height=height, + **kwargs + ).images + latents = unscale_latents(latents) + if not output_type == "latent": + image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image)