JiantaoLin
commited on
Commit
•
2fe3da0
1
Parent(s):
e2cc5f8
new
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README copy.md +111 -0
- app.py +499 -0
- configs/PRM.yaml +71 -0
- configs/PRM_inference.yaml +22 -0
- light2map.py +95 -0
- obj2mesh.py +121 -0
- requirements.txt +21 -0
- run.py +355 -0
- run.sh +7 -0
- run_hpc.sh +16 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/objaverse.cpython-310.pyc +0 -0
- src/data/bsdf_256_256.bin +0 -0
- src/data/objaverse.py +509 -0
- src/model_mesh.py +642 -0
- src/models/__init__.py +0 -0
- src/models/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/__pycache__/lrm_mesh.cpython-310.pyc +0 -0
- src/models/decoder/__init__.py +0 -0
- src/models/decoder/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/decoder/__pycache__/transformer.cpython-310.pyc +0 -0
- src/models/decoder/transformer.py +123 -0
- src/models/encoder/__init__.py +0 -0
- src/models/encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/encoder/__pycache__/dino.cpython-310.pyc +0 -0
- src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc +0 -0
- src/models/encoder/dino.py +550 -0
- src/models/encoder/dino_wrapper.py +80 -0
- src/models/geometry/__init__.py +7 -0
- src/models/geometry/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/geometry/camera/__init__.py +16 -0
- src/models/geometry/camera/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/geometry/camera/__pycache__/perspective_camera.cpython-310.pyc +0 -0
- src/models/geometry/camera/perspective_camera.py +35 -0
- src/models/geometry/render/__init__.py +8 -0
- src/models/geometry/render/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/geometry/render/__pycache__/neural_render.cpython-310.pyc +0 -0
- src/models/geometry/render/__pycache__/util.cpython-310.pyc +0 -0
- src/models/geometry/render/neural_render.py +293 -0
- src/models/geometry/render/renderutils/__init__.py +11 -0
- src/models/geometry/render/renderutils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/geometry/render/renderutils/__pycache__/bsdf.cpython-310.pyc +0 -0
- src/models/geometry/render/renderutils/__pycache__/loss.cpython-310.pyc +0 -0
- src/models/geometry/render/renderutils/__pycache__/ops.cpython-310.pyc +0 -0
- src/models/geometry/render/renderutils/bsdf.py +151 -0
- src/models/geometry/render/renderutils/c_src/bsdf.cu +710 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README copy.md
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
|
5 |
+
# PRM: Photometric Stereo based Large Reconstruction Model
|
6 |
+
|
7 |
+
<a href="https://tau-yihouxiang.github.io/projects/X-Ray/X-Ray.html"><img src="https://img.shields.io/badge/Project_Page-Online-EA3A97"></a>
|
8 |
+
<a href="https://arxiv.org/abs/2404.07191"><img src="https://img.shields.io/badge/ArXiv-2404.07191-brightgreen"></a>
|
9 |
+
<a href="https://huggingface.co/LTT/PRM"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a> <br>
|
10 |
+
<a href="https://huggingface.co/spaces/TencentARC/InstantMesh"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Gradio%20Demo-Huggingface-orange"></a>
|
11 |
+
<a href="https://github.com/jtydhr88/ComfyUI-InstantMesh"><img src="https://img.shields.io/badge/Demo-ComfyUI-8A2BE2"></a>
|
12 |
+
|
13 |
+
</div>
|
14 |
+
|
15 |
+
---
|
16 |
+
|
17 |
+
An official implementation of PRM, a feed-forward framework for high-quality 3D mesh generation with photometric stereo images.
|
18 |
+
|
19 |
+
|
20 |
+
![image](https://github.com/g3956/PRM/blob/main/assets/teaser.png)
|
21 |
+
|
22 |
+
# 🚩 Features
|
23 |
+
- [x] Release inference and training code.
|
24 |
+
- [x] Release model weights.
|
25 |
+
- [x] Release huggingface gradio demo. Please try it at [demo](https://huggingface.co/spaces/TencentARC/InstantMesh) link.
|
26 |
+
- [x] Release ComfyUI demo.
|
27 |
+
|
28 |
+
# ⚙️ Dependencies and Installation
|
29 |
+
|
30 |
+
We recommend using `Python>=3.10`, `PyTorch>=2.1.0`, and `CUDA>=12.1`.
|
31 |
+
```bash
|
32 |
+
conda create --name PRM python=3.10
|
33 |
+
conda activate PRM
|
34 |
+
pip install -U pip
|
35 |
+
|
36 |
+
# Ensure Ninja is installed
|
37 |
+
conda install Ninja
|
38 |
+
|
39 |
+
# Install the correct version of CUDA
|
40 |
+
conda install cuda -c nvidia/label/cuda-12.1.0
|
41 |
+
|
42 |
+
# Install PyTorch and xformers
|
43 |
+
# You may need to install another xformers version if you use a different PyTorch version
|
44 |
+
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
45 |
+
pip install xformers==0.0.22.post7
|
46 |
+
|
47 |
+
# Install Triton
|
48 |
+
pip install triton
|
49 |
+
|
50 |
+
# Install other requirements
|
51 |
+
pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
# 💫 Inference
|
55 |
+
|
56 |
+
## Download the pretrained model
|
57 |
+
|
58 |
+
The pretrained model can be found [model card](https://huggingface.co/LTT/PRM).
|
59 |
+
|
60 |
+
Our inference script will download the models automatically. Alternatively, you can manually download the models and put them under the `ckpts/` directory.
|
61 |
+
|
62 |
+
# 💻 Training
|
63 |
+
|
64 |
+
We provide our training code to facilitate future research.
|
65 |
+
For training data, we used filtered Objaverse for training. Before training, you need to pre-processe the environment maps and GLB files into formats that fit our dataloader.
|
66 |
+
For preprocessing GLB files, please run
|
67 |
+
```bash
|
68 |
+
# GLB files to OBJ files
|
69 |
+
python train.py --base configs/instant-mesh-large-train.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
|
70 |
+
```
|
71 |
+
then
|
72 |
+
```bash
|
73 |
+
# OBJ files to mesh files that can be readed
|
74 |
+
python obj2mesh.py path_to_obj save_path
|
75 |
+
```
|
76 |
+
For preprocessing environment maps, please run
|
77 |
+
```bash
|
78 |
+
# Pre-process environment maps
|
79 |
+
python light2map.py path_to_env save_path
|
80 |
+
```
|
81 |
+
|
82 |
+
|
83 |
+
To train the sparse-view reconstruction models, please run:
|
84 |
+
```bash
|
85 |
+
# Training on Mesh representation
|
86 |
+
python train.py --base configs/PRM.yaml --gpus 0,1,2,3,4,5,6,7 --num_nodes 1
|
87 |
+
```
|
88 |
+
Note that you need to change to root_dir and light_dir to pathes that you save the preprocessed GLB files and environment maps.
|
89 |
+
|
90 |
+
# :books: Citation
|
91 |
+
|
92 |
+
If you find our work useful for your research or applications, please cite using this BibTeX:
|
93 |
+
|
94 |
+
```BibTeX
|
95 |
+
@article{xu2024instantmesh,
|
96 |
+
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
|
97 |
+
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
|
98 |
+
journal={arXiv preprint arXiv:2404.07191},
|
99 |
+
year={2024}
|
100 |
+
}
|
101 |
+
```
|
102 |
+
|
103 |
+
# 🤗 Acknowledgements
|
104 |
+
|
105 |
+
We thank the authors of the following projects for their excellent contributions to 3D generative AI!
|
106 |
+
|
107 |
+
- [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes)
|
108 |
+
- [InstantMesh]([https://instant-3d.github.io/](https://github.com/TencentARC/InstantMesh))
|
109 |
+
|
110 |
+
|
111 |
+
|
app.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import rembg
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.transforms import v2
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from tqdm import tqdm
|
12 |
+
import glm
|
13 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
14 |
+
|
15 |
+
from src.data.objaverse import load_mipmap
|
16 |
+
from src.utils import render_utils
|
17 |
+
from src.utils.train_util import instantiate_from_config
|
18 |
+
from src.utils.camera_util import (
|
19 |
+
FOV_to_intrinsics,
|
20 |
+
get_zero123plus_input_cameras,
|
21 |
+
get_circular_camera_poses,
|
22 |
+
)
|
23 |
+
from src.utils.mesh_util import save_obj, save_glb
|
24 |
+
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
|
25 |
+
|
26 |
+
import tempfile
|
27 |
+
from huggingface_hub import hf_hub_download
|
28 |
+
|
29 |
+
|
30 |
+
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
|
31 |
+
device0 = torch.device('cuda:0')
|
32 |
+
device1 = torch.device('cuda:0')
|
33 |
+
else:
|
34 |
+
device0 = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
35 |
+
device1 = device0
|
36 |
+
|
37 |
+
# Define the cache directory for model files
|
38 |
+
model_cache_dir = './ckpts/'
|
39 |
+
os.makedirs(model_cache_dir, exist_ok=True)
|
40 |
+
|
41 |
+
def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50):
|
42 |
+
"""
|
43 |
+
Get the rendering camera parameters.
|
44 |
+
"""
|
45 |
+
train_res = [512, 512]
|
46 |
+
cam_near_far = [0.1, 1000.0]
|
47 |
+
fovy = np.deg2rad(fov)
|
48 |
+
proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])
|
49 |
+
all_mv = []
|
50 |
+
all_mvp = []
|
51 |
+
all_campos = []
|
52 |
+
if isinstance(elevation, tuple):
|
53 |
+
elevation_0 = np.deg2rad(elevation[0])
|
54 |
+
elevation_1 = np.deg2rad(elevation[1])
|
55 |
+
for i in range(M//2):
|
56 |
+
azimuth = 2 * np.pi * i / (M // 2)
|
57 |
+
z = radius * np.cos(azimuth) * np.sin(elevation_0)
|
58 |
+
x = radius * np.sin(azimuth) * np.sin(elevation_0)
|
59 |
+
y = radius * np.cos(elevation_0)
|
60 |
+
|
61 |
+
eye = glm.vec3(x, y, z)
|
62 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
63 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
64 |
+
view_matrix = glm.lookAt(eye, at, up)
|
65 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
66 |
+
mvp = proj_mtx @ (mv) #w2c
|
67 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
68 |
+
all_mv.append(mv[None, ...].cuda())
|
69 |
+
all_mvp.append(mvp[None, ...].cuda())
|
70 |
+
all_campos.append(campos[None, ...].cuda())
|
71 |
+
for i in range(M//2):
|
72 |
+
azimuth = 2 * np.pi * i / (M // 2)
|
73 |
+
z = radius * np.cos(azimuth) * np.sin(elevation_1)
|
74 |
+
x = radius * np.sin(azimuth) * np.sin(elevation_1)
|
75 |
+
y = radius * np.cos(elevation_1)
|
76 |
+
|
77 |
+
eye = glm.vec3(x, y, z)
|
78 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
79 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
80 |
+
view_matrix = glm.lookAt(eye, at, up)
|
81 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
82 |
+
mvp = proj_mtx @ (mv) #w2c
|
83 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
84 |
+
all_mv.append(mv[None, ...].cuda())
|
85 |
+
all_mvp.append(mvp[None, ...].cuda())
|
86 |
+
all_campos.append(campos[None, ...].cuda())
|
87 |
+
else:
|
88 |
+
# elevation = 90 - elevation
|
89 |
+
for i in range(M):
|
90 |
+
azimuth = 2 * np.pi * i / M
|
91 |
+
z = radius * np.cos(azimuth) * np.sin(elevation)
|
92 |
+
x = radius * np.sin(azimuth) * np.sin(elevation)
|
93 |
+
y = radius * np.cos(elevation)
|
94 |
+
|
95 |
+
eye = glm.vec3(x, y, z)
|
96 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
97 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
98 |
+
view_matrix = glm.lookAt(eye, at, up)
|
99 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
100 |
+
mvp = proj_mtx @ (mv) #w2c
|
101 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
102 |
+
all_mv.append(mv[None, ...].cuda())
|
103 |
+
all_mvp.append(mvp[None, ...].cuda())
|
104 |
+
all_campos.append(campos[None, ...].cuda())
|
105 |
+
all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2)
|
106 |
+
all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2)
|
107 |
+
all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2)
|
108 |
+
return all_mv, all_mvp, all_campos
|
109 |
+
|
110 |
+
|
111 |
+
def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1, is_flexicubes=False):
|
112 |
+
"""
|
113 |
+
Render frames from triplanes.
|
114 |
+
"""
|
115 |
+
frames = []
|
116 |
+
albedos = []
|
117 |
+
pbr_spec_lights = []
|
118 |
+
pbr_diffuse_lights = []
|
119 |
+
normals = []
|
120 |
+
alphas = []
|
121 |
+
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
122 |
+
if is_flexicubes:
|
123 |
+
out = model.forward_geometry(
|
124 |
+
planes,
|
125 |
+
render_cameras[:, i:i+chunk_size],
|
126 |
+
camera_pos[:, i:i+chunk_size],
|
127 |
+
[[env]*chunk_size],
|
128 |
+
[[materials]*chunk_size],
|
129 |
+
render_size=render_size,
|
130 |
+
)
|
131 |
+
frame = out['pbr_img']
|
132 |
+
albedo = out['albedo']
|
133 |
+
pbr_spec_light = out['pbr_spec_light']
|
134 |
+
pbr_diffuse_light = out['pbr_diffuse_light']
|
135 |
+
normal = out['normal']
|
136 |
+
alpha = out['mask']
|
137 |
+
else:
|
138 |
+
frame = model.forward_synthesizer(
|
139 |
+
planes,
|
140 |
+
render_cameras[i],
|
141 |
+
render_size=render_size,
|
142 |
+
)['images_rgb']
|
143 |
+
frames.append(frame)
|
144 |
+
albedos.append(albedo)
|
145 |
+
pbr_spec_lights.append(pbr_spec_light)
|
146 |
+
pbr_diffuse_lights.append(pbr_diffuse_light)
|
147 |
+
normals.append(normal)
|
148 |
+
alphas.append(alpha)
|
149 |
+
|
150 |
+
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
|
151 |
+
alphas = torch.cat(alphas, dim=1)[0]
|
152 |
+
albedos = torch.cat(albedos, dim=1)[0]
|
153 |
+
pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0]
|
154 |
+
pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0]
|
155 |
+
normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3]
|
156 |
+
return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def images_to_video(images, output_path, fps=30):
|
161 |
+
# images: (N, C, H, W)
|
162 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
163 |
+
frames = []
|
164 |
+
for i in range(images.shape[0]):
|
165 |
+
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
|
166 |
+
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
167 |
+
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
168 |
+
assert frame.min() >= 0 and frame.max() <= 255, \
|
169 |
+
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
170 |
+
frames.append(frame)
|
171 |
+
imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
|
172 |
+
|
173 |
+
|
174 |
+
###############################################################################
|
175 |
+
# Configuration.
|
176 |
+
###############################################################################
|
177 |
+
|
178 |
+
seed_everything(0)
|
179 |
+
|
180 |
+
config_path = 'configs/PRM_inference.yaml'
|
181 |
+
config = OmegaConf.load(config_path)
|
182 |
+
config_name = os.path.basename(config_path).replace('.yaml', '')
|
183 |
+
model_config = config.model_config
|
184 |
+
infer_config = config.infer_config
|
185 |
+
|
186 |
+
IS_FLEXICUBES = True
|
187 |
+
|
188 |
+
device = torch.device('cuda')
|
189 |
+
|
190 |
+
# load diffusion model
|
191 |
+
print('Loading diffusion model ...')
|
192 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
193 |
+
"sudo-ai/zero123plus-v1.2",
|
194 |
+
custom_pipeline="zero123plus",
|
195 |
+
torch_dtype=torch.float16,
|
196 |
+
cache_dir=model_cache_dir
|
197 |
+
)
|
198 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
199 |
+
pipeline.scheduler.config, timestep_spacing='trailing'
|
200 |
+
)
|
201 |
+
|
202 |
+
# load custom white-background UNet
|
203 |
+
print('Loading custom white-background unet ...')
|
204 |
+
if os.path.exists(infer_config.unet_path):
|
205 |
+
unet_ckpt_path = infer_config.unet_path
|
206 |
+
else:
|
207 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="diffusion_pytorch_model.bin", repo_type="model")
|
208 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
209 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
210 |
+
|
211 |
+
pipeline = pipeline.to(device)
|
212 |
+
|
213 |
+
# load reconstruction model
|
214 |
+
print('Loading reconstruction model ...')
|
215 |
+
model = instantiate_from_config(model_config)
|
216 |
+
if os.path.exists(infer_config.model_path):
|
217 |
+
model_ckpt_path = infer_config.model_path
|
218 |
+
else:
|
219 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
220 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
221 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
222 |
+
model.load_state_dict(state_dict, strict=True)
|
223 |
+
|
224 |
+
model = model.to(device1)
|
225 |
+
if IS_FLEXICUBES:
|
226 |
+
model.init_flexicubes_geometry(device1, fovy=30.0)
|
227 |
+
model = model.eval()
|
228 |
+
|
229 |
+
print('Loading Finished!')
|
230 |
+
|
231 |
+
|
232 |
+
def check_input_image(input_image):
|
233 |
+
if input_image is None:
|
234 |
+
raise gr.Error("No image uploaded!")
|
235 |
+
|
236 |
+
|
237 |
+
def preprocess(input_image, do_remove_background):
|
238 |
+
|
239 |
+
rembg_session = rembg.new_session() if do_remove_background else None
|
240 |
+
if do_remove_background:
|
241 |
+
input_image = remove_background(input_image, rembg_session)
|
242 |
+
input_image = resize_foreground(input_image, 0.85)
|
243 |
+
|
244 |
+
return input_image
|
245 |
+
|
246 |
+
|
247 |
+
def generate_mvs(input_image, sample_steps, sample_seed):
|
248 |
+
|
249 |
+
seed_everything(sample_seed)
|
250 |
+
|
251 |
+
# sampling
|
252 |
+
generator = torch.Generator(device=device0)
|
253 |
+
z123_image = pipeline(
|
254 |
+
input_image,
|
255 |
+
num_inference_steps=sample_steps,
|
256 |
+
generator=generator,
|
257 |
+
).images[0]
|
258 |
+
|
259 |
+
show_image = np.asarray(z123_image, dtype=np.uint8)
|
260 |
+
show_image = torch.from_numpy(show_image) # (960, 640, 3)
|
261 |
+
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
262 |
+
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
263 |
+
show_image = Image.fromarray(show_image.numpy())
|
264 |
+
|
265 |
+
return z123_image, show_image
|
266 |
+
|
267 |
+
|
268 |
+
def make_mesh(mesh_fpath, planes):
|
269 |
+
|
270 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
271 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
272 |
+
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
273 |
+
|
274 |
+
with torch.no_grad():
|
275 |
+
# get mesh
|
276 |
+
|
277 |
+
mesh_out = model.extract_mesh(
|
278 |
+
planes,
|
279 |
+
use_texture_map=False,
|
280 |
+
**infer_config,
|
281 |
+
)
|
282 |
+
|
283 |
+
vertices, faces, vertex_colors = mesh_out
|
284 |
+
vertices = vertices[:, [1, 2, 0]]
|
285 |
+
|
286 |
+
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
|
287 |
+
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
288 |
+
|
289 |
+
print(f"Mesh saved to {mesh_fpath}")
|
290 |
+
|
291 |
+
return mesh_fpath, mesh_glb_fpath
|
292 |
+
|
293 |
+
|
294 |
+
def make3d(images):
|
295 |
+
|
296 |
+
images = np.asarray(images, dtype=np.float32) / 255.0
|
297 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
298 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
299 |
+
|
300 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=3.2, fov=30).to(device).to(device1)
|
301 |
+
all_mv, all_mvp, all_campos = get_render_cameras(
|
302 |
+
batch_size=1,
|
303 |
+
M=240,
|
304 |
+
radius=4.5,
|
305 |
+
elevation=(90, 60.0),
|
306 |
+
is_flexicubes=IS_FLEXICUBES,
|
307 |
+
fov=30
|
308 |
+
)
|
309 |
+
|
310 |
+
images = images.unsqueeze(0).to(device1)
|
311 |
+
images = v2.functional.resize(images, (512, 512), interpolation=3, antialias=True).clamp(0, 1)
|
312 |
+
|
313 |
+
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
314 |
+
print(mesh_fpath)
|
315 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
316 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
317 |
+
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
|
318 |
+
ENV = load_mipmap("env_mipmap/6")
|
319 |
+
materials = (0.0,0.9)
|
320 |
+
with torch.no_grad():
|
321 |
+
# get triplane
|
322 |
+
planes = model.forward_planes(images, input_cameras)
|
323 |
+
|
324 |
+
# get video
|
325 |
+
chunk_size = 20 if IS_FLEXICUBES else 1
|
326 |
+
render_size = 512
|
327 |
+
|
328 |
+
frames = []
|
329 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
330 |
+
model,
|
331 |
+
planes,
|
332 |
+
render_cameras=all_mvp,
|
333 |
+
camera_pos=all_campos,
|
334 |
+
env=ENV,
|
335 |
+
materials=materials,
|
336 |
+
render_size=render_size,
|
337 |
+
chunk_size=chunk_size,
|
338 |
+
is_flexicubes=IS_FLEXICUBES,
|
339 |
+
)
|
340 |
+
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
341 |
+
normals = normals * alphas + (1-alphas)
|
342 |
+
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
343 |
+
|
344 |
+
images_to_video(
|
345 |
+
all_frames,
|
346 |
+
video_fpath,
|
347 |
+
fps=30,
|
348 |
+
)
|
349 |
+
|
350 |
+
print(f"Video saved to {video_fpath}")
|
351 |
+
|
352 |
+
mesh_fpath, mesh_glb_fpath = make_mesh(mesh_fpath, planes)
|
353 |
+
|
354 |
+
return video_fpath, mesh_fpath, mesh_glb_fpath
|
355 |
+
|
356 |
+
|
357 |
+
import gradio as gr
|
358 |
+
|
359 |
+
_HEADER_ = '''
|
360 |
+
<h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/g3956/PRM' target='_blank'><b>PRM: Photometric Stereo based Large Reconstruction Model</b></a></h2>
|
361 |
+
|
362 |
+
**PRM** is a feed-forward framework for high-quality 3D mesh generation with fine-grained local details from a single image.
|
363 |
+
|
364 |
+
Code: <a href='https://github.com/g3956/PRM' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
|
365 |
+
'''
|
366 |
+
|
367 |
+
_CITE_ = r"""
|
368 |
+
If PRM is helpful, please help to ⭐ the <a href='https://github.com/g3956/PRM' target='_blank'>Github Repo</a>. Thanks!
|
369 |
+
---
|
370 |
+
📝 **Citation**
|
371 |
+
|
372 |
+
If you find our work useful for your research or applications, please cite using this bibtex:
|
373 |
+
```bibtex
|
374 |
+
@article{xu2024instantmesh,
|
375 |
+
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
|
376 |
+
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
|
377 |
+
journal={arXiv preprint arXiv:2404.07191},
|
378 |
+
year={2024}
|
379 |
+
}
|
380 |
+
```
|
381 |
+
|
382 |
+
📋 **License**
|
383 |
+
|
384 |
+
Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
|
385 |
+
|
386 |
+
📧 **Contact**
|
387 |
+
|
388 |
+
If you have any questions, feel free to open a discussion or contact us at <b>jlin695@connect.hkust-gz.edu.cn</b>.
|
389 |
+
"""
|
390 |
+
|
391 |
+
with gr.Blocks() as demo:
|
392 |
+
gr.Markdown(_HEADER_)
|
393 |
+
with gr.Row(variant="panel"):
|
394 |
+
with gr.Column():
|
395 |
+
with gr.Row():
|
396 |
+
input_image = gr.Image(
|
397 |
+
label="Input Image",
|
398 |
+
image_mode="RGBA",
|
399 |
+
sources="upload",
|
400 |
+
width=256,
|
401 |
+
height=256,
|
402 |
+
type="pil",
|
403 |
+
elem_id="content_image",
|
404 |
+
)
|
405 |
+
processed_image = gr.Image(
|
406 |
+
label="Processed Image",
|
407 |
+
image_mode="RGBA",
|
408 |
+
width=256,
|
409 |
+
height=256,
|
410 |
+
type="pil",
|
411 |
+
interactive=False
|
412 |
+
)
|
413 |
+
with gr.Row():
|
414 |
+
with gr.Group():
|
415 |
+
do_remove_background = gr.Checkbox(
|
416 |
+
label="Remove Background", value=True
|
417 |
+
)
|
418 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
419 |
+
|
420 |
+
sample_steps = gr.Slider(
|
421 |
+
label="Sample Steps",
|
422 |
+
minimum=30,
|
423 |
+
maximum=100,
|
424 |
+
value=75,
|
425 |
+
step=5
|
426 |
+
)
|
427 |
+
|
428 |
+
with gr.Row():
|
429 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
430 |
+
|
431 |
+
with gr.Row(variant="panel"):
|
432 |
+
gr.Examples(
|
433 |
+
examples=[
|
434 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
435 |
+
],
|
436 |
+
inputs=[input_image],
|
437 |
+
label="Examples",
|
438 |
+
examples_per_page=20
|
439 |
+
)
|
440 |
+
|
441 |
+
with gr.Column():
|
442 |
+
|
443 |
+
with gr.Row():
|
444 |
+
|
445 |
+
with gr.Column():
|
446 |
+
mv_show_images = gr.Image(
|
447 |
+
label="Generated Multi-views",
|
448 |
+
type="pil",
|
449 |
+
width=379,
|
450 |
+
interactive=False
|
451 |
+
)
|
452 |
+
|
453 |
+
with gr.Column():
|
454 |
+
with gr.Column():
|
455 |
+
output_video = gr.Video(
|
456 |
+
label="video", format="mp4",
|
457 |
+
width=768,
|
458 |
+
autoplay=True,
|
459 |
+
interactive=False
|
460 |
+
)
|
461 |
+
|
462 |
+
with gr.Row():
|
463 |
+
with gr.Tab("OBJ"):
|
464 |
+
output_model_obj = gr.Model3D(
|
465 |
+
label="Output Model (OBJ Format)",
|
466 |
+
#width=768,
|
467 |
+
interactive=False,
|
468 |
+
)
|
469 |
+
gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
|
470 |
+
with gr.Tab("GLB"):
|
471 |
+
output_model_glb = gr.Model3D(
|
472 |
+
label="Output Model (GLB Format)",
|
473 |
+
#width=768,
|
474 |
+
interactive=False,
|
475 |
+
)
|
476 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
477 |
+
|
478 |
+
with gr.Row():
|
479 |
+
gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
|
480 |
+
|
481 |
+
gr.Markdown(_CITE_)
|
482 |
+
mv_images = gr.State()
|
483 |
+
|
484 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
485 |
+
fn=preprocess,
|
486 |
+
inputs=[input_image, do_remove_background],
|
487 |
+
outputs=[processed_image],
|
488 |
+
).success(
|
489 |
+
fn=generate_mvs,
|
490 |
+
inputs=[processed_image, sample_steps, sample_seed],
|
491 |
+
outputs=[mv_images, mv_show_images],
|
492 |
+
).success(
|
493 |
+
fn=make3d,
|
494 |
+
inputs=[mv_images],
|
495 |
+
outputs=[output_video, output_model_obj, output_model_glb]
|
496 |
+
)
|
497 |
+
|
498 |
+
demo.queue(max_size=10)
|
499 |
+
demo.launch(server_port=1211)
|
configs/PRM.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.0e-06
|
3 |
+
target: src.model_mesh.MVRecon
|
4 |
+
params:
|
5 |
+
mesh_save_root: Objaverse
|
6 |
+
init_ckpt: nerf_base.ckpt
|
7 |
+
input_size: 512
|
8 |
+
render_size: 512
|
9 |
+
use_tv_loss: true
|
10 |
+
sample_points: null
|
11 |
+
use_gt_albedo: false
|
12 |
+
|
13 |
+
lrm_generator_config:
|
14 |
+
target: src.models.lrm_mesh.PRM
|
15 |
+
params:
|
16 |
+
encoder_feat_dim: 768
|
17 |
+
encoder_freeze: false
|
18 |
+
encoder_model_name: facebook/dino-vitb16
|
19 |
+
transformer_dim: 1024
|
20 |
+
transformer_layers: 16
|
21 |
+
transformer_heads: 16
|
22 |
+
triplane_low_res: 32
|
23 |
+
triplane_high_res: 64
|
24 |
+
triplane_dim: 80
|
25 |
+
rendering_samples_per_ray: 128
|
26 |
+
grid_res: 128
|
27 |
+
grid_scale: 2.1
|
28 |
+
|
29 |
+
|
30 |
+
data:
|
31 |
+
target: src.data.objaverse.DataModuleFromConfig
|
32 |
+
params:
|
33 |
+
batch_size: 1
|
34 |
+
num_workers: 8
|
35 |
+
train:
|
36 |
+
target: src.data.objaverse.ObjaverseData
|
37 |
+
params:
|
38 |
+
root_dir: Objaverse
|
39 |
+
light_dir: env_mipmap
|
40 |
+
input_view_num: [6]
|
41 |
+
target_view_num: 6
|
42 |
+
total_view_n: 18
|
43 |
+
distance: 5.0
|
44 |
+
fov: 30
|
45 |
+
camera_random: true
|
46 |
+
validation: false
|
47 |
+
validation:
|
48 |
+
target: src.data.objaverse.ValidationData
|
49 |
+
params:
|
50 |
+
root_dir: Objaverse
|
51 |
+
input_view_num: 6
|
52 |
+
input_image_size: 320
|
53 |
+
fov: 30
|
54 |
+
|
55 |
+
|
56 |
+
lightning:
|
57 |
+
modelcheckpoint:
|
58 |
+
params:
|
59 |
+
every_n_train_steps: 100
|
60 |
+
save_top_k: -1
|
61 |
+
save_last: true
|
62 |
+
callbacks: {}
|
63 |
+
|
64 |
+
trainer:
|
65 |
+
benchmark: true
|
66 |
+
max_epochs: -1
|
67 |
+
val_check_interval: 2000000000
|
68 |
+
num_sanity_val_steps: 0
|
69 |
+
accumulate_grad_batches: 8
|
70 |
+
log_every_n_steps: 1
|
71 |
+
check_val_every_n_epoch: null # if not set this, validation does not run
|
configs/PRM_inference.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.PRM
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/final_ckpt.ckpt
|
21 |
+
texture_resolution: 2048
|
22 |
+
render_resolution: 512
|
light2map.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from src.models.geometry.render import renderutils as ru
|
3 |
+
import torch
|
4 |
+
from src.models.geometry.render import util
|
5 |
+
import nvdiffrast.torch as dr
|
6 |
+
import os
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms.functional as TF
|
10 |
+
import torchvision.utils as vutils
|
11 |
+
import imageio
|
12 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
|
13 |
+
LIGHT_MIN_RES = 16
|
14 |
+
|
15 |
+
MIN_ROUGHNESS = 0.04
|
16 |
+
MAX_ROUGHNESS = 1.00
|
17 |
+
|
18 |
+
class cubemap_mip(torch.autograd.Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(ctx, cubemap):
|
21 |
+
return util.avg_pool_nhwc(cubemap, (2,2))
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def backward(ctx, dout):
|
25 |
+
res = dout.shape[1] * 2
|
26 |
+
out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
|
27 |
+
for s in range(6):
|
28 |
+
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
|
29 |
+
torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
|
30 |
+
indexing='ij')
|
31 |
+
v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
|
32 |
+
out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
|
33 |
+
return out
|
34 |
+
|
35 |
+
def build_mips(base, cutoff=0.99):
|
36 |
+
specular = [base]
|
37 |
+
while specular[-1].shape[1] > LIGHT_MIN_RES:
|
38 |
+
specular.append(cubemap_mip.apply(specular[-1]))
|
39 |
+
#specular.append(util.avg_pool_nhwc(specular[-1], (2,2)))
|
40 |
+
|
41 |
+
diffuse = ru.diffuse_cubemap(specular[-1])
|
42 |
+
|
43 |
+
for idx in range(len(specular) - 1):
|
44 |
+
roughness = (idx / (len(specular) - 2)) * (MAX_ROUGHNESS - MIN_ROUGHNESS) + MIN_ROUGHNESS
|
45 |
+
specular[idx] = ru.specular_cubemap(specular[idx], roughness, cutoff)
|
46 |
+
specular[-1] = ru.specular_cubemap(specular[-1], 1.0, cutoff)
|
47 |
+
|
48 |
+
return specular, diffuse
|
49 |
+
|
50 |
+
|
51 |
+
# Load from latlong .HDR file
|
52 |
+
def _load_env_hdr(fn, scale=1.0):
|
53 |
+
latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
|
54 |
+
cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
|
55 |
+
|
56 |
+
specular, diffuse = build_mips(cubemap)
|
57 |
+
|
58 |
+
return specular, diffuse
|
59 |
+
|
60 |
+
def main(path_hdr, save_path_map):
|
61 |
+
all_envs = os.listdir(path_hdr)
|
62 |
+
|
63 |
+
for env in all_envs:
|
64 |
+
env_path = os.path.join(path_hdr, env)
|
65 |
+
base_n = os.path.basename(env_path).split('.')[0]
|
66 |
+
|
67 |
+
try:
|
68 |
+
if not os.path.exists(os.path.join(save_path_map, base_n)):
|
69 |
+
os.makedirs(os.path.join(save_path_map, base_n))
|
70 |
+
specular, diffuse = _load_env_hdr(env_path)
|
71 |
+
for i in range(len(specular)):
|
72 |
+
tensor = specular[i]
|
73 |
+
torch.save(tensor, os.path.join(save_path_map, base_n, f'specular_{i}.pth'))
|
74 |
+
|
75 |
+
torch.save(diffuse, os.path.join(save_path_map, base_n, 'diffuse.pth'))
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error processing {env}: {e}")
|
78 |
+
continue
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
if len(sys.argv) != 3:
|
82 |
+
print("Usage: python script.py <path_hdr> <save_path_map>")
|
83 |
+
sys.exit(1)
|
84 |
+
|
85 |
+
path_hdr = sys.argv[1]
|
86 |
+
save_path_map = sys.argv[2]
|
87 |
+
|
88 |
+
if not os.path.exists(path_hdr):
|
89 |
+
print(f"Error: path_hdr '{path_hdr}' does not exist.")
|
90 |
+
sys.exit(1)
|
91 |
+
|
92 |
+
if not os.path.exists(save_path_map):
|
93 |
+
os.makedirs(save_path_map)
|
94 |
+
|
95 |
+
main(path_hdr, save_path_map)
|
obj2mesh.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import psutil
|
5 |
+
import gc
|
6 |
+
from tqdm import tqdm
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
from src.data.objaverse import load_obj
|
9 |
+
from src.utils import mesh
|
10 |
+
from src.utils.material import Material
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
|
14 |
+
def bytes_to_megabytes(bytes):
|
15 |
+
return bytes / (1024 * 1024)
|
16 |
+
|
17 |
+
|
18 |
+
def bytes_to_gigabytes(bytes):
|
19 |
+
return bytes / (1024 * 1024 * 1024)
|
20 |
+
|
21 |
+
|
22 |
+
def print_memory_usage(stage):
|
23 |
+
process = psutil.Process(os.getpid())
|
24 |
+
memory_info = process.memory_info()
|
25 |
+
allocated = torch.cuda.memory_allocated() / 1024**2
|
26 |
+
cached = torch.cuda.memory_reserved() / 1024**2
|
27 |
+
print(
|
28 |
+
f"[{stage}] Process memory: {memory_info.rss / 1024**2:.2f} MB, "
|
29 |
+
f"Allocated CUDA memory: {allocated:.2f} MB, Cached CUDA memory: {cached:.2f} MB"
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def process_obj(index, root_dir, final_save_dir, paths):
|
34 |
+
obj_path = os.path.join(root_dir, paths[index], paths[index] + '.obj')
|
35 |
+
mtl_path = os.path.join(root_dir, paths[index], paths[index] + '.mtl')
|
36 |
+
|
37 |
+
if os.path.exists(os.path.join(final_save_dir, f"{paths[index]}.pth")):
|
38 |
+
return None
|
39 |
+
|
40 |
+
try:
|
41 |
+
with torch.no_grad():
|
42 |
+
ref_mesh, vertices, faces, normals, nfaces, texcoords, tfaces, uber_material = load_obj(
|
43 |
+
obj_path, return_attributes=True
|
44 |
+
)
|
45 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
+
ref_mesh = mesh.compute_tangents(ref_mesh)
|
47 |
+
|
48 |
+
with open(mtl_path, 'r') as file:
|
49 |
+
lines = file.readlines()
|
50 |
+
|
51 |
+
if len(lines) >= 250:
|
52 |
+
return None
|
53 |
+
|
54 |
+
final_mesh_attributes = {
|
55 |
+
"v_pos": ref_mesh.v_pos.detach().cpu(),
|
56 |
+
"v_nrm": ref_mesh.v_nrm.detach().cpu(),
|
57 |
+
"v_tex": ref_mesh.v_tex.detach().cpu(),
|
58 |
+
"v_tng": ref_mesh.v_tng.detach().cpu(),
|
59 |
+
"t_pos_idx": ref_mesh.t_pos_idx.detach().cpu(),
|
60 |
+
"t_nrm_idx": ref_mesh.t_nrm_idx.detach().cpu(),
|
61 |
+
"t_tex_idx": ref_mesh.t_tex_idx.detach().cpu(),
|
62 |
+
"t_tng_idx": ref_mesh.t_tng_idx.detach().cpu(),
|
63 |
+
"mat_dict": {key: ref_mesh.material[key] for key in ref_mesh.material.mat_keys},
|
64 |
+
}
|
65 |
+
|
66 |
+
torch.save(final_mesh_attributes, f"{final_save_dir}/{paths[index]}.pth")
|
67 |
+
print(f"==> Saved to {final_save_dir}/{paths[index]}.pth")
|
68 |
+
|
69 |
+
del ref_mesh
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
return paths[index]
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Failed to process {paths[index]}: {e}")
|
75 |
+
return None
|
76 |
+
|
77 |
+
finally:
|
78 |
+
gc.collect()
|
79 |
+
torch.cuda.empty_cache()
|
80 |
+
|
81 |
+
|
82 |
+
def main(root_dir, save_dir):
|
83 |
+
os.makedirs(save_dir, exist_ok=True)
|
84 |
+
finish_lists = os.listdir(save_dir)
|
85 |
+
paths = os.listdir(root_dir)
|
86 |
+
|
87 |
+
valid_uid = []
|
88 |
+
|
89 |
+
print_memory_usage("Start")
|
90 |
+
|
91 |
+
batch_size = 100
|
92 |
+
num_batches = (len(paths) + batch_size - 1) // batch_size
|
93 |
+
|
94 |
+
for batch in tqdm(range(num_batches)):
|
95 |
+
start_index = batch * batch_size
|
96 |
+
end_index = min(start_index + batch_size, len(paths))
|
97 |
+
|
98 |
+
with ThreadPoolExecutor(max_workers=8) as executor:
|
99 |
+
futures = [
|
100 |
+
executor.submit(process_obj, index, root_dir, save_dir, paths)
|
101 |
+
for index in range(start_index, end_index)
|
102 |
+
]
|
103 |
+
for future in as_completed(futures):
|
104 |
+
result = future.result()
|
105 |
+
if result is not None:
|
106 |
+
valid_uid.append(result)
|
107 |
+
|
108 |
+
print_memory_usage(f"=====> After processing batch {batch + 1}")
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
gc.collect()
|
111 |
+
|
112 |
+
print_memory_usage("End")
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
parser = argparse.ArgumentParser(description="Process OBJ files and save final results.")
|
117 |
+
parser.add_argument("root_dir", type=str, help="Directory containing the root OBJ files.")
|
118 |
+
parser.add_argument("save_dir", type=str, help="Directory to save the processed results.")
|
119 |
+
args = parser.parse_args()
|
120 |
+
|
121 |
+
main(args.root_dir, args.save_dir)
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==2.1.2
|
2 |
+
gradio==3.41.2
|
3 |
+
huggingface-hub
|
4 |
+
einops
|
5 |
+
omegaconf
|
6 |
+
torchmetrics
|
7 |
+
webdataset
|
8 |
+
accelerate
|
9 |
+
tensorboard
|
10 |
+
PyMCubes
|
11 |
+
trimesh
|
12 |
+
rembg
|
13 |
+
transformers==4.34.1
|
14 |
+
diffusers==0.20.2
|
15 |
+
bitsandbytes
|
16 |
+
imageio[ffmpeg]
|
17 |
+
xatlas
|
18 |
+
plyfile
|
19 |
+
git+https://github.com/NVlabs/nvdiffrast/
|
20 |
+
PyGLM==2.7.0
|
21 |
+
open3d
|
run.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import glm
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import rembg
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision.transforms import v2
|
9 |
+
import torchvision
|
10 |
+
from pytorch_lightning import seed_everything
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from tqdm import tqdm
|
14 |
+
from huggingface_hub import hf_hub_download
|
15 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
16 |
+
|
17 |
+
from src.data.objaverse import load_mipmap
|
18 |
+
from src.utils import render_utils
|
19 |
+
from src.utils.train_util import instantiate_from_config
|
20 |
+
from src.utils.camera_util import (
|
21 |
+
FOV_to_intrinsics,
|
22 |
+
center_looking_at_camera_pose,
|
23 |
+
get_zero123plus_input_cameras,
|
24 |
+
get_circular_camera_poses,
|
25 |
+
)
|
26 |
+
from src.utils.mesh_util import save_obj, save_obj_with_mtl
|
27 |
+
from src.utils.infer_util import remove_background, resize_foreground, save_video
|
28 |
+
|
29 |
+
def str_to_tuple(arg_str):
|
30 |
+
try:
|
31 |
+
return eval(arg_str)
|
32 |
+
except:
|
33 |
+
raise argparse.ArgumentTypeError("Tuple argument must be in the format (x, y)")
|
34 |
+
|
35 |
+
|
36 |
+
def get_render_cameras(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50):
|
37 |
+
"""
|
38 |
+
Get the rendering camera parameters.
|
39 |
+
"""
|
40 |
+
train_res = [512, 512]
|
41 |
+
cam_near_far = [0.1, 1000.0]
|
42 |
+
fovy = np.deg2rad(fov)
|
43 |
+
proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])
|
44 |
+
all_mv = []
|
45 |
+
all_mvp = []
|
46 |
+
all_campos = []
|
47 |
+
if isinstance(elevation, tuple):
|
48 |
+
elevation_0 = np.deg2rad(elevation[0])
|
49 |
+
elevation_1 = np.deg2rad(elevation[1])
|
50 |
+
for i in range(M//2):
|
51 |
+
azimuth = 2 * np.pi * i / (M // 2)
|
52 |
+
z = radius * np.cos(azimuth) * np.sin(elevation_0)
|
53 |
+
x = radius * np.sin(azimuth) * np.sin(elevation_0)
|
54 |
+
y = radius * np.cos(elevation_0)
|
55 |
+
|
56 |
+
eye = glm.vec3(x, y, z)
|
57 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
58 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
59 |
+
view_matrix = glm.lookAt(eye, at, up)
|
60 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
61 |
+
mvp = proj_mtx @ (mv) #w2c
|
62 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
63 |
+
all_mv.append(mv[None, ...].cuda())
|
64 |
+
all_mvp.append(mvp[None, ...].cuda())
|
65 |
+
all_campos.append(campos[None, ...].cuda())
|
66 |
+
for i in range(M//2):
|
67 |
+
azimuth = 2 * np.pi * i / (M // 2)
|
68 |
+
z = radius * np.cos(azimuth) * np.sin(elevation_1)
|
69 |
+
x = radius * np.sin(azimuth) * np.sin(elevation_1)
|
70 |
+
y = radius * np.cos(elevation_1)
|
71 |
+
|
72 |
+
eye = glm.vec3(x, y, z)
|
73 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
74 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
75 |
+
view_matrix = glm.lookAt(eye, at, up)
|
76 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
77 |
+
mvp = proj_mtx @ (mv) #w2c
|
78 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
79 |
+
all_mv.append(mv[None, ...].cuda())
|
80 |
+
all_mvp.append(mvp[None, ...].cuda())
|
81 |
+
all_campos.append(campos[None, ...].cuda())
|
82 |
+
else:
|
83 |
+
# elevation = 90 - elevation
|
84 |
+
for i in range(M):
|
85 |
+
azimuth = 2 * np.pi * i / M
|
86 |
+
z = radius * np.cos(azimuth) * np.sin(elevation)
|
87 |
+
x = radius * np.sin(azimuth) * np.sin(elevation)
|
88 |
+
y = radius * np.cos(elevation)
|
89 |
+
|
90 |
+
eye = glm.vec3(x, y, z)
|
91 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
92 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
93 |
+
view_matrix = glm.lookAt(eye, at, up)
|
94 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
95 |
+
mvp = proj_mtx @ (mv) #w2c
|
96 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
97 |
+
all_mv.append(mv[None, ...].cuda())
|
98 |
+
all_mvp.append(mvp[None, ...].cuda())
|
99 |
+
all_campos.append(campos[None, ...].cuda())
|
100 |
+
all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2)
|
101 |
+
all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2)
|
102 |
+
all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2)
|
103 |
+
return all_mv, all_mvp, all_campos
|
104 |
+
|
105 |
+
def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1, is_flexicubes=False):
|
106 |
+
"""
|
107 |
+
Render frames from triplanes.
|
108 |
+
"""
|
109 |
+
frames = []
|
110 |
+
albedos = []
|
111 |
+
pbr_spec_lights = []
|
112 |
+
pbr_diffuse_lights = []
|
113 |
+
normals = []
|
114 |
+
alphas = []
|
115 |
+
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
116 |
+
if is_flexicubes:
|
117 |
+
out = model.forward_geometry(
|
118 |
+
planes,
|
119 |
+
render_cameras[:, i:i+chunk_size],
|
120 |
+
camera_pos[:, i:i+chunk_size],
|
121 |
+
[[env]*chunk_size],
|
122 |
+
[[materials]*chunk_size],
|
123 |
+
render_size=render_size,
|
124 |
+
)
|
125 |
+
frame = out['pbr_img']
|
126 |
+
albedo = out['albedo']
|
127 |
+
pbr_spec_light = out['pbr_spec_light']
|
128 |
+
pbr_diffuse_light = out['pbr_diffuse_light']
|
129 |
+
normal = out['normal']
|
130 |
+
alpha = out['mask']
|
131 |
+
else:
|
132 |
+
frame = model.forward_synthesizer(
|
133 |
+
planes,
|
134 |
+
render_cameras[i],
|
135 |
+
render_size=render_size,
|
136 |
+
)['images_rgb']
|
137 |
+
frames.append(frame)
|
138 |
+
albedos.append(albedo)
|
139 |
+
pbr_spec_lights.append(pbr_spec_light)
|
140 |
+
pbr_diffuse_lights.append(pbr_diffuse_light)
|
141 |
+
normals.append(normal)
|
142 |
+
alphas.append(alpha)
|
143 |
+
|
144 |
+
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
|
145 |
+
alphas = torch.cat(alphas, dim=1)[0]
|
146 |
+
albedos = torch.cat(albedos, dim=1)[0]
|
147 |
+
pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0]
|
148 |
+
pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0]
|
149 |
+
normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3]
|
150 |
+
return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas
|
151 |
+
|
152 |
+
|
153 |
+
###############################################################################
|
154 |
+
# Arguments.
|
155 |
+
###############################################################################
|
156 |
+
|
157 |
+
parser = argparse.ArgumentParser()
|
158 |
+
parser.add_argument('config', type=str, help='Path to config file.')
|
159 |
+
parser.add_argument('input_path', type=str, help='Path to input image or directory.')
|
160 |
+
parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')
|
161 |
+
parser.add_argument('--model_ckpt_path', type=str, default="", help='Output directory.')
|
162 |
+
parser.add_argument('--diffusion_steps', type=int, default=100, help='Denoising Sampling steps.')
|
163 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')
|
164 |
+
parser.add_argument('--scale', type=float, default=1.0, help='Scale of generated object.')
|
165 |
+
parser.add_argument('--materials', type=str_to_tuple, default=(1.0, 0.1), help=' metallic and roughness')
|
166 |
+
parser.add_argument('--distance', type=float, default=4.5, help='Render distance.')
|
167 |
+
parser.add_argument('--fov', type=float, default=30, help='Render distance.')
|
168 |
+
parser.add_argument('--env_path', type=str, default='data/env_mipmap/2', help='environment map')
|
169 |
+
parser.add_argument('--view', type=int, default=6, choices=[4, 6], help='Number of input views.')
|
170 |
+
parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')
|
171 |
+
parser.add_argument('--export_texmap', action='store_true', help='Export a mesh with texture map.')
|
172 |
+
parser.add_argument('--save_video', action='store_true', help='Save a circular-view video.')
|
173 |
+
args = parser.parse_args()
|
174 |
+
seed_everything(args.seed)
|
175 |
+
|
176 |
+
###############################################################################
|
177 |
+
# Stage 0: Configuration.
|
178 |
+
###############################################################################
|
179 |
+
|
180 |
+
config = OmegaConf.load(args.config)
|
181 |
+
config_name = os.path.basename(args.config).replace('.yaml', '')
|
182 |
+
model_config = config.model_config
|
183 |
+
infer_config = config.infer_config
|
184 |
+
|
185 |
+
IS_FLEXICUBES = True
|
186 |
+
|
187 |
+
device = torch.device('cuda')
|
188 |
+
|
189 |
+
# load diffusion model
|
190 |
+
print('Loading diffusion model ...')
|
191 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
192 |
+
"sudo-ai/zero123plus-v1.2",
|
193 |
+
custom_pipeline="zero123plus",
|
194 |
+
torch_dtype=torch.float16,
|
195 |
+
)
|
196 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
197 |
+
pipeline.scheduler.config, timestep_spacing='trailing'
|
198 |
+
)
|
199 |
+
|
200 |
+
# load custom white-background UNet
|
201 |
+
print('Loading custom white-background unet ...')
|
202 |
+
if os.path.exists(infer_config.unet_path):
|
203 |
+
unet_ckpt_path = infer_config.unet_path
|
204 |
+
else:
|
205 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="diffusion_pytorch_model.bin", repo_type="model")
|
206 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
207 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
208 |
+
|
209 |
+
pipeline = pipeline.to(device)
|
210 |
+
|
211 |
+
# load reconstruction model
|
212 |
+
print('Loading reconstruction model ...')
|
213 |
+
model = instantiate_from_config(model_config)
|
214 |
+
if os.path.exists(infer_config.model_path):
|
215 |
+
model_ckpt_path = infer_config.model_path
|
216 |
+
else:
|
217 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
218 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
219 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
220 |
+
model.load_state_dict(state_dict, strict=True)
|
221 |
+
|
222 |
+
model = model.to(device)
|
223 |
+
if IS_FLEXICUBES:
|
224 |
+
model.init_flexicubes_geometry(device, fovy=50.0)
|
225 |
+
model = model.eval()
|
226 |
+
|
227 |
+
# make output directories
|
228 |
+
image_path = os.path.join(args.output_path, config_name, 'images')
|
229 |
+
mesh_path = os.path.join(args.output_path, config_name, 'meshes')
|
230 |
+
video_path = os.path.join(args.output_path, config_name, 'videos')
|
231 |
+
os.makedirs(image_path, exist_ok=True)
|
232 |
+
os.makedirs(mesh_path, exist_ok=True)
|
233 |
+
os.makedirs(video_path, exist_ok=True)
|
234 |
+
|
235 |
+
# process input files
|
236 |
+
if os.path.isdir(args.input_path):
|
237 |
+
input_files = [
|
238 |
+
os.path.join(args.input_path, file)
|
239 |
+
for file in os.listdir(args.input_path)
|
240 |
+
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.webp')
|
241 |
+
]
|
242 |
+
else:
|
243 |
+
input_files = [args.input_path]
|
244 |
+
print(f'Total number of input images: {len(input_files)}')
|
245 |
+
|
246 |
+
###############################################################################
|
247 |
+
# Stage 1: Multiview generation.
|
248 |
+
###############################################################################
|
249 |
+
|
250 |
+
rembg_session = None if args.no_rembg else rembg.new_session()
|
251 |
+
|
252 |
+
outputs = []
|
253 |
+
for idx, image_file in enumerate(input_files):
|
254 |
+
name = os.path.basename(image_file).split('.')[0]
|
255 |
+
print(f'[{idx+1}/{len(input_files)}] Imagining {name} ...')
|
256 |
+
|
257 |
+
# remove background optionally
|
258 |
+
input_image = Image.open(image_file)
|
259 |
+
if not args.no_rembg:
|
260 |
+
input_image = remove_background(input_image, rembg_session)
|
261 |
+
input_image = resize_foreground(input_image, 0.85)
|
262 |
+
# sampling
|
263 |
+
output_image = pipeline(
|
264 |
+
input_image,
|
265 |
+
num_inference_steps=args.diffusion_steps,
|
266 |
+
).images[0]
|
267 |
+
print(f"Image saved to {os.path.join(image_path, f'{name}.png')}")
|
268 |
+
|
269 |
+
images = np.asarray(output_image, dtype=np.float32) / 255.0
|
270 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
271 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
272 |
+
torchvision.utils.save_image(images, os.path.join(image_path, f'{name}.png'))
|
273 |
+
sample = {'name': name, 'images': images}
|
274 |
+
|
275 |
+
# delete pipeline to save memory
|
276 |
+
# del pipeline
|
277 |
+
|
278 |
+
###############################################################################
|
279 |
+
# Stage 2: Reconstruction.
|
280 |
+
###############################################################################
|
281 |
+
|
282 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=3.2*args.scale, fov=30).to(device)
|
283 |
+
chunk_size = 20 if IS_FLEXICUBES else 1
|
284 |
+
|
285 |
+
# for idx, sample in enumerate(outputs):
|
286 |
+
name = sample['name']
|
287 |
+
print(f'[{idx+1}/{len(outputs)}] Creating {name} ...')
|
288 |
+
|
289 |
+
images = sample['images'].unsqueeze(0).to(device)
|
290 |
+
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
291 |
+
|
292 |
+
with torch.no_grad():
|
293 |
+
# get triplane
|
294 |
+
planes = model.forward_planes(images, input_cameras)
|
295 |
+
|
296 |
+
mesh_path_idx = os.path.join(mesh_path, f'{name}.obj')
|
297 |
+
|
298 |
+
mesh_out = model.extract_mesh(
|
299 |
+
planes,
|
300 |
+
use_texture_map=args.export_texmap,
|
301 |
+
**infer_config,
|
302 |
+
)
|
303 |
+
if args.export_texmap:
|
304 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
305 |
+
save_obj_with_mtl(
|
306 |
+
vertices.data.cpu().numpy(),
|
307 |
+
uvs.data.cpu().numpy(),
|
308 |
+
faces.data.cpu().numpy(),
|
309 |
+
mesh_tex_idx.data.cpu().numpy(),
|
310 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
311 |
+
mesh_path_idx,
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
vertices, faces, vertex_colors = mesh_out
|
315 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
316 |
+
print(f"Mesh saved to {mesh_path_idx}")
|
317 |
+
|
318 |
+
render_size = 512
|
319 |
+
if args.save_video:
|
320 |
+
video_path_idx = os.path.join(video_path, f'{name}.mp4')
|
321 |
+
render_size = infer_config.render_resolution
|
322 |
+
ENV = load_mipmap(args.env_path)
|
323 |
+
materials = args.materials
|
324 |
+
|
325 |
+
all_mv, all_mvp, all_campos = get_render_cameras(
|
326 |
+
batch_size=1,
|
327 |
+
M=240,
|
328 |
+
radius=args.distance,
|
329 |
+
elevation=(90, 60.0),
|
330 |
+
is_flexicubes=IS_FLEXICUBES,
|
331 |
+
fov=args.fov
|
332 |
+
)
|
333 |
+
|
334 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
335 |
+
model,
|
336 |
+
planes,
|
337 |
+
render_cameras=all_mvp,
|
338 |
+
camera_pos=all_campos,
|
339 |
+
env=ENV,
|
340 |
+
materials=materials,
|
341 |
+
render_size=render_size,
|
342 |
+
chunk_size=chunk_size,
|
343 |
+
is_flexicubes=IS_FLEXICUBES,
|
344 |
+
)
|
345 |
+
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
346 |
+
normals = normals * alphas + (1-alphas)
|
347 |
+
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
348 |
+
|
349 |
+
# breakpoint()
|
350 |
+
save_video(
|
351 |
+
all_frames,
|
352 |
+
video_path_idx,
|
353 |
+
fps=30,
|
354 |
+
)
|
355 |
+
print(f"Video saved to {video_path_idx}")
|
run.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python run.py configs/PRM_inference.yaml examples/ \
|
2 |
+
--seed 10 \
|
3 |
+
--materials "(0.0, 0.9)" \
|
4 |
+
--env_path "./env_mipmap/6" \
|
5 |
+
--output_path "output/" \
|
6 |
+
--save_video \
|
7 |
+
--export_texmap \
|
run_hpc.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
source /hpc2ssd/softwares/anaconda3/bin/activate instantmesh
|
2 |
+
module load cuda/12.1 compilers/gcc-11.1.0 compilers/icc-2023.1.0 cmake/3.27.0
|
3 |
+
export CXX=$(which g++)
|
4 |
+
export CC=$(which gcc)
|
5 |
+
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
+
export CUDA_LAUNCH_BLOCKING=1
|
7 |
+
export NCCL_TIMEOUT=3600
|
8 |
+
export CUDA_VISIBLE_DEVICES="0"
|
9 |
+
# python app.py
|
10 |
+
python run.py configs/PRM_inference.yaml examples/恐龙套装.webp \
|
11 |
+
--seed 10 \
|
12 |
+
--materials "(0.0, 0.9)" \
|
13 |
+
--env_path "./env_mipmap/6" \
|
14 |
+
--output_path "output/" \
|
15 |
+
--save_video \
|
16 |
+
--export_texmap \
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (139 Bytes). View file
|
|
src/data/__init__.py
ADDED
File without changes
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (144 Bytes). View file
|
|
src/data/__pycache__/objaverse.cpython-310.pyc
ADDED
Binary file (14.9 kB). View file
|
|
src/data/bsdf_256_256.bin
ADDED
Binary file (524 kB). View file
|
|
src/data/objaverse.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
import glm
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import webdataset as wds
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
import sys
|
13 |
+
from src.utils import obj, render_utils
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torch.utils.data.distributed import DistributedSampler
|
18 |
+
import random
|
19 |
+
import itertools
|
20 |
+
from src.utils.train_util import instantiate_from_config
|
21 |
+
from src.utils.camera_util import (
|
22 |
+
FOV_to_intrinsics,
|
23 |
+
center_looking_at_camera_pose,
|
24 |
+
get_circular_camera_poses,
|
25 |
+
)
|
26 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
|
27 |
+
import re
|
28 |
+
|
29 |
+
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
|
30 |
+
azimuths = np.deg2rad(azimuths)
|
31 |
+
elevations = np.deg2rad(elevations)
|
32 |
+
|
33 |
+
xs = radius * np.cos(elevations) * np.cos(azimuths)
|
34 |
+
ys = radius * np.cos(elevations) * np.sin(azimuths)
|
35 |
+
zs = radius * np.sin(elevations)
|
36 |
+
|
37 |
+
cam_locations = np.stack([xs, ys, zs], axis=-1)
|
38 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
39 |
+
|
40 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
41 |
+
return c2ws
|
42 |
+
|
43 |
+
def find_matching_files(base_path, idx):
|
44 |
+
formatted_idx = '%03d' % idx
|
45 |
+
pattern = re.compile(r'^%s_\d+\.png$' % formatted_idx)
|
46 |
+
matching_files = []
|
47 |
+
|
48 |
+
if os.path.exists(base_path):
|
49 |
+
for filename in os.listdir(base_path):
|
50 |
+
if pattern.match(filename):
|
51 |
+
matching_files.append(filename)
|
52 |
+
|
53 |
+
return os.path.join(base_path, matching_files[0])
|
54 |
+
|
55 |
+
def load_mipmap(env_path):
|
56 |
+
diffuse_path = os.path.join(env_path, "diffuse.pth")
|
57 |
+
diffuse = torch.load(diffuse_path, map_location=torch.device('cpu'))
|
58 |
+
|
59 |
+
specular = []
|
60 |
+
for i in range(6):
|
61 |
+
specular_path = os.path.join(env_path, f"specular_{i}.pth")
|
62 |
+
specular_tensor = torch.load(specular_path, map_location=torch.device('cpu'))
|
63 |
+
specular.append(specular_tensor)
|
64 |
+
return [specular, diffuse]
|
65 |
+
|
66 |
+
def convert_to_white_bg(image, write_bg=True):
|
67 |
+
alpha = image[:, :, 3:]
|
68 |
+
if write_bg:
|
69 |
+
return image[:, :, :3] * alpha + 1. * (1 - alpha)
|
70 |
+
else:
|
71 |
+
return image[:, :, :3] * alpha
|
72 |
+
|
73 |
+
def load_obj(path, return_attributes=False, scale_factor=1.0):
|
74 |
+
return obj.load_obj(path, clear_ks=True, mtl_override=None, return_attributes=return_attributes, scale_factor=scale_factor)
|
75 |
+
|
76 |
+
def custom_collate_fn(batch):
|
77 |
+
return batch
|
78 |
+
|
79 |
+
|
80 |
+
def collate_fn_wrapper(batch):
|
81 |
+
return custom_collate_fn(batch)
|
82 |
+
|
83 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
batch_size=8,
|
87 |
+
num_workers=4,
|
88 |
+
train=None,
|
89 |
+
validation=None,
|
90 |
+
test=None,
|
91 |
+
**kwargs,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self.batch_size = batch_size
|
96 |
+
self.num_workers = num_workers
|
97 |
+
|
98 |
+
self.dataset_configs = dict()
|
99 |
+
if train is not None:
|
100 |
+
self.dataset_configs['train'] = train
|
101 |
+
if validation is not None:
|
102 |
+
self.dataset_configs['validation'] = validation
|
103 |
+
if test is not None:
|
104 |
+
self.dataset_configs['test'] = test
|
105 |
+
|
106 |
+
def setup(self, stage):
|
107 |
+
|
108 |
+
if stage in ['fit']:
|
109 |
+
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
110 |
+
else:
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
def custom_collate_fn(self, batch):
|
114 |
+
collated_batch = {}
|
115 |
+
for key in batch[0].keys():
|
116 |
+
if key == 'input_env' or key == 'target_env':
|
117 |
+
collated_batch[key] = [d[key] for d in batch]
|
118 |
+
else:
|
119 |
+
collated_batch[key] = torch.stack([d[key] for d in batch], dim=0)
|
120 |
+
return collated_batch
|
121 |
+
|
122 |
+
def convert_to_white_bg(self, image):
|
123 |
+
alpha = image[:, :, 3:]
|
124 |
+
return image[:, :, :3] * alpha + 1. * (1 - alpha)
|
125 |
+
|
126 |
+
def load_obj(self, path):
|
127 |
+
return obj.load_obj(path, clear_ks=True, mtl_override=None)
|
128 |
+
|
129 |
+
def train_dataloader(self):
|
130 |
+
|
131 |
+
sampler = DistributedSampler(self.datasets['train'])
|
132 |
+
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper)
|
133 |
+
|
134 |
+
def val_dataloader(self):
|
135 |
+
|
136 |
+
sampler = DistributedSampler(self.datasets['validation'])
|
137 |
+
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper)
|
138 |
+
|
139 |
+
def test_dataloader(self):
|
140 |
+
|
141 |
+
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
142 |
+
|
143 |
+
|
144 |
+
class ObjaverseData(Dataset):
|
145 |
+
def __init__(self,
|
146 |
+
root_dir='Objaverse_highQuality',
|
147 |
+
light_dir= 'env_mipmap',
|
148 |
+
input_view_num=6,
|
149 |
+
target_view_num=4,
|
150 |
+
total_view_n=18,
|
151 |
+
distance=3.5,
|
152 |
+
fov=50,
|
153 |
+
camera_random=False,
|
154 |
+
validation=False,
|
155 |
+
):
|
156 |
+
self.root_dir = Path(root_dir)
|
157 |
+
self.light_dir = light_dir
|
158 |
+
self.all_env_name = []
|
159 |
+
for temp_dir in os.listdir(light_dir):
|
160 |
+
if os.listdir(os.path.join(self.light_dir, temp_dir)):
|
161 |
+
self.all_env_name.append(temp_dir)
|
162 |
+
|
163 |
+
self.input_view_num = input_view_num
|
164 |
+
self.target_view_num = target_view_num
|
165 |
+
self.total_view_n = total_view_n
|
166 |
+
self.fov = fov
|
167 |
+
self.camera_random = camera_random
|
168 |
+
|
169 |
+
self.train_res = [512, 512]
|
170 |
+
self.cam_near_far = [0.1, 1000.0]
|
171 |
+
self.fov_rad = np.deg2rad(fov)
|
172 |
+
self.fov_deg = fov
|
173 |
+
self.spp = 1
|
174 |
+
self.cam_radius = distance
|
175 |
+
self.layers = 1
|
176 |
+
|
177 |
+
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
178 |
+
self.combinations = list(itertools.product(numbers, repeat=2))
|
179 |
+
|
180 |
+
self.paths = os.listdir(self.root_dir)
|
181 |
+
|
182 |
+
# with open("BJ_Mesh_list.json", 'r') as file:
|
183 |
+
# self.paths = json.load(file)
|
184 |
+
|
185 |
+
print('total training object num:', len(self.paths))
|
186 |
+
|
187 |
+
self.depth_scale = 6.0
|
188 |
+
|
189 |
+
total_objects = len(self.paths)
|
190 |
+
print('============= length of dataset %d =============' % total_objects)
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.paths)
|
194 |
+
|
195 |
+
def load_obj(self, path):
|
196 |
+
return obj.load_obj(path, clear_ks=True, mtl_override=None)
|
197 |
+
|
198 |
+
def sample_spherical(self, phi, theta, cam_radius):
|
199 |
+
theta = np.deg2rad(theta)
|
200 |
+
phi = np.deg2rad(phi)
|
201 |
+
|
202 |
+
z = cam_radius * np.cos(phi) * np.sin(theta)
|
203 |
+
x = cam_radius * np.sin(phi) * np.sin(theta)
|
204 |
+
y = cam_radius * np.cos(theta)
|
205 |
+
|
206 |
+
return x, y, z
|
207 |
+
|
208 |
+
def _random_scene(self, cam_radius, fov_rad):
|
209 |
+
iter_res = self.train_res
|
210 |
+
proj_mtx = render_utils.perspective(fov_rad, iter_res[1] / iter_res[0], self.cam_near_far[0], self.cam_near_far[1])
|
211 |
+
|
212 |
+
azimuths = random.uniform(0, 360)
|
213 |
+
elevations = random.uniform(30, 150)
|
214 |
+
mv_embedding = spherical_camera_pose(azimuths, 90-elevations, cam_radius)
|
215 |
+
x, y, z = self.sample_spherical(azimuths, elevations, cam_radius)
|
216 |
+
eye = glm.vec3(x, y, z)
|
217 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
218 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
219 |
+
view_matrix = glm.lookAt(eye, at, up)
|
220 |
+
mv = torch.from_numpy(np.array(view_matrix))
|
221 |
+
mvp = proj_mtx @ (mv) #w2c
|
222 |
+
campos = torch.linalg.inv(mv)[:3, 3]
|
223 |
+
return mv[None, ...], mvp[None, ...], campos[None, ...], mv_embedding[None, ...], iter_res, self.spp # Add batch dimension
|
224 |
+
|
225 |
+
def load_im(self, path, color):
|
226 |
+
'''
|
227 |
+
replace background pixel with random color in rendering
|
228 |
+
'''
|
229 |
+
pil_img = Image.open(path)
|
230 |
+
|
231 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
232 |
+
alpha = image[:, :, 3:]
|
233 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
234 |
+
|
235 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
236 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
237 |
+
return image, alpha
|
238 |
+
|
239 |
+
def load_albedo(self, path, color, mask):
|
240 |
+
'''
|
241 |
+
replace background pixel with random color in rendering
|
242 |
+
'''
|
243 |
+
pil_img = Image.open(path)
|
244 |
+
|
245 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
246 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
247 |
+
|
248 |
+
color = torch.ones_like(image)
|
249 |
+
image = image * mask + color * (1 - mask)
|
250 |
+
return image
|
251 |
+
|
252 |
+
def convert_to_white_bg(self, image):
|
253 |
+
alpha = image[:, :, 3:]
|
254 |
+
return image[:, :, :3] * alpha + 1. * (1 - alpha)
|
255 |
+
|
256 |
+
def calculate_fov(self, initial_distance, initial_fov, new_distance):
|
257 |
+
initial_fov_rad = math.radians(initial_fov)
|
258 |
+
|
259 |
+
height = 2 * initial_distance * math.tan(initial_fov_rad / 2)
|
260 |
+
|
261 |
+
new_fov_rad = 2 * math.atan(height / (2 * new_distance))
|
262 |
+
|
263 |
+
new_fov = math.degrees(new_fov_rad)
|
264 |
+
|
265 |
+
return new_fov
|
266 |
+
|
267 |
+
def __getitem__(self, index):
|
268 |
+
obj_path = os.path.join(self.root_dir, self.paths[index])
|
269 |
+
mesh_attributes = torch.load(obj_path, map_location=torch.device('cpu'))
|
270 |
+
pose_list = []
|
271 |
+
env_list = []
|
272 |
+
material_list = []
|
273 |
+
camera_pos = []
|
274 |
+
c2w_list = []
|
275 |
+
camera_embedding_list = []
|
276 |
+
random_env = False
|
277 |
+
random_mr = False
|
278 |
+
if random.random() > 0.5:
|
279 |
+
random_env = True
|
280 |
+
if random.random() > 0.5:
|
281 |
+
random_mr = True
|
282 |
+
selected_env = random.randint(0, len(self.all_env_name)-1)
|
283 |
+
materials = random.choice(self.combinations)
|
284 |
+
if self.camera_random:
|
285 |
+
random_perturbation = random.uniform(-1.5, 1.5)
|
286 |
+
cam_radius = self.cam_radius + random_perturbation
|
287 |
+
fov_deg = self.calculate_fov(initial_distance=self.cam_radius, initial_fov=self.fov_deg, new_distance=cam_radius)
|
288 |
+
fov_rad = np.deg2rad(fov_deg)
|
289 |
+
else:
|
290 |
+
cam_radius = self.cam_radius
|
291 |
+
fov_rad = self.fov_rad
|
292 |
+
fov_deg = self.fov_deg
|
293 |
+
|
294 |
+
if len(self.input_view_num) >= 1:
|
295 |
+
input_view_num = random.choice(self.input_view_num)
|
296 |
+
else:
|
297 |
+
input_view_num = self.input_view_num
|
298 |
+
for _ in range(input_view_num + self.target_view_num):
|
299 |
+
mv, mvp, campos, mv_mebedding, iter_res, iter_spp = self._random_scene(cam_radius, fov_rad)
|
300 |
+
if random_env:
|
301 |
+
selected_env = random.randint(0, len(self.all_env_name)-1)
|
302 |
+
env_path = os.path.join(self.light_dir, self.all_env_name[selected_env])
|
303 |
+
env = load_mipmap(env_path)
|
304 |
+
if random_mr:
|
305 |
+
materials = random.choice(self.combinations)
|
306 |
+
pose_list.append(mvp)
|
307 |
+
camera_pos.append(campos)
|
308 |
+
c2w_list.append(mv)
|
309 |
+
env_list.append(env)
|
310 |
+
material_list.append(materials)
|
311 |
+
camera_embedding_list.append(mv_mebedding)
|
312 |
+
data = {
|
313 |
+
'mesh_attributes': mesh_attributes,
|
314 |
+
'input_view_num': input_view_num,
|
315 |
+
'target_view_num': self.target_view_num,
|
316 |
+
'obj_path': obj_path,
|
317 |
+
'pose_list': pose_list,
|
318 |
+
'camera_pos': camera_pos,
|
319 |
+
'c2w_list': c2w_list,
|
320 |
+
'env_list': env_list,
|
321 |
+
'material_list': material_list,
|
322 |
+
'camera_embedding_list': camera_embedding_list,
|
323 |
+
'fov_deg':fov_deg,
|
324 |
+
'raduis': cam_radius
|
325 |
+
}
|
326 |
+
|
327 |
+
return data
|
328 |
+
|
329 |
+
class ValidationData(Dataset):
|
330 |
+
def __init__(self,
|
331 |
+
root_dir='objaverse/',
|
332 |
+
input_view_num=6,
|
333 |
+
input_image_size=320,
|
334 |
+
fov=30,
|
335 |
+
):
|
336 |
+
self.root_dir = Path(root_dir)
|
337 |
+
self.input_view_num = input_view_num
|
338 |
+
self.input_image_size = input_image_size
|
339 |
+
self.fov = fov
|
340 |
+
self.light_dir = 'env_mipmap'
|
341 |
+
|
342 |
+
# with open('Mesh_list.json') as f:
|
343 |
+
# filtered_dict = json.load(f)
|
344 |
+
|
345 |
+
self.paths = os.listdir(self.root_dir)
|
346 |
+
|
347 |
+
# self.paths = filtered_dict
|
348 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
349 |
+
|
350 |
+
cam_distance = 4.0
|
351 |
+
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
352 |
+
elevations = np.array([20, -10, 20, -10, 20, -10])
|
353 |
+
azimuths = np.deg2rad(azimuths)
|
354 |
+
elevations = np.deg2rad(elevations)
|
355 |
+
|
356 |
+
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
|
357 |
+
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
|
358 |
+
z = cam_distance * np.sin(elevations)
|
359 |
+
|
360 |
+
cam_locations = np.stack([x, y, z], axis=-1)
|
361 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
362 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
363 |
+
self.c2ws = c2ws.float()
|
364 |
+
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
365 |
+
|
366 |
+
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0)
|
367 |
+
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
368 |
+
self.render_c2ws = render_c2ws.float()
|
369 |
+
self.render_Ks = render_Ks.float()
|
370 |
+
|
371 |
+
def __len__(self):
|
372 |
+
return len(self.paths)
|
373 |
+
|
374 |
+
def load_im(self, path, color):
|
375 |
+
'''
|
376 |
+
replace background pixel with random color in rendering
|
377 |
+
'''
|
378 |
+
pil_img = Image.open(path)
|
379 |
+
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
|
380 |
+
|
381 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
382 |
+
if image.shape[-1] == 4:
|
383 |
+
alpha = image[:, :, 3:]
|
384 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
385 |
+
else:
|
386 |
+
alpha = np.ones_like(image[:, :, :1])
|
387 |
+
|
388 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
389 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
390 |
+
return image, alpha
|
391 |
+
|
392 |
+
def load_mat(self, path, color):
|
393 |
+
'''
|
394 |
+
replace background pixel with random color in rendering
|
395 |
+
'''
|
396 |
+
pil_img = Image.open(path)
|
397 |
+
pil_img = pil_img.resize((384,384), resample=Image.BICUBIC)
|
398 |
+
|
399 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
400 |
+
if image.shape[-1] == 4:
|
401 |
+
alpha = image[:, :, 3:]
|
402 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
403 |
+
else:
|
404 |
+
alpha = np.ones_like(image[:, :, :1])
|
405 |
+
|
406 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
407 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
408 |
+
return image, alpha
|
409 |
+
|
410 |
+
def load_albedo(self, path, color, mask):
|
411 |
+
'''
|
412 |
+
replace background pixel with random color in rendering
|
413 |
+
'''
|
414 |
+
pil_img = Image.open(path)
|
415 |
+
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
|
416 |
+
|
417 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
418 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
419 |
+
|
420 |
+
color = torch.ones_like(image)
|
421 |
+
image = image * mask + color * (1 - mask)
|
422 |
+
return image
|
423 |
+
|
424 |
+
def __getitem__(self, index):
|
425 |
+
|
426 |
+
# load data
|
427 |
+
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
428 |
+
|
429 |
+
'''background color, default: white'''
|
430 |
+
bkg_color = [1.0, 1.0, 1.0]
|
431 |
+
|
432 |
+
image_list = []
|
433 |
+
albedo_list = []
|
434 |
+
alpha_list = []
|
435 |
+
specular_list = []
|
436 |
+
diffuse_list = []
|
437 |
+
metallic_list = []
|
438 |
+
roughness_list = []
|
439 |
+
|
440 |
+
exist_comb_list = []
|
441 |
+
for subfolder in os.listdir(input_image_path):
|
442 |
+
found_numeric_subfolder=False
|
443 |
+
subfolder_path = os.path.join(input_image_path, subfolder)
|
444 |
+
if os.path.isdir(subfolder_path) and '_' in subfolder and 'specular' not in subfolder and 'diffuse' not in subfolder:
|
445 |
+
try:
|
446 |
+
parts = subfolder.split('_')
|
447 |
+
float(parts[0]) # 尝试将分隔符前后的字符串转换为浮点数
|
448 |
+
float(parts[1])
|
449 |
+
found_numeric_subfolder = True
|
450 |
+
except ValueError:
|
451 |
+
continue
|
452 |
+
if found_numeric_subfolder:
|
453 |
+
exist_comb_list.append(subfolder)
|
454 |
+
|
455 |
+
selected_one_comb = random.choice(exist_comb_list)
|
456 |
+
|
457 |
+
|
458 |
+
for idx in range(self.input_view_num):
|
459 |
+
img_path = find_matching_files(os.path.join(input_image_path, selected_one_comb, 'rgb'), idx)
|
460 |
+
albedo_path = img_path.replace('rgb', 'albedo')
|
461 |
+
metallic_path = img_path.replace('rgb', 'metallic')
|
462 |
+
roughness_path = img_path.replace('rgb', 'roughness')
|
463 |
+
|
464 |
+
image, alpha = self.load_im(img_path, bkg_color)
|
465 |
+
albedo = self.load_albedo(albedo_path, bkg_color, alpha)
|
466 |
+
metallic,_ = self.load_mat(metallic_path, bkg_color)
|
467 |
+
roughness,_ = self.load_mat(roughness_path, bkg_color)
|
468 |
+
|
469 |
+
light_num = os.path.basename(img_path).split('_')[1].split('.')[0]
|
470 |
+
light_path = os.path.join(self.light_dir, str(int(light_num)+1))
|
471 |
+
|
472 |
+
specular, diffuse = load_mipmap(light_path)
|
473 |
+
|
474 |
+
image_list.append(image)
|
475 |
+
alpha_list.append(alpha)
|
476 |
+
albedo_list.append(albedo)
|
477 |
+
metallic_list.append(metallic)
|
478 |
+
roughness_list.append(roughness)
|
479 |
+
specular_list.append(specular)
|
480 |
+
diffuse_list.append(diffuse)
|
481 |
+
|
482 |
+
images = torch.stack(image_list, dim=0).float()
|
483 |
+
alphas = torch.stack(alpha_list, dim=0).float()
|
484 |
+
albedo = torch.stack(albedo_list, dim=0).float()
|
485 |
+
metallic = torch.stack(metallic_list, dim=0).float()
|
486 |
+
roughness = torch.stack(roughness_list, dim=0).float()
|
487 |
+
|
488 |
+
data = {
|
489 |
+
'input_images': images,
|
490 |
+
'input_alphas': alphas,
|
491 |
+
'input_c2ws': self.c2ws,
|
492 |
+
'input_Ks': self.Ks,
|
493 |
+
|
494 |
+
'input_albedos': albedo[:self.input_view_num],
|
495 |
+
'input_metallics': metallic[:self.input_view_num],
|
496 |
+
'input_roughness': roughness[:self.input_view_num],
|
497 |
+
|
498 |
+
'specular': specular_list[:self.input_view_num],
|
499 |
+
'diffuse': diffuse_list[:self.input_view_num],
|
500 |
+
|
501 |
+
'render_c2ws': self.render_c2ws,
|
502 |
+
'render_Ks': self.render_Ks,
|
503 |
+
}
|
504 |
+
return data
|
505 |
+
|
506 |
+
|
507 |
+
if __name__ == '__main__':
|
508 |
+
dataset = ObjaverseData()
|
509 |
+
dataset.new(1)
|
src/model_mesh.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import gc
|
7 |
+
from torchvision.transforms import v2
|
8 |
+
from torchvision.utils import make_grid, save_image
|
9 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from src.utils.camera_util import FOV_to_intrinsics
|
13 |
+
from src.utils.material import Material
|
14 |
+
from src.utils.train_util import instantiate_from_config
|
15 |
+
import nvdiffrast.torch as dr
|
16 |
+
from src.utils import render
|
17 |
+
from src.utils.mesh import Mesh, compute_tangents
|
18 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
19 |
+
|
20 |
+
# from pytorch3d.transforms import quaternion_to_matrix, euler_angles_to_matrix
|
21 |
+
GLCTX = [None] * torch.cuda.device_count()
|
22 |
+
|
23 |
+
def initialize_extension(gpu_id):
|
24 |
+
global GLCTX
|
25 |
+
if GLCTX[gpu_id] is None:
|
26 |
+
print(f"Initializing extension module renderutils_plugin on GPU {gpu_id}...")
|
27 |
+
torch.cuda.set_device(gpu_id)
|
28 |
+
GLCTX[gpu_id] = dr.RasterizeCudaContext()
|
29 |
+
return GLCTX[gpu_id]
|
30 |
+
|
31 |
+
# Regulrarization loss for FlexiCubes
|
32 |
+
def sdf_reg_loss_batch(sdf, all_edges):
|
33 |
+
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
|
34 |
+
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
35 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
36 |
+
sdf_diff = F.binary_cross_entropy_with_logits(
|
37 |
+
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
38 |
+
F.binary_cross_entropy_with_logits(
|
39 |
+
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
40 |
+
return sdf_diff
|
41 |
+
|
42 |
+
def rotate_x(a, device=None):
|
43 |
+
s, c = np.sin(a), np.cos(a)
|
44 |
+
return torch.tensor([[1, 0, 0, 0],
|
45 |
+
[0, c,-s, 0],
|
46 |
+
[0, s, c, 0],
|
47 |
+
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
48 |
+
|
49 |
+
|
50 |
+
def convert_to_white_bg(image, write_bg=True):
|
51 |
+
alpha = image[:, :, 3:]
|
52 |
+
if write_bg:
|
53 |
+
return image[:, :, :3] * alpha + 1. * (1 - alpha)
|
54 |
+
else:
|
55 |
+
return image[:, :, :3] * alpha
|
56 |
+
|
57 |
+
|
58 |
+
class MVRecon(pl.LightningModule):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
lrm_generator_config,
|
62 |
+
input_size=256,
|
63 |
+
render_size=512,
|
64 |
+
init_ckpt=None,
|
65 |
+
use_tv_loss=True,
|
66 |
+
mesh_save_root="Objaverse_highQuality",
|
67 |
+
sample_points=None,
|
68 |
+
use_gt_albedo=False,
|
69 |
+
):
|
70 |
+
super(MVRecon, self).__init__()
|
71 |
+
|
72 |
+
self.use_gt_albedo = use_gt_albedo
|
73 |
+
self.use_tv_loss = use_tv_loss
|
74 |
+
self.input_size = input_size
|
75 |
+
self.render_size = render_size
|
76 |
+
self.mesh_save_root = mesh_save_root
|
77 |
+
self.sample_points = sample_points
|
78 |
+
|
79 |
+
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
80 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
81 |
+
|
82 |
+
if init_ckpt is not None:
|
83 |
+
sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
|
84 |
+
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
|
85 |
+
sd_fc = {}
|
86 |
+
for k, v in sd.items():
|
87 |
+
if k.startswith('lrm_generator.synthesizer.decoder.net.'):
|
88 |
+
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
|
89 |
+
# Here we assume the density filed's isosurface threshold is t,
|
90 |
+
# we reverse the sign of density filed to initialize SDF field.
|
91 |
+
# -(w*x + b - t) = (-w)*x + (t - b)
|
92 |
+
if 'weight' in k:
|
93 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
|
94 |
+
else:
|
95 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = 10.0 - v[0:1]
|
96 |
+
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
|
97 |
+
else:
|
98 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = v
|
99 |
+
sd_fc[k.replace('net.', 'net_rgb.')] = v
|
100 |
+
else:
|
101 |
+
sd_fc[k] = v
|
102 |
+
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
|
103 |
+
# missing `net_deformation` and `net_weight` parameters
|
104 |
+
self.lrm_generator.load_state_dict(sd_fc, strict=False)
|
105 |
+
print(f'Loaded weights from {init_ckpt}')
|
106 |
+
|
107 |
+
self.validation_step_outputs = []
|
108 |
+
|
109 |
+
def on_fit_start(self):
|
110 |
+
device = torch.device(f'cuda:{self.local_rank}')
|
111 |
+
self.lrm_generator.init_flexicubes_geometry(device)
|
112 |
+
if self.global_rank == 0:
|
113 |
+
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
114 |
+
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
115 |
+
|
116 |
+
def collate_fn(self, batch):
|
117 |
+
gpu_id = torch.cuda.current_device() # 获取当前线程的 GPU ID
|
118 |
+
glctx = initialize_extension(gpu_id)
|
119 |
+
batch_size = len(batch)
|
120 |
+
input_view_num = batch[0]["input_view_num"]
|
121 |
+
target_view_num = batch[0]["target_view_num"]
|
122 |
+
iter_res = [512, 512]
|
123 |
+
iter_spp = 1
|
124 |
+
layers = 1
|
125 |
+
|
126 |
+
# Initialize lists for input and target data
|
127 |
+
input_images, input_alphas, input_depths, input_normals, input_albedos = [], [], [], [], []
|
128 |
+
input_spec_light, input_diff_light, input_spec_albedo,input_diff_albedo = [], [], [], []
|
129 |
+
input_w2cs, input_Ks, input_camera_pos, input_c2ws = [], [], [], []
|
130 |
+
input_env, input_materials = [], []
|
131 |
+
input_camera_embeddings = [] # camera_embedding_list
|
132 |
+
|
133 |
+
target_images, target_alphas, target_depths, target_normals, target_albedos = [], [], [], [], []
|
134 |
+
target_spec_light, target_diff_light, target_spec_albedo, target_diff_albedo = [], [], [], []
|
135 |
+
target_w2cs, target_Ks, target_camera_pos = [], [], []
|
136 |
+
target_env, target_materials = [], []
|
137 |
+
|
138 |
+
for sample in batch:
|
139 |
+
obj_path = sample['obj_path']
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
mesh_attributes = sample['mesh_attributes']
|
143 |
+
v_pos = mesh_attributes["v_pos"].to(self.device)
|
144 |
+
v_nrm = mesh_attributes["v_nrm"].to(self.device)
|
145 |
+
v_tex = mesh_attributes["v_tex"].to(self.device)
|
146 |
+
v_tng = mesh_attributes["v_tng"].to(self.device)
|
147 |
+
t_pos_idx = mesh_attributes["t_pos_idx"].to(self.device)
|
148 |
+
t_nrm_idx = mesh_attributes["t_nrm_idx"].to(self.device)
|
149 |
+
t_tex_idx = mesh_attributes["t_tex_idx"].to(self.device)
|
150 |
+
t_tng_idx = mesh_attributes["t_tng_idx"].to(self.device)
|
151 |
+
material = Material(mesh_attributes["mat_dict"])
|
152 |
+
material = material.to(self.device)
|
153 |
+
ref_mesh = Mesh(v_pos=v_pos, v_nrm=v_nrm, v_tex=v_tex, v_tng=v_tng,
|
154 |
+
t_pos_idx=t_pos_idx, t_nrm_idx=t_nrm_idx,
|
155 |
+
t_tex_idx=t_tex_idx, t_tng_idx=t_tng_idx, material=material)
|
156 |
+
|
157 |
+
pose_list_sample = sample['pose_list'] # mvp
|
158 |
+
camera_pos_sample = sample['camera_pos'] # campos, mv.inverse
|
159 |
+
c2w_list_sample = sample['c2w_list'] # mv
|
160 |
+
env_list_sample = sample['env_list']
|
161 |
+
material_list_sample = sample['material_list']
|
162 |
+
camera_embeddings = sample["camera_embedding_list"]
|
163 |
+
fov_deg = sample['fov_deg']
|
164 |
+
raduis = sample['raduis']
|
165 |
+
# print(f"fov_deg:{fov_deg}, raduis:{raduis}")
|
166 |
+
|
167 |
+
sample_input_images, sample_input_alphas, sample_input_depths, sample_input_normals, sample_input_albedos = [], [], [], [], []
|
168 |
+
sample_input_w2cs, sample_input_Ks, sample_input_camera_pos, sample_input_c2ws = [], [], [], []
|
169 |
+
sample_input_camera_embeddings = []
|
170 |
+
sample_input_spec_light, sample_input_diff_light = [], []
|
171 |
+
|
172 |
+
sample_target_images, sample_target_alphas, sample_target_depths, sample_target_normals, sample_target_albedos = [], [], [], [], []
|
173 |
+
sample_target_w2cs, sample_target_Ks, sample_target_camera_pos = [], [], []
|
174 |
+
sample_target_spec_light, sample_target_diff_light = [], []
|
175 |
+
|
176 |
+
sample_input_env = []
|
177 |
+
sample_input_materials = []
|
178 |
+
sample_target_env = []
|
179 |
+
sample_target_materials = []
|
180 |
+
|
181 |
+
for i in range(len(pose_list_sample)):
|
182 |
+
mvp = pose_list_sample[i]
|
183 |
+
campos = camera_pos_sample[i]
|
184 |
+
env = env_list_sample[i]
|
185 |
+
materials = material_list_sample[i]
|
186 |
+
camera_embedding = camera_embeddings[i]
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
buffer_dict = render.render_mesh(glctx, ref_mesh, mvp.to(self.device), campos.to(self.device), [env], None, None,
|
190 |
+
materials, iter_res, spp=iter_spp, num_layers=layers, msaa=True,
|
191 |
+
background=None, gt_render=True)
|
192 |
+
|
193 |
+
image = convert_to_white_bg(buffer_dict['shaded'][0])
|
194 |
+
albedo = convert_to_white_bg(buffer_dict['albedo'][0]).clamp(0., 1.)
|
195 |
+
alpha = buffer_dict['mask'][0][:, :, 3:]
|
196 |
+
depth = convert_to_white_bg(buffer_dict['depth'][0])
|
197 |
+
normal = convert_to_white_bg(buffer_dict['gb_normal'][0], write_bg=False)
|
198 |
+
spec_light = convert_to_white_bg(buffer_dict['spec_light'][0])
|
199 |
+
diff_light = convert_to_white_bg(buffer_dict['diff_light'][0])
|
200 |
+
if i < input_view_num:
|
201 |
+
sample_input_images.append(image)
|
202 |
+
sample_input_albedos.append(albedo)
|
203 |
+
sample_input_alphas.append(alpha)
|
204 |
+
sample_input_depths.append(depth)
|
205 |
+
sample_input_normals.append(normal)
|
206 |
+
sample_input_spec_light.append(spec_light)
|
207 |
+
sample_input_diff_light.append(diff_light)
|
208 |
+
sample_input_w2cs.append(mvp)
|
209 |
+
sample_input_camera_pos.append(campos)
|
210 |
+
sample_input_c2ws.append(c2w_list_sample[i])
|
211 |
+
sample_input_Ks.append(FOV_to_intrinsics(fov_deg))
|
212 |
+
sample_input_env.append(env)
|
213 |
+
sample_input_materials.append(materials)
|
214 |
+
sample_input_camera_embeddings.append(camera_embedding)
|
215 |
+
else:
|
216 |
+
sample_target_images.append(image)
|
217 |
+
sample_target_albedos.append(albedo)
|
218 |
+
sample_target_alphas.append(alpha)
|
219 |
+
sample_target_depths.append(depth)
|
220 |
+
sample_target_normals.append(normal)
|
221 |
+
sample_target_spec_light.append(spec_light)
|
222 |
+
sample_target_diff_light.append(diff_light)
|
223 |
+
sample_target_w2cs.append(mvp)
|
224 |
+
sample_target_camera_pos.append(campos)
|
225 |
+
sample_target_Ks.append(FOV_to_intrinsics(fov_deg))
|
226 |
+
sample_target_env.append(env)
|
227 |
+
sample_target_materials.append(materials)
|
228 |
+
|
229 |
+
input_images.append(torch.stack(sample_input_images, dim=0).permute(0, 3, 1, 2))
|
230 |
+
input_albedos.append(torch.stack(sample_input_albedos, dim=0).permute(0, 3, 1, 2))
|
231 |
+
input_alphas.append(torch.stack(sample_input_alphas, dim=0).permute(0, 3, 1, 2))
|
232 |
+
input_depths.append(torch.stack(sample_input_depths, dim=0).permute(0, 3, 1, 2))
|
233 |
+
input_normals.append(torch.stack(sample_input_normals, dim=0).permute(0, 3, 1, 2))
|
234 |
+
input_spec_light.append(torch.stack(sample_input_spec_light, dim=0).permute(0, 3, 1, 2))
|
235 |
+
input_diff_light.append(torch.stack(sample_input_diff_light, dim=0).permute(0, 3, 1, 2))
|
236 |
+
input_w2cs.append(torch.stack(sample_input_w2cs, dim=0))
|
237 |
+
input_camera_pos.append(torch.stack(sample_input_camera_pos, dim=0))
|
238 |
+
input_c2ws.append(torch.stack(sample_input_c2ws, dim=0))
|
239 |
+
input_camera_embeddings.append(torch.stack(sample_input_camera_embeddings, dim=0))
|
240 |
+
input_Ks.append(torch.stack(sample_input_Ks, dim=0))
|
241 |
+
input_env.append(sample_input_env)
|
242 |
+
input_materials.append(sample_input_materials)
|
243 |
+
|
244 |
+
target_images.append(torch.stack(sample_target_images, dim=0).permute(0, 3, 1, 2))
|
245 |
+
target_albedos.append(torch.stack(sample_target_albedos, dim=0).permute(0, 3, 1, 2))
|
246 |
+
target_alphas.append(torch.stack(sample_target_alphas, dim=0).permute(0, 3, 1, 2))
|
247 |
+
target_depths.append(torch.stack(sample_target_depths, dim=0).permute(0, 3, 1, 2))
|
248 |
+
target_normals.append(torch.stack(sample_target_normals, dim=0).permute(0, 3, 1, 2))
|
249 |
+
target_spec_light.append(torch.stack(sample_target_spec_light, dim=0).permute(0, 3, 1, 2))
|
250 |
+
target_diff_light.append(torch.stack(sample_target_diff_light, dim=0).permute(0, 3, 1, 2))
|
251 |
+
target_w2cs.append(torch.stack(sample_target_w2cs, dim=0))
|
252 |
+
target_camera_pos.append(torch.stack(sample_target_camera_pos, dim=0))
|
253 |
+
target_Ks.append(torch.stack(sample_target_Ks, dim=0))
|
254 |
+
target_env.append(sample_target_env)
|
255 |
+
target_materials.append(sample_target_materials)
|
256 |
+
|
257 |
+
del ref_mesh
|
258 |
+
del material
|
259 |
+
del mesh_attributes
|
260 |
+
torch.cuda.empty_cache()
|
261 |
+
gc.collect()
|
262 |
+
|
263 |
+
data = {
|
264 |
+
'input_images': torch.stack(input_images, dim=0).detach().cpu(), # (batch_size, input_view_num, 3, H, W)
|
265 |
+
'input_alphas': torch.stack(input_alphas, dim=0).detach().cpu(), # (batch_size, input_view_num, 1, H, W)
|
266 |
+
'input_depths': torch.stack(input_depths, dim=0).detach().cpu(),
|
267 |
+
'input_normals': torch.stack(input_normals, dim=0).detach().cpu(),
|
268 |
+
'input_albedos': torch.stack(input_albedos, dim=0).detach().cpu(),
|
269 |
+
'input_spec_light': torch.stack(input_spec_light, dim=0).detach().cpu(),
|
270 |
+
'input_diff_light': torch.stack(input_diff_light, dim=0).detach().cpu(),
|
271 |
+
'input_materials': input_materials,
|
272 |
+
'input_w2cs': torch.stack(input_w2cs, dim=0).squeeze(2), # (batch_size, input_view_num, 4, 4)
|
273 |
+
'input_Ks': torch.stack(input_Ks, dim=0).float(), # (batch_size, input_view_num, 3, 3)
|
274 |
+
'input_env': input_env,
|
275 |
+
'input_camera_pos': torch.stack(input_camera_pos, dim=0).squeeze(2), # (batch_size, input_view_num, 3)
|
276 |
+
'input_c2ws': torch.stack(input_c2ws, dim=0).squeeze(2), # (batch_size, input_view_num, 4, 4)
|
277 |
+
'input_camera_embedding': torch.stack(input_camera_embeddings, dim=0).squeeze(2),
|
278 |
+
|
279 |
+
'target_sample_points': None,
|
280 |
+
'target_images': torch.stack(target_images, dim=0).detach().cpu(), # (batch_size, target_view_num, 3, H, W)
|
281 |
+
'target_alphas': torch.stack(target_alphas, dim=0).detach().cpu(), # (batch_size, target_view_num, 1, H, W)
|
282 |
+
'target_depths': torch.stack(target_depths, dim=0).detach().cpu(),
|
283 |
+
'target_normals': torch.stack(target_normals, dim=0).detach().cpu(),
|
284 |
+
'target_albedos': torch.stack(target_albedos, dim=0).detach().cpu(),
|
285 |
+
'target_spec_light': torch.stack(target_spec_light, dim=0).detach().cpu(),
|
286 |
+
'target_diff_light': torch.stack(target_diff_light, dim=0).detach().cpu(),
|
287 |
+
'target_materials': target_materials,
|
288 |
+
'target_w2cs': torch.stack(target_w2cs, dim=0).squeeze(2), # (batch_size, target_view_num, 4, 4)
|
289 |
+
'target_Ks': torch.stack(target_Ks, dim=0).float(), # (batch_size, target_view_num, 3, 3)
|
290 |
+
'target_env': target_env,
|
291 |
+
'target_camera_pos': torch.stack(target_camera_pos, dim=0).squeeze(2) # (batch_size, target_view_num, 3)
|
292 |
+
}
|
293 |
+
|
294 |
+
return data
|
295 |
+
|
296 |
+
def prepare_batch_data(self, batch):
|
297 |
+
# breakpoint()
|
298 |
+
lrm_generator_input = {}
|
299 |
+
render_gt = {}
|
300 |
+
|
301 |
+
# input images
|
302 |
+
images = batch['input_images']
|
303 |
+
images = v2.functional.resize(images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
304 |
+
batch_size = images.shape[0]
|
305 |
+
# breakpoint()
|
306 |
+
lrm_generator_input['images'] = images.to(self.device)
|
307 |
+
|
308 |
+
# input cameras and render cameras
|
309 |
+
# input_c2ws = batch['input_c2ws']
|
310 |
+
input_Ks = batch['input_Ks']
|
311 |
+
# target_c2ws = batch['target_c2ws']
|
312 |
+
input_camera_embedding = batch["input_camera_embedding"].to(self.device)
|
313 |
+
|
314 |
+
input_w2cs = batch['input_w2cs']
|
315 |
+
target_w2cs = batch['target_w2cs']
|
316 |
+
render_w2cs = torch.cat([input_w2cs, target_w2cs], dim=1)
|
317 |
+
|
318 |
+
input_camera_pos = batch['input_camera_pos']
|
319 |
+
target_camera_pos = batch['target_camera_pos']
|
320 |
+
render_camera_pos = torch.cat([input_camera_pos, target_camera_pos], dim=1)
|
321 |
+
|
322 |
+
input_extrinsics = input_camera_embedding.flatten(-2)
|
323 |
+
input_extrinsics = input_extrinsics[:, :, :12]
|
324 |
+
input_intrinsics = input_Ks.flatten(-2).to(self.device)
|
325 |
+
input_intrinsics = torch.stack([
|
326 |
+
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
|
327 |
+
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
|
328 |
+
], dim=-1)
|
329 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
330 |
+
|
331 |
+
# add noise to input_cameras
|
332 |
+
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
333 |
+
|
334 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
335 |
+
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
336 |
+
lrm_generator_input['cameras_pos'] = render_camera_pos.to(self.device)
|
337 |
+
lrm_generator_input['env'] = []
|
338 |
+
lrm_generator_input['materials'] = []
|
339 |
+
for i in range(batch_size):
|
340 |
+
lrm_generator_input['env'].append( batch['input_env'][i] + batch['target_env'][i])
|
341 |
+
lrm_generator_input['materials'].append( batch['input_materials'][i] + batch['target_materials'][i])
|
342 |
+
lrm_generator_input['albedo'] = torch.cat([batch['input_albedos'],batch['target_albedos']],dim=1)
|
343 |
+
|
344 |
+
# target images
|
345 |
+
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
346 |
+
target_albedos = torch.cat([batch['input_albedos'], batch['target_albedos']], dim=1)
|
347 |
+
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
348 |
+
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
349 |
+
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
|
350 |
+
target_spec_lights = torch.cat([batch['input_spec_light'], batch['target_spec_light']], dim=1)
|
351 |
+
target_diff_lights = torch.cat([batch['input_diff_light'], batch['target_diff_light']], dim=1)
|
352 |
+
|
353 |
+
render_size = self.render_size
|
354 |
+
target_images = v2.functional.resize(
|
355 |
+
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
356 |
+
target_depths = v2.functional.resize(
|
357 |
+
target_depths, render_size, interpolation=0, antialias=True)
|
358 |
+
target_alphas = v2.functional.resize(
|
359 |
+
target_alphas, render_size, interpolation=0, antialias=True)
|
360 |
+
target_normals = v2.functional.resize(
|
361 |
+
target_normals, render_size, interpolation=3, antialias=True)
|
362 |
+
|
363 |
+
lrm_generator_input['render_size'] = render_size
|
364 |
+
|
365 |
+
render_gt['target_sample_points'] = batch['target_sample_points']
|
366 |
+
render_gt['target_images'] = target_images.to(self.device)
|
367 |
+
render_gt['target_albedos'] = target_albedos.to(self.device)
|
368 |
+
render_gt['target_depths'] = target_depths.to(self.device)
|
369 |
+
render_gt['target_alphas'] = target_alphas.to(self.device)
|
370 |
+
render_gt['target_normals'] = target_normals.to(self.device)
|
371 |
+
render_gt['target_spec_lights'] = target_spec_lights.to(self.device)
|
372 |
+
render_gt['target_diff_lights'] = target_diff_lights.to(self.device)
|
373 |
+
# render_gt['target_spec_albedos'] = target_spec_albedos.to(self.device)
|
374 |
+
# render_gt['target_diff_albedos'] = target_diff_albedos.to(self.device)
|
375 |
+
return lrm_generator_input, render_gt
|
376 |
+
|
377 |
+
def prepare_validation_batch_data(self, batch):
|
378 |
+
lrm_generator_input = {}
|
379 |
+
|
380 |
+
# input images
|
381 |
+
images = batch['input_images']
|
382 |
+
images = v2.functional.resize(
|
383 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
384 |
+
|
385 |
+
lrm_generator_input['images'] = images.to(self.device)
|
386 |
+
lrm_generator_input['specular_light'] = batch['specular']
|
387 |
+
lrm_generator_input['diffuse_light'] = batch['diffuse']
|
388 |
+
|
389 |
+
lrm_generator_input['metallic'] = batch['input_metallics']
|
390 |
+
lrm_generator_input['roughness'] = batch['input_roughness']
|
391 |
+
|
392 |
+
proj = self.perspective(0.449, 1, 0.1, 1000., self.device)
|
393 |
+
|
394 |
+
# input cameras
|
395 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
396 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
397 |
+
|
398 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
399 |
+
input_intrinsics = torch.stack([
|
400 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
401 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
402 |
+
], dim=-1)
|
403 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
404 |
+
|
405 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
406 |
+
|
407 |
+
# render cameras
|
408 |
+
render_c2ws = batch['render_c2ws']
|
409 |
+
|
410 |
+
lrm_generator_input['camera_pos'] = torch.linalg.inv(render_w2cs.to(self.device) @ rotate_x(np.pi / 2, self.device))[..., :3, 3]
|
411 |
+
render_w2cs = ( render_w2cs @ rotate_x(np.pi / 2) )
|
412 |
+
|
413 |
+
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
414 |
+
lrm_generator_input['render_size'] = 384
|
415 |
+
|
416 |
+
return lrm_generator_input
|
417 |
+
|
418 |
+
def forward_lrm_generator(self, images, cameras, camera_pos,env, materials, albedo_map, render_cameras, render_size=512, sample_points=None, gt_albedo_map=None):
|
419 |
+
planes = torch.utils.checkpoint.checkpoint(
|
420 |
+
self.lrm_generator.forward_planes,
|
421 |
+
images,
|
422 |
+
cameras,
|
423 |
+
use_reentrant=False,
|
424 |
+
)
|
425 |
+
out = self.lrm_generator.forward_geometry(
|
426 |
+
planes,
|
427 |
+
render_cameras,
|
428 |
+
camera_pos,
|
429 |
+
env,
|
430 |
+
materials,
|
431 |
+
albedo_map,
|
432 |
+
render_size,
|
433 |
+
sample_points,
|
434 |
+
gt_albedo_map
|
435 |
+
)
|
436 |
+
return out
|
437 |
+
|
438 |
+
def forward(self, lrm_generator_input, gt_albedo_map=None):
|
439 |
+
images = lrm_generator_input['images']
|
440 |
+
cameras = lrm_generator_input['cameras']
|
441 |
+
render_cameras = lrm_generator_input['render_cameras']
|
442 |
+
render_size = lrm_generator_input['render_size']
|
443 |
+
env = lrm_generator_input['env']
|
444 |
+
materials = lrm_generator_input['materials']
|
445 |
+
albedo_map = lrm_generator_input['albedo']
|
446 |
+
camera_pos = lrm_generator_input['cameras_pos']
|
447 |
+
|
448 |
+
out = self.forward_lrm_generator(
|
449 |
+
images, cameras, camera_pos, env, materials, albedo_map, render_cameras, render_size=render_size, sample_points=self.sample_points, gt_albedo_map=gt_albedo_map)
|
450 |
+
|
451 |
+
return out
|
452 |
+
|
453 |
+
def training_step(self, batch, batch_idx):
|
454 |
+
batch = self.collate_fn(batch)
|
455 |
+
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
456 |
+
if self.use_gt_albedo:
|
457 |
+
gt_albedo_map = render_gt['target_albedos']
|
458 |
+
else:
|
459 |
+
gt_albedo_map = None
|
460 |
+
render_out = self.forward(lrm_generator_input, gt_albedo_map=gt_albedo_map)
|
461 |
+
|
462 |
+
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
463 |
+
|
464 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, batch_size=len(batch['input_images']), sync_dist=True)
|
465 |
+
|
466 |
+
if self.global_step % 20 == 0 and self.global_rank == 0 :
|
467 |
+
B, N, C, H, W = render_gt['target_images'].shape
|
468 |
+
N_in = lrm_generator_input['images'].shape[1]
|
469 |
+
|
470 |
+
target_images = rearrange(render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
471 |
+
render_images = rearrange(render_out['pbr_img'], 'b n c h w -> b c h (n w)')
|
472 |
+
target_alphas = rearrange(repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
473 |
+
target_spec_light = rearrange(render_gt['target_spec_lights'], 'b n c h w -> b c h (n w)')
|
474 |
+
target_diff_light = rearrange(render_gt['target_diff_lights'], 'b n c h w -> b c h (n w)')
|
475 |
+
|
476 |
+
render_alphas = rearrange(render_out['mask'], 'b n c h w -> b c h (n w)')
|
477 |
+
render_albodos = rearrange(render_out['albedo'], 'b n c h w -> b c h (n w)')
|
478 |
+
target_albedos = rearrange(render_gt['target_albedos'], 'b n c h w -> b c h (n w)')
|
479 |
+
|
480 |
+
render_spec_light = rearrange(render_out['pbr_spec_light'], 'b n c h w -> b c h (n w)')
|
481 |
+
render_diffuse_light = rearrange(render_out['pbr_diffuse_light'], 'b n c h w -> b c h (n w)')
|
482 |
+
render_normal = rearrange(render_out['normal_img'], 'b n c h w -> b c h (n w)')
|
483 |
+
target_depths = rearrange(render_gt['target_depths'], 'b n c h w -> b c h (n w)')
|
484 |
+
render_depths = rearrange(render_out['depth'], 'b n c h w -> b c h (n w)')
|
485 |
+
target_normals = rearrange(render_gt['target_normals'], 'b n c h w -> b c h (n w)')
|
486 |
+
|
487 |
+
MAX_DEPTH = torch.max(target_depths)
|
488 |
+
target_depths = target_depths / MAX_DEPTH * target_alphas
|
489 |
+
render_depths = render_depths / MAX_DEPTH * render_alphas
|
490 |
+
|
491 |
+
grid = torch.cat([
|
492 |
+
target_images, render_images,
|
493 |
+
target_alphas, render_alphas,
|
494 |
+
target_albedos, render_albodos,
|
495 |
+
target_spec_light, render_spec_light,
|
496 |
+
target_diff_light, render_diffuse_light,
|
497 |
+
(target_normals+1)/2, (render_normal+1)/2,
|
498 |
+
target_depths, render_depths
|
499 |
+
], dim=-2).detach().cpu()
|
500 |
+
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
501 |
+
|
502 |
+
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
|
503 |
+
save_image(grid, image_path)
|
504 |
+
print(f"Saved image to {image_path}")
|
505 |
+
return loss
|
506 |
+
|
507 |
+
|
508 |
+
def total_variation_loss(self, img, beta=2.0):
|
509 |
+
bs_img, n_view, c_img, h_img, w_img = img.size()
|
510 |
+
tv_h = torch.pow(img[...,1:,:]-img[...,:-1,:], beta).sum()
|
511 |
+
tv_w = torch.pow(img[...,:,1:]-img[...,:,:-1], beta).sum()
|
512 |
+
return (tv_h+tv_w)/(bs_img*n_view*c_img*h_img*w_img)
|
513 |
+
|
514 |
+
|
515 |
+
def compute_loss(self, render_out, render_gt):
|
516 |
+
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
517 |
+
render_albedo_image = render_out['albedo']
|
518 |
+
render_pbr_image = render_out['pbr_img']
|
519 |
+
render_spec_light = render_out['pbr_spec_light']
|
520 |
+
render_diff_light = render_out['pbr_diffuse_light']
|
521 |
+
|
522 |
+
target_images = render_gt['target_images'].to(render_albedo_image)
|
523 |
+
target_albedos = render_gt['target_albedos'].to(render_albedo_image)
|
524 |
+
target_spec_light = render_gt['target_spec_lights'].to(render_albedo_image)
|
525 |
+
target_diff_light = render_gt['target_diff_lights'].to(render_albedo_image)
|
526 |
+
|
527 |
+
render_images = rearrange(render_pbr_image, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
528 |
+
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
529 |
+
|
530 |
+
render_albedos = rearrange(render_albedo_image, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
531 |
+
target_albedos = rearrange(target_albedos, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
532 |
+
|
533 |
+
render_spec_light = rearrange(render_spec_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
534 |
+
target_spec_light = rearrange(target_spec_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
535 |
+
|
536 |
+
render_diff_light = rearrange(render_diff_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
537 |
+
target_diff_light = rearrange(target_diff_light, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
538 |
+
|
539 |
+
|
540 |
+
loss_mse = F.mse_loss(render_images, target_images)
|
541 |
+
loss_mse_albedo = F.mse_loss(render_albedos, target_albedos)
|
542 |
+
loss_rgb_lpips = 2.0 * self.lpips(render_images, target_images)
|
543 |
+
loss_albedo_lpips = 2.0 * self.lpips(render_albedos, target_albedos)
|
544 |
+
|
545 |
+
loss_spec_light = F.mse_loss(render_spec_light, target_spec_light)
|
546 |
+
loss_diff_light = F.mse_loss(render_diff_light, target_diff_light)
|
547 |
+
loss_spec_light_lpips = 2.0 * self.lpips(render_spec_light.clamp(-1., 1.), target_spec_light.clamp(-1., 1.))
|
548 |
+
loss_diff_light_lpips = 2.0 * self.lpips(render_diff_light.clamp(-1., 1.), target_diff_light.clamp(-1., 1.))
|
549 |
+
|
550 |
+
render_alphas = render_out['mask'][:,:,:1,:,:]
|
551 |
+
target_alphas = render_gt['target_alphas']
|
552 |
+
|
553 |
+
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
554 |
+
render_depths = torch.mean(render_out['depth'], dim=2, keepdim=True)
|
555 |
+
target_depths = torch.mean(render_gt['target_depths'], dim=2, keepdim=True)
|
556 |
+
loss_depth = 0.5 * F.l1_loss(render_depths[(target_alphas>0)], target_depths[target_alphas>0])
|
557 |
+
|
558 |
+
render_normals = render_out['normal'][...,:3].permute(0,3,1,2).unsqueeze(0)
|
559 |
+
target_normals = render_gt['target_normals']
|
560 |
+
similarity = (render_normals * target_normals).sum(dim=-3).abs()
|
561 |
+
normal_mask = target_alphas.squeeze(-3)
|
562 |
+
loss_normal = 1 - similarity[normal_mask>0].mean()
|
563 |
+
loss_normal = 0.2 * loss_normal * 1.0
|
564 |
+
|
565 |
+
# tv loss
|
566 |
+
if self.use_tv_loss:
|
567 |
+
triplane = render_out['triplane']
|
568 |
+
tv_loss = self.total_variation_loss(triplane, beta=2.0)
|
569 |
+
|
570 |
+
# flexicubes regularization loss
|
571 |
+
sdf = render_out['sdf']
|
572 |
+
sdf_reg_loss = render_out['sdf_reg_loss']
|
573 |
+
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
|
574 |
+
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
|
575 |
+
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
|
576 |
+
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
|
577 |
+
|
578 |
+
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
|
579 |
+
loss_reg = loss_reg
|
580 |
+
loss = loss_mse + loss_rgb_lpips + loss_albedo_lpips + loss_mask + loss_reg + loss_mse_albedo + loss_depth + \
|
581 |
+
loss_normal + loss_spec_light + loss_diff_light + loss_spec_light_lpips + loss_diff_light_lpips
|
582 |
+
if self.use_tv_loss:
|
583 |
+
loss += tv_loss * 2e-4
|
584 |
+
|
585 |
+
prefix = 'train'
|
586 |
+
loss_dict = {}
|
587 |
+
|
588 |
+
loss_dict.update({f'{prefix}/loss_mse': loss_mse.item()})
|
589 |
+
loss_dict.update({f'{prefix}/loss_mse_albedo': loss_mse_albedo.item()})
|
590 |
+
loss_dict.update({f'{prefix}/loss_rgb_lpips': loss_rgb_lpips.item()})
|
591 |
+
loss_dict.update({f'{prefix}/loss_albedo_lpips': loss_albedo_lpips.item()})
|
592 |
+
loss_dict.update({f'{prefix}/loss_mask': loss_mask.item()})
|
593 |
+
loss_dict.update({f'{prefix}/loss_normal': loss_normal.item()})
|
594 |
+
loss_dict.update({f'{prefix}/loss_depth': loss_depth.item()})
|
595 |
+
loss_dict.update({f'{prefix}/loss_spec_light': loss_spec_light.item()})
|
596 |
+
loss_dict.update({f'{prefix}/loss_diff_light': loss_diff_light.item()})
|
597 |
+
loss_dict.update({f'{prefix}/loss_spec_light_lpips': loss_spec_light_lpips.item()})
|
598 |
+
loss_dict.update({f'{prefix}/loss_diff_light_lpips': loss_diff_light_lpips.item()})
|
599 |
+
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy.item()})
|
600 |
+
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg.item()})
|
601 |
+
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg.item()})
|
602 |
+
if self.use_tv_loss:
|
603 |
+
loss_dict.update({f'{prefix}/loss_tv': tv_loss.item()})
|
604 |
+
loss_dict.update({f'{prefix}/loss': loss.item()})
|
605 |
+
|
606 |
+
return loss, loss_dict
|
607 |
+
|
608 |
+
@torch.no_grad()
|
609 |
+
def validation_step(self, batch, batch_idx):
|
610 |
+
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
611 |
+
|
612 |
+
render_out = self.forward(lrm_generator_input)
|
613 |
+
render_images = rearrange(render_out['pbr_img'], 'b n c h w -> b c h (n w)')
|
614 |
+
render_albodos = rearrange(render_out['img'], 'b n c h w -> b c h (n w)')
|
615 |
+
|
616 |
+
self.validation_step_outputs.append(render_images)
|
617 |
+
self.validation_step_outputs.append(render_albodos)
|
618 |
+
|
619 |
+
def on_validation_epoch_end(self):
|
620 |
+
images = torch.cat(self.validation_step_outputs, dim=0)
|
621 |
+
|
622 |
+
all_images = self.all_gather(images)
|
623 |
+
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
624 |
+
|
625 |
+
if self.global_rank == 0:
|
626 |
+
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
627 |
+
|
628 |
+
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
629 |
+
|
630 |
+
save_image(grid, image_path)
|
631 |
+
print(f"Saved image to {image_path}")
|
632 |
+
|
633 |
+
self.validation_step_outputs.clear()
|
634 |
+
|
635 |
+
def configure_optimizers(self):
|
636 |
+
lr = self.learning_rate
|
637 |
+
|
638 |
+
optimizer = torch.optim.AdamW(
|
639 |
+
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
|
640 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
|
641 |
+
|
642 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
src/models/__init__.py
ADDED
File without changes
|
src/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (146 Bytes). View file
|
|
src/models/__pycache__/lrm_mesh.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
src/models/decoder/__init__.py
ADDED
File without changes
|
src/models/decoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
src/models/decoder/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (3.45 kB). View file
|
|
src/models/decoder/transformer.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class BasicTransformerBlock(nn.Module):
|
21 |
+
"""
|
22 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
23 |
+
"""
|
24 |
+
# use attention from torch.nn.MultiHeadAttention
|
25 |
+
# Block contains a cross-attention layer, a self-attention layer, and a MLP
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
inner_dim: int,
|
29 |
+
cond_dim: int,
|
30 |
+
num_heads: int,
|
31 |
+
eps: float,
|
32 |
+
attn_drop: float = 0.,
|
33 |
+
attn_bias: bool = False,
|
34 |
+
mlp_ratio: float = 4.,
|
35 |
+
mlp_drop: float = 0.,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.norm1 = nn.LayerNorm(inner_dim)
|
40 |
+
self.cross_attn = nn.MultiheadAttention(
|
41 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
42 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
43 |
+
self.norm2 = nn.LayerNorm(inner_dim)
|
44 |
+
self.self_attn = nn.MultiheadAttention(
|
45 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
46 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
47 |
+
self.norm3 = nn.LayerNorm(inner_dim)
|
48 |
+
self.mlp = nn.Sequential(
|
49 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
50 |
+
nn.GELU(),
|
51 |
+
nn.Dropout(mlp_drop),
|
52 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
53 |
+
nn.Dropout(mlp_drop),
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, cond):
|
57 |
+
# x: [N, L, D]
|
58 |
+
# cond: [N, L_cond, D_cond]
|
59 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
|
60 |
+
before_sa = self.norm2(x)
|
61 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
|
62 |
+
x = x + self.mlp(self.norm3(x))
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class TriplaneTransformer(nn.Module):
|
67 |
+
"""
|
68 |
+
Transformer with condition that generates a triplane representation.
|
69 |
+
|
70 |
+
Reference:
|
71 |
+
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
|
72 |
+
"""
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
inner_dim: int,
|
76 |
+
image_feat_dim: int,
|
77 |
+
triplane_low_res: int,
|
78 |
+
triplane_high_res: int,
|
79 |
+
triplane_dim: int,
|
80 |
+
num_layers: int,
|
81 |
+
num_heads: int,
|
82 |
+
eps: float = 1e-6,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
# attributes
|
87 |
+
self.triplane_low_res = triplane_low_res
|
88 |
+
self.triplane_high_res = triplane_high_res
|
89 |
+
self.triplane_dim = triplane_dim
|
90 |
+
|
91 |
+
# modules
|
92 |
+
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
93 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
|
94 |
+
self.layers = nn.ModuleList([
|
95 |
+
BasicTransformerBlock(
|
96 |
+
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
|
97 |
+
for _ in range(num_layers)
|
98 |
+
])
|
99 |
+
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
100 |
+
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
101 |
+
|
102 |
+
def forward(self, image_feats):
|
103 |
+
# image_feats: [N, L_cond, D_cond]
|
104 |
+
|
105 |
+
N = image_feats.shape[0]
|
106 |
+
H = W = self.triplane_low_res
|
107 |
+
L = 3 * H * W
|
108 |
+
|
109 |
+
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
110 |
+
for layer in self.layers:
|
111 |
+
x = layer(x, image_feats)
|
112 |
+
x = self.norm(x)
|
113 |
+
|
114 |
+
# separate each plane and apply deconv
|
115 |
+
x = x.view(N, 3, H, W, -1)
|
116 |
+
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
117 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
118 |
+
x = self.deconv(x) # [3*N, D', H', W']
|
119 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
120 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
121 |
+
x = x.contiguous()
|
122 |
+
|
123 |
+
return x
|
src/models/encoder/__init__.py
ADDED
File without changes
|
src/models/encoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
src/models/encoder/__pycache__/dino.cpython-310.pyc
ADDED
Binary file (17.2 kB). View file
|
|
src/models/encoder/__pycache__/dino_wrapper.cpython-310.pyc
ADDED
Binary file (2.54 kB). View file
|
|
src/models/encoder/dino.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch ViT model."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections.abc
|
19 |
+
import math
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutput,
|
28 |
+
BaseModelOutputWithPooling,
|
29 |
+
)
|
30 |
+
from transformers import PreTrainedModel, ViTConfig
|
31 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
32 |
+
|
33 |
+
|
34 |
+
class ViTEmbeddings(nn.Module):
|
35 |
+
"""
|
36 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
43 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
44 |
+
self.patch_embeddings = ViTPatchEmbeddings(config)
|
45 |
+
num_patches = self.patch_embeddings.num_patches
|
46 |
+
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
47 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
48 |
+
self.config = config
|
49 |
+
|
50 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
51 |
+
"""
|
52 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
53 |
+
resolution images.
|
54 |
+
|
55 |
+
Source:
|
56 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
57 |
+
"""
|
58 |
+
|
59 |
+
num_patches = embeddings.shape[1] - 1
|
60 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
61 |
+
if num_patches == num_positions and height == width:
|
62 |
+
return self.position_embeddings
|
63 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
64 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
65 |
+
dim = embeddings.shape[-1]
|
66 |
+
h0 = height // self.config.patch_size
|
67 |
+
w0 = width // self.config.patch_size
|
68 |
+
# we add a small number to avoid floating point error in the interpolation
|
69 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
70 |
+
h0, w0 = h0 + 0.1, w0 + 0.1
|
71 |
+
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
72 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
73 |
+
patch_pos_embed = nn.functional.interpolate(
|
74 |
+
patch_pos_embed,
|
75 |
+
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
76 |
+
mode="bicubic",
|
77 |
+
align_corners=False,
|
78 |
+
)
|
79 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
80 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
81 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
82 |
+
|
83 |
+
def forward(
|
84 |
+
self,
|
85 |
+
pixel_values: torch.Tensor,
|
86 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
87 |
+
interpolate_pos_encoding: bool = False,
|
88 |
+
) -> torch.Tensor:
|
89 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
90 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
91 |
+
|
92 |
+
if bool_masked_pos is not None:
|
93 |
+
seq_length = embeddings.shape[1]
|
94 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
95 |
+
# replace the masked visual tokens by mask_tokens
|
96 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
97 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
98 |
+
|
99 |
+
# add the [CLS] token to the embedded patch tokens
|
100 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
101 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
102 |
+
|
103 |
+
# add positional encoding to each token
|
104 |
+
if interpolate_pos_encoding:
|
105 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
106 |
+
else:
|
107 |
+
embeddings = embeddings + self.position_embeddings
|
108 |
+
|
109 |
+
embeddings = self.dropout(embeddings)
|
110 |
+
|
111 |
+
return embeddings
|
112 |
+
|
113 |
+
|
114 |
+
class ViTPatchEmbeddings(nn.Module):
|
115 |
+
"""
|
116 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
117 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
118 |
+
Transformer.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, config):
|
122 |
+
super().__init__()
|
123 |
+
image_size, patch_size = config.image_size, config.patch_size
|
124 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
125 |
+
|
126 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
127 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
128 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
129 |
+
self.image_size = image_size
|
130 |
+
self.patch_size = patch_size
|
131 |
+
self.num_channels = num_channels
|
132 |
+
self.num_patches = num_patches
|
133 |
+
|
134 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
135 |
+
|
136 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
137 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
138 |
+
if num_channels != self.num_channels:
|
139 |
+
raise ValueError(
|
140 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
141 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
142 |
+
)
|
143 |
+
if not interpolate_pos_encoding:
|
144 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
145 |
+
raise ValueError(
|
146 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
147 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
148 |
+
)
|
149 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
150 |
+
return embeddings
|
151 |
+
|
152 |
+
|
153 |
+
class ViTSelfAttention(nn.Module):
|
154 |
+
def __init__(self, config: ViTConfig) -> None:
|
155 |
+
super().__init__()
|
156 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
157 |
+
raise ValueError(
|
158 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
159 |
+
f"heads {config.num_attention_heads}."
|
160 |
+
)
|
161 |
+
|
162 |
+
self.num_attention_heads = config.num_attention_heads
|
163 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
164 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
165 |
+
|
166 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
167 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
168 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
169 |
+
|
170 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
171 |
+
|
172 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
173 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
174 |
+
x = x.view(new_x_shape)
|
175 |
+
return x.permute(0, 2, 1, 3)
|
176 |
+
|
177 |
+
def forward(
|
178 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
179 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
180 |
+
mixed_query_layer = self.query(hidden_states)
|
181 |
+
|
182 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
183 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
184 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
185 |
+
|
186 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
187 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
188 |
+
|
189 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
190 |
+
|
191 |
+
# Normalize the attention scores to probabilities.
|
192 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
193 |
+
|
194 |
+
# This is actually dropping out entire tokens to attend to, which might
|
195 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
196 |
+
attention_probs = self.dropout(attention_probs)
|
197 |
+
|
198 |
+
# Mask heads if we want to
|
199 |
+
if head_mask is not None:
|
200 |
+
attention_probs = attention_probs * head_mask
|
201 |
+
|
202 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
203 |
+
|
204 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
205 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
206 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
207 |
+
|
208 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
209 |
+
|
210 |
+
return outputs
|
211 |
+
|
212 |
+
|
213 |
+
class ViTSelfOutput(nn.Module):
|
214 |
+
"""
|
215 |
+
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
|
216 |
+
layernorm applied before each block.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, config: ViTConfig) -> None:
|
220 |
+
super().__init__()
|
221 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
222 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
223 |
+
|
224 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
225 |
+
hidden_states = self.dense(hidden_states)
|
226 |
+
hidden_states = self.dropout(hidden_states)
|
227 |
+
|
228 |
+
return hidden_states
|
229 |
+
|
230 |
+
|
231 |
+
class ViTAttention(nn.Module):
|
232 |
+
def __init__(self, config: ViTConfig) -> None:
|
233 |
+
super().__init__()
|
234 |
+
self.attention = ViTSelfAttention(config)
|
235 |
+
self.output = ViTSelfOutput(config)
|
236 |
+
self.pruned_heads = set()
|
237 |
+
|
238 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
239 |
+
if len(heads) == 0:
|
240 |
+
return
|
241 |
+
heads, index = find_pruneable_heads_and_indices(
|
242 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
243 |
+
)
|
244 |
+
|
245 |
+
# Prune linear layers
|
246 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
247 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
248 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
249 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
250 |
+
|
251 |
+
# Update hyper params and store pruned heads
|
252 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
253 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
254 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
hidden_states: torch.Tensor,
|
259 |
+
head_mask: Optional[torch.Tensor] = None,
|
260 |
+
output_attentions: bool = False,
|
261 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
262 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
263 |
+
|
264 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
265 |
+
|
266 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
267 |
+
return outputs
|
268 |
+
|
269 |
+
|
270 |
+
class ViTIntermediate(nn.Module):
|
271 |
+
def __init__(self, config: ViTConfig) -> None:
|
272 |
+
super().__init__()
|
273 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
274 |
+
if isinstance(config.hidden_act, str):
|
275 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
276 |
+
else:
|
277 |
+
self.intermediate_act_fn = config.hidden_act
|
278 |
+
|
279 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
280 |
+
hidden_states = self.dense(hidden_states)
|
281 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
282 |
+
|
283 |
+
return hidden_states
|
284 |
+
|
285 |
+
|
286 |
+
class ViTOutput(nn.Module):
|
287 |
+
def __init__(self, config: ViTConfig) -> None:
|
288 |
+
super().__init__()
|
289 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
290 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
291 |
+
|
292 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
293 |
+
hidden_states = self.dense(hidden_states)
|
294 |
+
hidden_states = self.dropout(hidden_states)
|
295 |
+
|
296 |
+
hidden_states = hidden_states + input_tensor
|
297 |
+
|
298 |
+
return hidden_states
|
299 |
+
|
300 |
+
|
301 |
+
def modulate(x, shift, scale):
|
302 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
303 |
+
|
304 |
+
|
305 |
+
class ViTLayer(nn.Module):
|
306 |
+
"""This corresponds to the Block class in the timm implementation."""
|
307 |
+
|
308 |
+
def __init__(self, config: ViTConfig) -> None:
|
309 |
+
super().__init__()
|
310 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
311 |
+
self.seq_len_dim = 1
|
312 |
+
self.attention = ViTAttention(config)
|
313 |
+
self.intermediate = ViTIntermediate(config)
|
314 |
+
self.output = ViTOutput(config)
|
315 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
316 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
317 |
+
|
318 |
+
self.adaLN_modulation = nn.Sequential(
|
319 |
+
nn.SiLU(),
|
320 |
+
nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
|
321 |
+
)
|
322 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
323 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
hidden_states: torch.Tensor,
|
328 |
+
adaln_input: torch.Tensor = None,
|
329 |
+
head_mask: Optional[torch.Tensor] = None,
|
330 |
+
output_attentions: bool = False,
|
331 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
332 |
+
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
333 |
+
|
334 |
+
self_attention_outputs = self.attention(
|
335 |
+
modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
|
336 |
+
head_mask,
|
337 |
+
output_attentions=output_attentions,
|
338 |
+
)
|
339 |
+
attention_output = self_attention_outputs[0]
|
340 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
341 |
+
|
342 |
+
# first residual connection
|
343 |
+
hidden_states = attention_output + hidden_states
|
344 |
+
|
345 |
+
# in ViT, layernorm is also applied after self-attention
|
346 |
+
layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
|
347 |
+
layer_output = self.intermediate(layer_output)
|
348 |
+
|
349 |
+
# second residual connection is done here
|
350 |
+
layer_output = self.output(layer_output, hidden_states)
|
351 |
+
|
352 |
+
outputs = (layer_output,) + outputs
|
353 |
+
|
354 |
+
return outputs
|
355 |
+
|
356 |
+
|
357 |
+
class ViTEncoder(nn.Module):
|
358 |
+
def __init__(self, config: ViTConfig) -> None:
|
359 |
+
super().__init__()
|
360 |
+
self.config = config
|
361 |
+
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
362 |
+
self.gradient_checkpointing = False
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
adaln_input: torch.Tensor = None,
|
368 |
+
head_mask: Optional[torch.Tensor] = None,
|
369 |
+
output_attentions: bool = False,
|
370 |
+
output_hidden_states: bool = False,
|
371 |
+
return_dict: bool = True,
|
372 |
+
) -> Union[tuple, BaseModelOutput]:
|
373 |
+
all_hidden_states = () if output_hidden_states else None
|
374 |
+
all_self_attentions = () if output_attentions else None
|
375 |
+
|
376 |
+
for i, layer_module in enumerate(self.layer):
|
377 |
+
if output_hidden_states:
|
378 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
379 |
+
|
380 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
381 |
+
|
382 |
+
if self.gradient_checkpointing and self.training:
|
383 |
+
layer_outputs = self._gradient_checkpointing_func(
|
384 |
+
layer_module.__call__,
|
385 |
+
hidden_states,
|
386 |
+
adaln_input,
|
387 |
+
layer_head_mask,
|
388 |
+
output_attentions,
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
|
392 |
+
|
393 |
+
hidden_states = layer_outputs[0]
|
394 |
+
|
395 |
+
if output_attentions:
|
396 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
397 |
+
|
398 |
+
if output_hidden_states:
|
399 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
400 |
+
|
401 |
+
if not return_dict:
|
402 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
403 |
+
return BaseModelOutput(
|
404 |
+
last_hidden_state=hidden_states,
|
405 |
+
hidden_states=all_hidden_states,
|
406 |
+
attentions=all_self_attentions,
|
407 |
+
)
|
408 |
+
|
409 |
+
|
410 |
+
class ViTPreTrainedModel(PreTrainedModel):
|
411 |
+
"""
|
412 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
413 |
+
models.
|
414 |
+
"""
|
415 |
+
|
416 |
+
config_class = ViTConfig
|
417 |
+
base_model_prefix = "vit"
|
418 |
+
main_input_name = "pixel_values"
|
419 |
+
supports_gradient_checkpointing = True
|
420 |
+
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
421 |
+
|
422 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
423 |
+
"""Initialize the weights"""
|
424 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
425 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
426 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
427 |
+
module.weight.data = nn.init.trunc_normal_(
|
428 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
429 |
+
).to(module.weight.dtype)
|
430 |
+
if module.bias is not None:
|
431 |
+
module.bias.data.zero_()
|
432 |
+
elif isinstance(module, nn.LayerNorm):
|
433 |
+
module.bias.data.zero_()
|
434 |
+
module.weight.data.fill_(1.0)
|
435 |
+
elif isinstance(module, ViTEmbeddings):
|
436 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
437 |
+
module.position_embeddings.data.to(torch.float32),
|
438 |
+
mean=0.0,
|
439 |
+
std=self.config.initializer_range,
|
440 |
+
).to(module.position_embeddings.dtype)
|
441 |
+
|
442 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
443 |
+
module.cls_token.data.to(torch.float32),
|
444 |
+
mean=0.0,
|
445 |
+
std=self.config.initializer_range,
|
446 |
+
).to(module.cls_token.dtype)
|
447 |
+
|
448 |
+
|
449 |
+
class ViTModel(ViTPreTrainedModel):
|
450 |
+
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
|
451 |
+
super().__init__(config)
|
452 |
+
self.config = config
|
453 |
+
|
454 |
+
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
|
455 |
+
self.encoder = ViTEncoder(config)
|
456 |
+
|
457 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
458 |
+
self.pooler = ViTPooler(config) if add_pooling_layer else None
|
459 |
+
|
460 |
+
# Initialize weights and apply final processing
|
461 |
+
self.post_init()
|
462 |
+
|
463 |
+
def get_input_embeddings(self) -> ViTPatchEmbeddings:
|
464 |
+
return self.embeddings.patch_embeddings
|
465 |
+
|
466 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
467 |
+
"""
|
468 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
469 |
+
class PreTrainedModel
|
470 |
+
"""
|
471 |
+
for layer, heads in heads_to_prune.items():
|
472 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
473 |
+
|
474 |
+
def forward(
|
475 |
+
self,
|
476 |
+
pixel_values: Optional[torch.Tensor] = None,
|
477 |
+
adaln_input: Optional[torch.Tensor] = None,
|
478 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
479 |
+
head_mask: Optional[torch.Tensor] = None,
|
480 |
+
output_attentions: Optional[bool] = None,
|
481 |
+
output_hidden_states: Optional[bool] = None,
|
482 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
483 |
+
return_dict: Optional[bool] = None,
|
484 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
485 |
+
r"""
|
486 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
487 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
488 |
+
"""
|
489 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
490 |
+
output_hidden_states = (
|
491 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
492 |
+
)
|
493 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
494 |
+
|
495 |
+
if pixel_values is None:
|
496 |
+
raise ValueError("You have to specify pixel_values")
|
497 |
+
|
498 |
+
# Prepare head mask if needed
|
499 |
+
# 1.0 in head_mask indicate we keep the head
|
500 |
+
# attention_probs has shape bsz x n_heads x N x N
|
501 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
502 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
503 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
504 |
+
|
505 |
+
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
506 |
+
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
507 |
+
if pixel_values.dtype != expected_dtype:
|
508 |
+
pixel_values = pixel_values.to(expected_dtype)
|
509 |
+
|
510 |
+
embedding_output = self.embeddings(
|
511 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
512 |
+
)
|
513 |
+
|
514 |
+
encoder_outputs = self.encoder(
|
515 |
+
embedding_output,
|
516 |
+
adaln_input=adaln_input,
|
517 |
+
head_mask=head_mask,
|
518 |
+
output_attentions=output_attentions,
|
519 |
+
output_hidden_states=output_hidden_states,
|
520 |
+
return_dict=return_dict,
|
521 |
+
)
|
522 |
+
sequence_output = encoder_outputs[0]
|
523 |
+
sequence_output = self.layernorm(sequence_output)
|
524 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
525 |
+
|
526 |
+
if not return_dict:
|
527 |
+
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
528 |
+
return head_outputs + encoder_outputs[1:]
|
529 |
+
|
530 |
+
return BaseModelOutputWithPooling(
|
531 |
+
last_hidden_state=sequence_output,
|
532 |
+
pooler_output=pooled_output,
|
533 |
+
hidden_states=encoder_outputs.hidden_states,
|
534 |
+
attentions=encoder_outputs.attentions,
|
535 |
+
)
|
536 |
+
|
537 |
+
|
538 |
+
class ViTPooler(nn.Module):
|
539 |
+
def __init__(self, config: ViTConfig):
|
540 |
+
super().__init__()
|
541 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
542 |
+
self.activation = nn.Tanh()
|
543 |
+
|
544 |
+
def forward(self, hidden_states):
|
545 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
546 |
+
# to the first token.
|
547 |
+
first_token_tensor = hidden_states[:, 0]
|
548 |
+
pooled_output = self.dense(first_token_tensor)
|
549 |
+
pooled_output = self.activation(pooled_output)
|
550 |
+
return pooled_output
|
src/models/encoder/dino_wrapper.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch.nn as nn
|
17 |
+
from transformers import ViTImageProcessor
|
18 |
+
from einops import rearrange, repeat
|
19 |
+
from .dino import ViTModel
|
20 |
+
|
21 |
+
|
22 |
+
class DinoWrapper(nn.Module):
|
23 |
+
"""
|
24 |
+
Dino v1 wrapper using huggingface transformer implementation.
|
25 |
+
"""
|
26 |
+
def __init__(self, model_name: str, freeze: bool = True):
|
27 |
+
super().__init__()
|
28 |
+
self.model, self.processor = self._build_dino(model_name)
|
29 |
+
self.camera_embedder = nn.Sequential(
|
30 |
+
nn.Linear(16, self.model.config.hidden_size, bias=True),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
|
33 |
+
)
|
34 |
+
if freeze:
|
35 |
+
self._freeze()
|
36 |
+
|
37 |
+
def forward(self, image, camera):
|
38 |
+
# image: [B, N, C, H, W]
|
39 |
+
# camera: [B, N, D]
|
40 |
+
# RGB image with [0,1] scale and properly sized
|
41 |
+
if image.ndim == 5:
|
42 |
+
image = rearrange(image, 'b n c h w -> (b n) c h w')
|
43 |
+
dtype = image.dtype
|
44 |
+
inputs = self.processor(
|
45 |
+
images=image.float(),
|
46 |
+
return_tensors="pt",
|
47 |
+
do_rescale=False,
|
48 |
+
do_resize=False,
|
49 |
+
).to(self.model.device).to(dtype)
|
50 |
+
# embed camera
|
51 |
+
N = camera.shape[1]
|
52 |
+
camera_embeddings = self.camera_embedder(camera)
|
53 |
+
camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
|
54 |
+
embeddings = camera_embeddings
|
55 |
+
# This resampling of positional embedding uses bicubic interpolation
|
56 |
+
outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
|
57 |
+
last_hidden_states = outputs.last_hidden_state
|
58 |
+
return last_hidden_states
|
59 |
+
|
60 |
+
def _freeze(self):
|
61 |
+
print(f"======== Freezing DinoWrapper ========")
|
62 |
+
self.model.eval()
|
63 |
+
for name, param in self.model.named_parameters():
|
64 |
+
param.requires_grad = False
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
68 |
+
import requests
|
69 |
+
try:
|
70 |
+
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
71 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
72 |
+
return model, processor
|
73 |
+
except requests.exceptions.ProxyError as err:
|
74 |
+
if proxy_error_retries > 0:
|
75 |
+
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
|
76 |
+
import time
|
77 |
+
time.sleep(proxy_error_cooldown)
|
78 |
+
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
79 |
+
else:
|
80 |
+
raise err
|
src/models/geometry/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
src/models/geometry/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (155 Bytes). View file
|
|
src/models/geometry/camera/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class Camera(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(Camera, self).__init__()
|
16 |
+
pass
|
src/models/geometry/camera/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (547 Bytes). View file
|
|
src/models/geometry/camera/__pycache__/perspective_camera.cpython-310.pyc
ADDED
Binary file (1.43 kB). View file
|
|
src/models/geometry/camera/perspective_camera.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from . import Camera
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
|
15 |
+
if near_plane is None:
|
16 |
+
near_plane = n
|
17 |
+
return np.array(
|
18 |
+
[[n / x, 0, 0, 0],
|
19 |
+
[0, n / -x, 0, 0],
|
20 |
+
[0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
|
21 |
+
[0, 0, -1, 0]]).astype(np.float32)
|
22 |
+
|
23 |
+
|
24 |
+
class PerspectiveCamera(Camera):
|
25 |
+
def __init__(self, fovy=49.0, device='cuda'):
|
26 |
+
super(PerspectiveCamera, self).__init__()
|
27 |
+
self.device = device
|
28 |
+
focal = np.tan(fovy / 180.0 * np.pi * 0.5)
|
29 |
+
self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
|
30 |
+
|
31 |
+
def project(self, points_bxnx4):
|
32 |
+
out = torch.matmul(
|
33 |
+
points_bxnx4,
|
34 |
+
torch.transpose(self.proj_mtx, 1, 2))
|
35 |
+
return out
|
src/models/geometry/render/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Renderer():
|
4 |
+
def __init__(self):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def forward(self):
|
8 |
+
pass
|
src/models/geometry/render/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (565 Bytes). View file
|
|
src/models/geometry/render/__pycache__/neural_render.cpython-310.pyc
ADDED
Binary file (5.85 kB). View file
|
|
src/models/geometry/render/__pycache__/util.cpython-310.pyc
ADDED
Binary file (15.1 kB). View file
|
|
src/models/geometry/render/neural_render.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
from . import Renderer
|
13 |
+
from . import util
|
14 |
+
from . import renderutils as ru
|
15 |
+
_FG_LUT = None
|
16 |
+
|
17 |
+
|
18 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
19 |
+
return dr.interpolate(
|
20 |
+
attr.contiguous(), rast, attr_idx, rast_db=rast_db,
|
21 |
+
diff_attrs=None if rast_db is None else 'all')
|
22 |
+
|
23 |
+
|
24 |
+
def xfm_points(points, matrix, use_python=True):
|
25 |
+
'''Transform points.
|
26 |
+
Args:
|
27 |
+
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
28 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
29 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
30 |
+
Returns:
|
31 |
+
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
32 |
+
'''
|
33 |
+
out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
34 |
+
if torch.is_anomaly_enabled():
|
35 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
def dot(x, y):
|
40 |
+
return torch.sum(x * y, -1, keepdim=True)
|
41 |
+
|
42 |
+
|
43 |
+
def compute_vertex_normal(v_pos, t_pos_idx):
|
44 |
+
i0 = t_pos_idx[:, 0]
|
45 |
+
i1 = t_pos_idx[:, 1]
|
46 |
+
i2 = t_pos_idx[:, 2]
|
47 |
+
|
48 |
+
v0 = v_pos[i0, :]
|
49 |
+
v1 = v_pos[i1, :]
|
50 |
+
v2 = v_pos[i2, :]
|
51 |
+
|
52 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
53 |
+
|
54 |
+
# Splat face normals to vertices
|
55 |
+
v_nrm = torch.zeros_like(v_pos)
|
56 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
57 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
58 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
59 |
+
|
60 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
61 |
+
v_nrm = torch.where(
|
62 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
63 |
+
)
|
64 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
65 |
+
assert torch.all(torch.isfinite(v_nrm))
|
66 |
+
|
67 |
+
return v_nrm
|
68 |
+
|
69 |
+
|
70 |
+
class NeuralRender(Renderer):
|
71 |
+
def __init__(self, device='cuda', camera_model=None):
|
72 |
+
super(NeuralRender, self).__init__()
|
73 |
+
self.device = device
|
74 |
+
self.ctx = dr.RasterizeCudaContext(device=device)
|
75 |
+
self.projection_mtx = None
|
76 |
+
self.camera = camera_model
|
77 |
+
|
78 |
+
# ==============================================================================================
|
79 |
+
# pixel shader
|
80 |
+
# ==============================================================================================
|
81 |
+
# def shade(
|
82 |
+
# self,
|
83 |
+
# gb_pos,
|
84 |
+
# gb_geometric_normal,
|
85 |
+
# gb_normal,
|
86 |
+
# gb_tangent,
|
87 |
+
# gb_texc,
|
88 |
+
# gb_texc_deriv,
|
89 |
+
# view_pos,
|
90 |
+
# ):
|
91 |
+
|
92 |
+
# ################################################################################
|
93 |
+
# # Texture lookups
|
94 |
+
# ################################################################################
|
95 |
+
# breakpoint()
|
96 |
+
# # Separate kd into alpha and color, default alpha = 1
|
97 |
+
# alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
|
98 |
+
# kd = kd[..., 0:3]
|
99 |
+
|
100 |
+
# ################################################################################
|
101 |
+
# # Normal perturbation & normal bend
|
102 |
+
# ################################################################################
|
103 |
+
|
104 |
+
# perturbed_nrm = None
|
105 |
+
|
106 |
+
# gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
|
107 |
+
|
108 |
+
# ################################################################################
|
109 |
+
# # Evaluate BSDF
|
110 |
+
# ################################################################################
|
111 |
+
|
112 |
+
# assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
|
113 |
+
# bsdf = material['bsdf'] if bsdf is None else bsdf
|
114 |
+
# if bsdf == 'pbr':
|
115 |
+
# if isinstance(lgt, light.EnvironmentLight):
|
116 |
+
# shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
|
117 |
+
# else:
|
118 |
+
# assert False, "Invalid light type"
|
119 |
+
# elif bsdf == 'diffuse':
|
120 |
+
# if isinstance(lgt, light.EnvironmentLight):
|
121 |
+
# shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
|
122 |
+
# else:
|
123 |
+
# assert False, "Invalid light type"
|
124 |
+
# elif bsdf == 'normal':
|
125 |
+
# shaded_col = (gb_normal + 1.0)*0.5
|
126 |
+
# elif bsdf == 'tangent':
|
127 |
+
# shaded_col = (gb_tangent + 1.0)*0.5
|
128 |
+
# elif bsdf == 'kd':
|
129 |
+
# shaded_col = kd
|
130 |
+
# elif bsdf == 'ks':
|
131 |
+
# shaded_col = ks
|
132 |
+
# else:
|
133 |
+
# assert False, "Invalid BSDF '%s'" % bsdf
|
134 |
+
|
135 |
+
# # Return multiple buffers
|
136 |
+
# buffers = {
|
137 |
+
# 'shaded' : torch.cat((shaded_col, alpha), dim=-1),
|
138 |
+
# 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
|
139 |
+
# 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1)
|
140 |
+
# }
|
141 |
+
# return buffers
|
142 |
+
|
143 |
+
# ==============================================================================================
|
144 |
+
# Render a depth slice of the mesh (scene), some limitations:
|
145 |
+
# - Single mesh
|
146 |
+
# - Single light
|
147 |
+
# - Single material
|
148 |
+
# ==============================================================================================
|
149 |
+
def render_layer(
|
150 |
+
self,
|
151 |
+
rast,
|
152 |
+
rast_deriv,
|
153 |
+
mesh,
|
154 |
+
view_pos,
|
155 |
+
resolution,
|
156 |
+
spp,
|
157 |
+
msaa
|
158 |
+
):
|
159 |
+
|
160 |
+
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
|
161 |
+
rast_out_s = rast
|
162 |
+
rast_out_deriv_s = rast_deriv
|
163 |
+
|
164 |
+
################################################################################
|
165 |
+
# Interpolate attributes
|
166 |
+
################################################################################
|
167 |
+
|
168 |
+
# Interpolate world space position
|
169 |
+
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
|
170 |
+
|
171 |
+
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
|
172 |
+
v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
|
173 |
+
v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
|
174 |
+
v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
|
175 |
+
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
176 |
+
face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
177 |
+
gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
|
178 |
+
|
179 |
+
# Compute tangent space
|
180 |
+
assert mesh.v_nrm is not None and mesh.v_tng is not None
|
181 |
+
gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
|
182 |
+
gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
|
183 |
+
|
184 |
+
# Texture coordinate
|
185 |
+
# assert mesh.v_tex is not None
|
186 |
+
# gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
|
187 |
+
perturbed_nrm = None
|
188 |
+
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
|
189 |
+
|
190 |
+
return gb_pos, gb_normal
|
191 |
+
|
192 |
+
def render_mesh(
|
193 |
+
self,
|
194 |
+
mesh_v_pos_bxnx3,
|
195 |
+
mesh_t_pos_idx_fx3,
|
196 |
+
mesh,
|
197 |
+
camera_mv_bx4x4,
|
198 |
+
camera_pos,
|
199 |
+
mesh_v_feat_bxnxd,
|
200 |
+
resolution=256,
|
201 |
+
spp=1,
|
202 |
+
device='cuda',
|
203 |
+
hierarchical_mask=False
|
204 |
+
):
|
205 |
+
assert not hierarchical_mask
|
206 |
+
|
207 |
+
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
|
208 |
+
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
|
209 |
+
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
|
210 |
+
|
211 |
+
# view_pos = torch.linalg.inv(mtx_in)[:, :3, 3]
|
212 |
+
view_pos = camera_pos
|
213 |
+
v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
|
214 |
+
|
215 |
+
# Render the image,
|
216 |
+
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
|
217 |
+
num_layers = 1
|
218 |
+
mask_pyramid = None
|
219 |
+
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
|
220 |
+
|
221 |
+
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization]
|
222 |
+
|
223 |
+
layers = []
|
224 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler:
|
225 |
+
for _ in range(num_layers):
|
226 |
+
rast, db = peeler.rasterize_next_layer()
|
227 |
+
gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False)
|
228 |
+
|
229 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
|
230 |
+
for _ in range(num_layers):
|
231 |
+
rast, db = peeler.rasterize_next_layer()
|
232 |
+
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
|
233 |
+
|
234 |
+
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
|
235 |
+
antialias_mask = dr.antialias(
|
236 |
+
hard_mask.clone().contiguous(), rast, v_pos_clip,
|
237 |
+
mesh_t_pos_idx_fx3)
|
238 |
+
|
239 |
+
depth = gb_feat[..., -2:-1]
|
240 |
+
ori_mesh_feature = gb_feat[..., :-4]
|
241 |
+
|
242 |
+
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
|
243 |
+
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
|
244 |
+
# normal = F.normalize(normal, dim=-1)
|
245 |
+
# normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
|
246 |
+
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal
|
247 |
+
|
248 |
+
def render_mesh_light(
|
249 |
+
self,
|
250 |
+
mesh_v_pos_bxnx3,
|
251 |
+
mesh_t_pos_idx_fx3,
|
252 |
+
mesh,
|
253 |
+
camera_mv_bx4x4,
|
254 |
+
mesh_v_feat_bxnxd,
|
255 |
+
resolution=256,
|
256 |
+
spp=1,
|
257 |
+
device='cuda',
|
258 |
+
hierarchical_mask=False
|
259 |
+
):
|
260 |
+
assert not hierarchical_mask
|
261 |
+
|
262 |
+
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
|
263 |
+
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
|
264 |
+
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
|
265 |
+
|
266 |
+
v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
|
267 |
+
|
268 |
+
# Render the image,
|
269 |
+
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
|
270 |
+
num_layers = 1
|
271 |
+
mask_pyramid = None
|
272 |
+
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
|
273 |
+
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
|
274 |
+
|
275 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
|
276 |
+
for _ in range(num_layers):
|
277 |
+
rast, db = peeler.rasterize_next_layer()
|
278 |
+
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
|
279 |
+
|
280 |
+
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
|
281 |
+
antialias_mask = dr.antialias(
|
282 |
+
hard_mask.clone().contiguous(), rast, v_pos_clip,
|
283 |
+
mesh_t_pos_idx_fx3)
|
284 |
+
|
285 |
+
depth = gb_feat[..., -2:-1]
|
286 |
+
ori_mesh_feature = gb_feat[..., :-4]
|
287 |
+
|
288 |
+
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
|
289 |
+
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
|
290 |
+
normal = F.normalize(normal, dim=-1)
|
291 |
+
normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
|
292 |
+
|
293 |
+
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
|
src/models/geometry/render/renderutils/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
|
11 |
+
__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
|
src/models/geometry/render/renderutils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (612 Bytes). View file
|
|
src/models/geometry/render/renderutils/__pycache__/bsdf.cpython-310.pyc
ADDED
Binary file (4.48 kB). View file
|
|
src/models/geometry/render/renderutils/__pycache__/loss.cpython-310.pyc
ADDED
Binary file (1.22 kB). View file
|
|
src/models/geometry/render/renderutils/__pycache__/ops.cpython-310.pyc
ADDED
Binary file (18.8 kB). View file
|
|
src/models/geometry/render/renderutils/bsdf.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
|
13 |
+
NORMAL_THRESHOLD = 0.1
|
14 |
+
|
15 |
+
################################################################################
|
16 |
+
# Vector utility functions
|
17 |
+
################################################################################
|
18 |
+
|
19 |
+
def _dot(x, y):
|
20 |
+
return torch.sum(x*y, -1, keepdim=True)
|
21 |
+
|
22 |
+
def _reflect(x, n):
|
23 |
+
return 2*_dot(x, n)*n - x
|
24 |
+
|
25 |
+
def _safe_normalize(x):
|
26 |
+
return torch.nn.functional.normalize(x, dim = -1)
|
27 |
+
|
28 |
+
def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
|
29 |
+
# Swap normal direction for backfacing surfaces
|
30 |
+
if two_sided_shading:
|
31 |
+
smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
|
32 |
+
geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
|
33 |
+
|
34 |
+
t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
|
35 |
+
return torch.lerp(geom_nrm, smooth_nrm, t)
|
36 |
+
|
37 |
+
|
38 |
+
def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
|
39 |
+
smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
|
40 |
+
if opengl:
|
41 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
42 |
+
else:
|
43 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
44 |
+
return _safe_normalize(shading_nrm)
|
45 |
+
|
46 |
+
def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
47 |
+
smooth_nrm = _safe_normalize(smooth_nrm)
|
48 |
+
smooth_tng = _safe_normalize(smooth_tng)
|
49 |
+
view_vec = _safe_normalize(view_pos - pos)
|
50 |
+
shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
|
51 |
+
return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
|
52 |
+
|
53 |
+
################################################################################
|
54 |
+
# Simple lambertian diffuse BSDF
|
55 |
+
################################################################################
|
56 |
+
|
57 |
+
def bsdf_lambert(nrm, wi):
|
58 |
+
return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
|
59 |
+
|
60 |
+
################################################################################
|
61 |
+
# Frostbite diffuse
|
62 |
+
################################################################################
|
63 |
+
|
64 |
+
def bsdf_frostbite(nrm, wi, wo, linearRoughness):
|
65 |
+
wiDotN = _dot(wi, nrm)
|
66 |
+
woDotN = _dot(wo, nrm)
|
67 |
+
|
68 |
+
h = _safe_normalize(wo + wi)
|
69 |
+
wiDotH = _dot(wi, h)
|
70 |
+
|
71 |
+
energyBias = 0.5 * linearRoughness
|
72 |
+
energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
|
73 |
+
f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
|
74 |
+
f0 = 1.0
|
75 |
+
|
76 |
+
wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
|
77 |
+
woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
|
78 |
+
res = wiScatter * woScatter * energyFactor
|
79 |
+
return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
|
80 |
+
|
81 |
+
################################################################################
|
82 |
+
# Phong specular, loosely based on mitsuba implementation
|
83 |
+
################################################################################
|
84 |
+
|
85 |
+
def bsdf_phong(nrm, wo, wi, N):
|
86 |
+
dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
|
87 |
+
dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
|
88 |
+
return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
|
89 |
+
|
90 |
+
################################################################################
|
91 |
+
# PBR's implementation of GGX specular
|
92 |
+
################################################################################
|
93 |
+
|
94 |
+
specular_epsilon = 1e-4
|
95 |
+
|
96 |
+
def bsdf_fresnel_shlick(f0, f90, cosTheta):
|
97 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
98 |
+
return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
|
99 |
+
|
100 |
+
def bsdf_ndf_ggx(alphaSqr, cosTheta):
|
101 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
102 |
+
d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
|
103 |
+
return alphaSqr / (d * d * math.pi)
|
104 |
+
|
105 |
+
def bsdf_lambda_ggx(alphaSqr, cosTheta):
|
106 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
107 |
+
cosThetaSqr = _cosTheta * _cosTheta
|
108 |
+
tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
|
109 |
+
res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
|
110 |
+
return res
|
111 |
+
|
112 |
+
def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
|
113 |
+
lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
|
114 |
+
lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
|
115 |
+
return 1 / (1 + lambdaI + lambdaO)
|
116 |
+
|
117 |
+
def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
|
118 |
+
_alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
|
119 |
+
alphaSqr = _alpha * _alpha
|
120 |
+
|
121 |
+
h = _safe_normalize(wo + wi)
|
122 |
+
woDotN = _dot(wo, nrm)
|
123 |
+
wiDotN = _dot(wi, nrm)
|
124 |
+
woDotH = _dot(wo, h)
|
125 |
+
nDotH = _dot(nrm, h)
|
126 |
+
|
127 |
+
D = bsdf_ndf_ggx(alphaSqr, nDotH)
|
128 |
+
G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
|
129 |
+
F = bsdf_fresnel_shlick(col, 1, woDotH)
|
130 |
+
|
131 |
+
w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
|
132 |
+
|
133 |
+
frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
|
134 |
+
return torch.where(frontfacing, w, torch.zeros_like(w))
|
135 |
+
|
136 |
+
def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
|
137 |
+
wo = _safe_normalize(view_pos - pos)
|
138 |
+
wi = _safe_normalize(light_pos - pos)
|
139 |
+
|
140 |
+
spec_str = arm[..., 0:1] # x component
|
141 |
+
roughness = arm[..., 1:2] # y component
|
142 |
+
metallic = arm[..., 2:3] # z component
|
143 |
+
ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
|
144 |
+
kd = kd * (1.0 - metallic)
|
145 |
+
|
146 |
+
if BSDF == 0:
|
147 |
+
diffuse = kd * bsdf_lambert(nrm, wi)
|
148 |
+
else:
|
149 |
+
diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
|
150 |
+
specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
|
151 |
+
return diffuse + specular
|
src/models/geometry/render/renderutils/c_src/bsdf.cu
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "common.h"
|
13 |
+
#include "bsdf.h"
|
14 |
+
|
15 |
+
#define SPECULAR_EPSILON 1e-4f
|
16 |
+
|
17 |
+
//------------------------------------------------------------------------
|
18 |
+
// Lambert functions
|
19 |
+
|
20 |
+
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
|
21 |
+
{
|
22 |
+
return max(dot(nrm, wi) / M_PI, 0.0f);
|
23 |
+
}
|
24 |
+
|
25 |
+
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
|
26 |
+
{
|
27 |
+
if (dot(nrm, wi) > 0.0f)
|
28 |
+
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
|
29 |
+
}
|
30 |
+
|
31 |
+
//------------------------------------------------------------------------
|
32 |
+
// Fresnel Schlick
|
33 |
+
|
34 |
+
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
|
35 |
+
{
|
36 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
37 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
38 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
39 |
+
}
|
40 |
+
|
41 |
+
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
|
42 |
+
{
|
43 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
44 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
45 |
+
d_f0 += d_out * (1.0 - scale);
|
46 |
+
d_f90 += d_out * scale;
|
47 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
48 |
+
{
|
49 |
+
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
|
54 |
+
{
|
55 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
56 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
57 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
58 |
+
}
|
59 |
+
|
60 |
+
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
|
61 |
+
{
|
62 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
63 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
64 |
+
d_f0 += d_out * (1.0 - scale);
|
65 |
+
d_f90 += d_out * scale;
|
66 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
67 |
+
{
|
68 |
+
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
//------------------------------------------------------------------------
|
73 |
+
// Frostbite diffuse
|
74 |
+
|
75 |
+
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
|
76 |
+
{
|
77 |
+
float wiDotN = dot(wi, nrm);
|
78 |
+
float woDotN = dot(wo, nrm);
|
79 |
+
if (wiDotN > 0.0f && woDotN > 0.0f)
|
80 |
+
{
|
81 |
+
vec3f h = safeNormalize(wo + wi);
|
82 |
+
float wiDotH = dot(wi, h);
|
83 |
+
|
84 |
+
float energyBias = 0.5f * linearRoughness;
|
85 |
+
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
86 |
+
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
87 |
+
float f0 = 1.f;
|
88 |
+
|
89 |
+
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
90 |
+
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
91 |
+
|
92 |
+
return wiScatter * woScatter * energyFactor;
|
93 |
+
}
|
94 |
+
else return 0.0f;
|
95 |
+
}
|
96 |
+
|
97 |
+
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
|
98 |
+
{
|
99 |
+
float wiDotN = dot(wi, nrm);
|
100 |
+
float woDotN = dot(wo, nrm);
|
101 |
+
|
102 |
+
if (wiDotN > 0.0f && woDotN > 0.0f)
|
103 |
+
{
|
104 |
+
vec3f h = safeNormalize(wo + wi);
|
105 |
+
float wiDotH = dot(wi, h);
|
106 |
+
|
107 |
+
float energyBias = 0.5f * linearRoughness;
|
108 |
+
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
109 |
+
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
110 |
+
float f0 = 1.f;
|
111 |
+
|
112 |
+
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
113 |
+
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
114 |
+
|
115 |
+
// -------------- BWD --------------
|
116 |
+
// Backprop: return wiScatter * woScatter * energyFactor;
|
117 |
+
float d_wiScatter = d_out * woScatter * energyFactor;
|
118 |
+
float d_woScatter = d_out * wiScatter * energyFactor;
|
119 |
+
float d_energyFactor = d_out * wiScatter * woScatter;
|
120 |
+
|
121 |
+
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
122 |
+
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
|
123 |
+
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
|
124 |
+
|
125 |
+
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
|
126 |
+
float d_wiDotN = 0.0f;
|
127 |
+
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
|
128 |
+
|
129 |
+
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
130 |
+
float d_energyBias = d_f90;
|
131 |
+
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
|
132 |
+
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
|
133 |
+
|
134 |
+
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
135 |
+
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
|
136 |
+
|
137 |
+
// Backprop: float energyBias = 0.5f * linearRoughness;
|
138 |
+
d_linearRoughness += 0.5 * d_energyBias;
|
139 |
+
|
140 |
+
// Backprop: float wiDotH = dot(wi, h);
|
141 |
+
vec3f d_h(0);
|
142 |
+
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
|
143 |
+
|
144 |
+
// Backprop: vec3f h = safeNormalize(wo + wi);
|
145 |
+
vec3f d_wo_wi(0);
|
146 |
+
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
|
147 |
+
d_wi += d_wo_wi; d_wo += d_wo_wi;
|
148 |
+
|
149 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
150 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
151 |
+
}
|
152 |
+
}
|
153 |
+
|
154 |
+
//------------------------------------------------------------------------
|
155 |
+
// Ndf GGX
|
156 |
+
|
157 |
+
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
|
158 |
+
{
|
159 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
160 |
+
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
161 |
+
return alphaSqr / (d * d * M_PI);
|
162 |
+
}
|
163 |
+
|
164 |
+
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
165 |
+
{
|
166 |
+
// Torch only back propagates if clamp doesn't trigger
|
167 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
168 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
169 |
+
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
170 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
171 |
+
{
|
172 |
+
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
173 |
+
}
|
174 |
+
}
|
175 |
+
|
176 |
+
//------------------------------------------------------------------------
|
177 |
+
// Lambda GGX
|
178 |
+
|
179 |
+
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
|
180 |
+
{
|
181 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
182 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
183 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
184 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
185 |
+
return res;
|
186 |
+
}
|
187 |
+
|
188 |
+
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
189 |
+
{
|
190 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
191 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
192 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
193 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
194 |
+
|
195 |
+
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
|
196 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
197 |
+
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
|
198 |
+
}
|
199 |
+
|
200 |
+
//------------------------------------------------------------------------
|
201 |
+
// Masking GGX
|
202 |
+
|
203 |
+
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
|
204 |
+
{
|
205 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
206 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
207 |
+
return 1.0f / (1.0f + lambdaI + lambdaO);
|
208 |
+
}
|
209 |
+
|
210 |
+
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
|
211 |
+
{
|
212 |
+
// FWD eval
|
213 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
214 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
215 |
+
|
216 |
+
// BWD eval
|
217 |
+
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
|
218 |
+
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
|
219 |
+
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
|
220 |
+
}
|
221 |
+
|
222 |
+
//------------------------------------------------------------------------
|
223 |
+
// GGX specular
|
224 |
+
|
225 |
+
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
|
226 |
+
{
|
227 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
228 |
+
float alphaSqr = _alpha * _alpha;
|
229 |
+
|
230 |
+
vec3f h = safeNormalize(wo + wi);
|
231 |
+
float woDotN = dot(wo, nrm);
|
232 |
+
float wiDotN = dot(wi, nrm);
|
233 |
+
float woDotH = dot(wo, h);
|
234 |
+
float nDotH = dot(nrm, h);
|
235 |
+
|
236 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
237 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
238 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
239 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
240 |
+
|
241 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
242 |
+
return frontfacing ? w : 0.0f;
|
243 |
+
}
|
244 |
+
|
245 |
+
__device__ void bwdPbrSpecular(
|
246 |
+
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
|
247 |
+
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
|
248 |
+
{
|
249 |
+
///////////////////////////////////////////////////////////////////////
|
250 |
+
// FWD eval
|
251 |
+
|
252 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
253 |
+
float alphaSqr = _alpha * _alpha;
|
254 |
+
|
255 |
+
vec3f h = safeNormalize(wo + wi);
|
256 |
+
float woDotN = dot(wo, nrm);
|
257 |
+
float wiDotN = dot(wi, nrm);
|
258 |
+
float woDotH = dot(wo, h);
|
259 |
+
float nDotH = dot(nrm, h);
|
260 |
+
|
261 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
262 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
263 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
264 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
265 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
266 |
+
|
267 |
+
if (frontfacing)
|
268 |
+
{
|
269 |
+
///////////////////////////////////////////////////////////////////////
|
270 |
+
// BWD eval
|
271 |
+
|
272 |
+
vec3f d_F = d_out * D * G * 0.25f / woDotN;
|
273 |
+
float d_D = sum(d_out * F * G * 0.25f / woDotN);
|
274 |
+
float d_G = sum(d_out * F * D * 0.25f / woDotN);
|
275 |
+
|
276 |
+
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
|
277 |
+
|
278 |
+
vec3f d_f90(0);
|
279 |
+
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
|
280 |
+
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
|
281 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
|
282 |
+
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
|
283 |
+
|
284 |
+
vec3f d_h(0);
|
285 |
+
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
|
286 |
+
bwdDot(wo, h, d_wo, d_h, d_woDotH);
|
287 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
288 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
289 |
+
|
290 |
+
vec3f d_h_unnorm(0);
|
291 |
+
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
|
292 |
+
d_wo += d_h_unnorm;
|
293 |
+
d_wi += d_h_unnorm;
|
294 |
+
|
295 |
+
if (alpha > min_roughness * min_roughness)
|
296 |
+
d_alpha += d_alphaSqr * 2 * alpha;
|
297 |
+
}
|
298 |
+
}
|
299 |
+
|
300 |
+
//------------------------------------------------------------------------
|
301 |
+
// Full PBR BSDF
|
302 |
+
|
303 |
+
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
|
304 |
+
{
|
305 |
+
vec3f wo = safeNormalize(view_pos - pos);
|
306 |
+
vec3f wi = safeNormalize(light_pos - pos);
|
307 |
+
|
308 |
+
float alpha = arm.y * arm.y;
|
309 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
310 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
311 |
+
|
312 |
+
float diff = 0.0f;
|
313 |
+
if (BSDF == 0)
|
314 |
+
diff = fwdLambert(nrm, wi);
|
315 |
+
else
|
316 |
+
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
317 |
+
vec3f diffuse = diff_col * diff;
|
318 |
+
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
|
319 |
+
|
320 |
+
return diffuse + specular;
|
321 |
+
}
|
322 |
+
|
323 |
+
__device__ void bwdPbrBSDF(
|
324 |
+
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
|
325 |
+
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
|
326 |
+
{
|
327 |
+
////////////////////////////////////////////////////////////////////////
|
328 |
+
// FWD
|
329 |
+
vec3f _wi = light_pos - pos;
|
330 |
+
vec3f _wo = view_pos - pos;
|
331 |
+
vec3f wi = safeNormalize(_wi);
|
332 |
+
vec3f wo = safeNormalize(_wo);
|
333 |
+
|
334 |
+
float alpha = arm.y * arm.y;
|
335 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
336 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
337 |
+
float diff = 0.0f;
|
338 |
+
if (BSDF == 0)
|
339 |
+
diff = fwdLambert(nrm, wi);
|
340 |
+
else
|
341 |
+
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
342 |
+
|
343 |
+
////////////////////////////////////////////////////////////////////////
|
344 |
+
// BWD
|
345 |
+
|
346 |
+
float d_alpha(0);
|
347 |
+
vec3f d_spec_col(0), d_wi(0), d_wo(0);
|
348 |
+
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
349 |
+
|
350 |
+
float d_diff = sum(diff_col * d_out);
|
351 |
+
if (BSDF == 0)
|
352 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
|
353 |
+
else
|
354 |
+
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
|
355 |
+
|
356 |
+
// Backprop: diff_col = kd * (1.0f - arm.z)
|
357 |
+
vec3f d_diff_col = d_out * diff;
|
358 |
+
d_kd += d_diff_col * (1.0f - arm.z);
|
359 |
+
d_arm.z -= sum(d_diff_col * kd);
|
360 |
+
|
361 |
+
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
|
362 |
+
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
|
363 |
+
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
|
364 |
+
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
|
365 |
+
|
366 |
+
// Backprop: alpha = arm.y * arm.y
|
367 |
+
d_arm.y += d_alpha * 2 * arm.y;
|
368 |
+
|
369 |
+
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
|
370 |
+
vec3f d__wi(0);
|
371 |
+
bwdSafeNormalize(_wi, d__wi, d_wi);
|
372 |
+
d_light_pos += d__wi;
|
373 |
+
d_pos -= d__wi;
|
374 |
+
|
375 |
+
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
|
376 |
+
vec3f d__wo(0);
|
377 |
+
bwdSafeNormalize(_wo, d__wo, d_wo);
|
378 |
+
d_view_pos += d__wo;
|
379 |
+
d_pos -= d__wo;
|
380 |
+
}
|
381 |
+
|
382 |
+
//------------------------------------------------------------------------
|
383 |
+
// Kernels
|
384 |
+
|
385 |
+
__global__ void LambertFwdKernel(LambertKernelParams p)
|
386 |
+
{
|
387 |
+
// Calculate pixel position.
|
388 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
389 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
390 |
+
unsigned int pz = blockIdx.z;
|
391 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
392 |
+
return;
|
393 |
+
|
394 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
395 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
396 |
+
|
397 |
+
float res = fwdLambert(nrm, wi);
|
398 |
+
|
399 |
+
p.out.store(px, py, pz, res);
|
400 |
+
}
|
401 |
+
|
402 |
+
__global__ void LambertBwdKernel(LambertKernelParams p)
|
403 |
+
{
|
404 |
+
// Calculate pixel position.
|
405 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
406 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
407 |
+
unsigned int pz = blockIdx.z;
|
408 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
409 |
+
return;
|
410 |
+
|
411 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
412 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
413 |
+
float d_out = p.out.fetch1(px, py, pz);
|
414 |
+
|
415 |
+
vec3f d_nrm(0), d_wi(0);
|
416 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
|
417 |
+
|
418 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
419 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
420 |
+
}
|
421 |
+
|
422 |
+
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
|
423 |
+
{
|
424 |
+
// Calculate pixel position.
|
425 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
426 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
427 |
+
unsigned int pz = blockIdx.z;
|
428 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
429 |
+
return;
|
430 |
+
|
431 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
432 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
433 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
434 |
+
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
435 |
+
|
436 |
+
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
|
437 |
+
|
438 |
+
p.out.store(px, py, pz, res);
|
439 |
+
}
|
440 |
+
|
441 |
+
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
|
442 |
+
{
|
443 |
+
// Calculate pixel position.
|
444 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
445 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
446 |
+
unsigned int pz = blockIdx.z;
|
447 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
448 |
+
return;
|
449 |
+
|
450 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
451 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
452 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
453 |
+
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
454 |
+
float d_out = p.out.fetch1(px, py, pz);
|
455 |
+
|
456 |
+
float d_linearRoughness = 0.0f;
|
457 |
+
vec3f d_nrm(0), d_wi(0), d_wo(0);
|
458 |
+
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
|
459 |
+
|
460 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
461 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
462 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
463 |
+
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
|
464 |
+
}
|
465 |
+
|
466 |
+
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
|
467 |
+
{
|
468 |
+
// Calculate pixel position.
|
469 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
470 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
471 |
+
unsigned int pz = blockIdx.z;
|
472 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
473 |
+
return;
|
474 |
+
|
475 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
476 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
477 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
478 |
+
|
479 |
+
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
|
480 |
+
p.out.store(px, py, pz, res);
|
481 |
+
}
|
482 |
+
|
483 |
+
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
|
484 |
+
{
|
485 |
+
// Calculate pixel position.
|
486 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
487 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
488 |
+
unsigned int pz = blockIdx.z;
|
489 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
490 |
+
return;
|
491 |
+
|
492 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
493 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
494 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
495 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
496 |
+
|
497 |
+
vec3f d_f0(0), d_f90(0);
|
498 |
+
float d_cosTheta(0);
|
499 |
+
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
|
500 |
+
|
501 |
+
p.f0.store_grad(px, py, pz, d_f0);
|
502 |
+
p.f90.store_grad(px, py, pz, d_f90);
|
503 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
504 |
+
}
|
505 |
+
|
506 |
+
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
|
507 |
+
{
|
508 |
+
// Calculate pixel position.
|
509 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
510 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
511 |
+
unsigned int pz = blockIdx.z;
|
512 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
513 |
+
return;
|
514 |
+
|
515 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
516 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
517 |
+
float res = fwdNdfGGX(alphaSqr, cosTheta);
|
518 |
+
|
519 |
+
p.out.store(px, py, pz, res);
|
520 |
+
}
|
521 |
+
|
522 |
+
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
|
523 |
+
{
|
524 |
+
// Calculate pixel position.
|
525 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
526 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
527 |
+
unsigned int pz = blockIdx.z;
|
528 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
529 |
+
return;
|
530 |
+
|
531 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
532 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
533 |
+
float d_out = p.out.fetch1(px, py, pz);
|
534 |
+
|
535 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
536 |
+
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
537 |
+
|
538 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
539 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
540 |
+
}
|
541 |
+
|
542 |
+
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
|
543 |
+
{
|
544 |
+
// Calculate pixel position.
|
545 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
546 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
547 |
+
unsigned int pz = blockIdx.z;
|
548 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
549 |
+
return;
|
550 |
+
|
551 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
552 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
553 |
+
float res = fwdLambdaGGX(alphaSqr, cosTheta);
|
554 |
+
|
555 |
+
p.out.store(px, py, pz, res);
|
556 |
+
}
|
557 |
+
|
558 |
+
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
|
559 |
+
{
|
560 |
+
// Calculate pixel position.
|
561 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
562 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
563 |
+
unsigned int pz = blockIdx.z;
|
564 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
565 |
+
return;
|
566 |
+
|
567 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
568 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
569 |
+
float d_out = p.out.fetch1(px, py, pz);
|
570 |
+
|
571 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
572 |
+
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
573 |
+
|
574 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
575 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
576 |
+
}
|
577 |
+
|
578 |
+
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
|
579 |
+
{
|
580 |
+
// Calculate pixel position.
|
581 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
582 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
583 |
+
unsigned int pz = blockIdx.z;
|
584 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
585 |
+
return;
|
586 |
+
|
587 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
588 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
589 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
590 |
+
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
|
591 |
+
|
592 |
+
p.out.store(px, py, pz, res);
|
593 |
+
}
|
594 |
+
|
595 |
+
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
|
596 |
+
{
|
597 |
+
// Calculate pixel position.
|
598 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
599 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
600 |
+
unsigned int pz = blockIdx.z;
|
601 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
602 |
+
return;
|
603 |
+
|
604 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
605 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
606 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
607 |
+
float d_out = p.out.fetch1(px, py, pz);
|
608 |
+
|
609 |
+
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
|
610 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
|
611 |
+
|
612 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
613 |
+
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
|
614 |
+
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
|
615 |
+
}
|
616 |
+
|
617 |
+
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
|
618 |
+
{
|
619 |
+
// Calculate pixel position.
|
620 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
621 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
622 |
+
unsigned int pz = blockIdx.z;
|
623 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
624 |
+
return;
|
625 |
+
|
626 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
627 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
628 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
629 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
630 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
631 |
+
|
632 |
+
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
|
633 |
+
|
634 |
+
p.out.store(px, py, pz, res);
|
635 |
+
}
|
636 |
+
|
637 |
+
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
|
638 |
+
{
|
639 |
+
// Calculate pixel position.
|
640 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
641 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
642 |
+
unsigned int pz = blockIdx.z;
|
643 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
644 |
+
return;
|
645 |
+
|
646 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
647 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
648 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
649 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
650 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
651 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
652 |
+
|
653 |
+
float d_alpha(0);
|
654 |
+
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
|
655 |
+
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
656 |
+
|
657 |
+
p.col.store_grad(px, py, pz, d_col);
|
658 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
659 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
660 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
661 |
+
p.alpha.store_grad(px, py, pz, d_alpha);
|
662 |
+
}
|
663 |
+
|
664 |
+
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
|
665 |
+
{
|
666 |
+
// Calculate pixel position.
|
667 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
668 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
669 |
+
unsigned int pz = blockIdx.z;
|
670 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
671 |
+
return;
|
672 |
+
|
673 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
674 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
675 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
676 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
677 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
678 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
679 |
+
|
680 |
+
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
|
681 |
+
|
682 |
+
p.out.store(px, py, pz, res);
|
683 |
+
}
|
684 |
+
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
|
685 |
+
{
|
686 |
+
// Calculate pixel position.
|
687 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
688 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
689 |
+
unsigned int pz = blockIdx.z;
|
690 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
691 |
+
return;
|
692 |
+
|
693 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
694 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
695 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
696 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
697 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
698 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
699 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
700 |
+
|
701 |
+
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
|
702 |
+
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
|
703 |
+
|
704 |
+
p.kd.store_grad(px, py, pz, d_kd);
|
705 |
+
p.arm.store_grad(px, py, pz, d_arm);
|
706 |
+
p.pos.store_grad(px, py, pz, d_pos);
|
707 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
708 |
+
p.view_pos.store_grad(px, py, pz, d_view_pos);
|
709 |
+
p.light_pos.store_grad(px, py, pz, d_light_pos);
|
710 |
+
}
|