diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0910954583c8d200503de7c8671ada5fac0a0ae7
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,10 @@
+S-Lab License 1.0
+
+Copyright 2023 S-Lab
+
+Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
\ No newline at end of file
diff --git a/README.md b/README.md
index f3afc1d0c84cfee28a5828e478f2328462103a92..3ead504c33d1748ca942986bf9d73201e10a0c89 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,256 @@
+
+
+
+LN3Diff: Scalable Latent Neural Fields Diffusion for Speedy 3D Generation
+
+
+
+
+ S-Lab, Nanyang Technological University1;
+
+
+ Wangxuan Institute of Computer Technology, Peking University2;
+
+
+ Shanghai Artificial Intelligence Laboratory 3
+
+
+
+
+
+
+
+
+
+
+LN3Diff is a feedforward 3D diffusion model that creates high-quality 3D object mesh from text within 8 V100-SECONDS.
+
+
+
+
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+
+ A standing hund. |
+ An UFO space aircraft. |
+ A sailboat with mast. |
+ An 18th century cannon. |
+ A blue plastic chair. |
+
+
+
+
+
+
+For more visual results, go checkout our
project page :page_with_curl:
+
+
+Codes coming soon :facepunch:
+
+
+This repository contains the official implementation of LN3Diff:
+Scalable Latent Neural Fields Diffusion for Speedy 3D Generation
+
+
+
---
-license: other
-license_name: ntu-slab-license
-license_link: https://github.com/NIRVANALAN/LN3Diff/blob/main/LICENSE
----
+
+
+
+
+## :mega: Updates
+
+[03/2024] Initial release.
+
+[04/2024] Inference and training codes on Objaverse, ShapeNet and FFHQ are released, including pre-trained model and training dataset.
+
+
+## :dromedary_camel: TODO
+
+- [x] Release the inference and training code.
+- [x] Release the pre-trained checkpoints of ShapeNet and FFHQ.
+- [x] Release the pre-trained checkpoints of T23D Objaverse model trained with 30K+ instances dataset.
+- [x] Release the stage-1 VAE of Objaverse trained with 80K+ instances dataset.
+- [ ] Add Gradio demo.
+- [ ] Polish the dataset preparation and training doc.
+- [ ] add metrics evaluation scripts and samples.
+- [ ] Lint the code.
+- [ ] Release the new T23D Objaverse model trained with 80K+ instances dataset.
+
+
+
+## :handshake: Citation
+If you find our work useful for your research, please consider citing the paper:
+```
+@misc{lan2024ln3diff,
+title={LN3Diff: Scalable Latent Neural Fields Diffusion for Speedy 3D Generation},
+author={Yushi Lan and Fangzhou Hong and Shuai Yang and Shangchen Zhou and Xuyi Meng and Bo Dai and Xingang Pan and Chen Change Loy},
+year={2024},
+eprint={2403.12019},
+archivePrefix={arXiv},
+primaryClass={cs.CV}
+}
+```
+
+## :desktop_computer: Requirements
+
+NVIDIA GPUs are required for this project.
+We conduct all the training on NVIDIA V100-32GiB (ShapeNet, FFHQ) and NVIDIA A100-80GiB (Objaverse).
+We have test the inference codes on NVIDIA V100.
+We recommend using anaconda to manage the python environments.
+
+The environment can be created via ```conda env create -f environment_ln3diff.yml```, and activated via ```conda activate ln3diff```.
+If you want to reuse your own PyTorch environment, install the following packages in your environment:
+
+```
+# first, check whether you have installed pytorch (>=2.0) and xformer.
+conda install -c conda-forge openexr-python git
+pip install openexr lpips imageio kornia opencv-python tensorboard tqdm timm ffmpeg einops beartype imageio[ffmpeg] blobfile ninja lmdb webdataset opencv-python click torchdiffeq transformers
+pip install git+https://github.com/nupurkmr9/vision-aided-gan.
+```
+
+## :running_woman: Inference
+
+### Download Models
+
+The pretrained stage-1 VAE and stage-2 LDM can be downloaded via [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/yushi001_e_ntu_edu_sg/ErdRV9hCYvlBioObT1v_LZ4Bnwye3sv6p5qiVZPNhI9coQ?e=nJgp8t).
+
+Put the downloaded checkpoints under ```checkpoints``` folder for inference. The checkpoints directory layout should be
+
+ checkpoints
+ ├── ffhq
+ │ └── model_joint_denoise_rec_model1580000.pt
+ ├── objaverse
+ │ ├── model_rec1680000.pt
+ │ └── model_joint_denoise_rec_model2310000.pt
+ ├── shapenet
+ │ └── car
+ │ └── model_joint_denoise_rec_model1580000.pt
+ │ └── chair
+ │ └── model_joint_denoise_rec_model2030000.pt
+ │ └── plane
+ │ └── model_joint_denoise_rec_model770000.pt
+ └── ...
+
+
+
+### Inference Commands
+
+Note that to extract the mesh, 24GiB VRAM is required.
+
+#### Stage-1 VAE 3D reconstruction
+
+For (Objaverse) stage-1 VAE 3D reconstruction and extract VAE latents for diffusion learning, please run
+
+```bash
+bash shell_scripts/final_release/inference/sample_obajverse.sh
+```
+
+which shall give the following result:
+
+
+The marching-cube extracted mesh can be visualized with Blender/MeshLab:
+
+
+
+**We upload the pre-extracted vae latents at [here](https://entuedu-my.sharepoint.com/:f:/g/personal/yushi001_e_ntu_edu_sg/EnXixldDrKhDtrcuPM4vjQYBv06uY58F1mF7f7KVdZ19lQ?e=nXQNdm), which contains the correponding VAE latents (with shape 32x32x12) of 76K G-buffer Objaverse objects. Feel free to use them in your own task.**
+
+For more G-buffer Objaverse examples, download the [demo data](https://entuedu-my.sharepoint.com/:f:/g/personal/yushi001_e_ntu_edu_sg/EoyzVJbMyBhLoKFJbbsq6bYBi1paLwQxIDjTkO1KjI4b1g?e=sJc3rQ).
+
+
+#### Stage-2 Text-to-3D
+
+We train 3D latent diffusion model on top of the stage-1 extracted latents.
+For the following bash inference file, to extract mesh from the generated tri-plane, set ```--export_mesh True```. To change the text prompt, set the ```prompt``` variable. For unconditional sampling, set the cfg guidance ```unconditional_guidance_scale=0```. Feel free to tune the cfg guidance scale to trade off diversity and fidelity.
+
+Note that the diffusion sampling batch size is set to ```4```, which costs around 16GiB VRAM. The mesh extraction of a single instance costs 24GiB VRAM.
+
+For text-to-3D on Objaverse, run
+
+```bash
+bash shell_scripts/final_release/inference/sample_obajverse.sh
+```
+
+For text-to-3D on ShapeNet, run one of the following commands (which conducts T23D on car, chair and plane.):
+```bash
+bash shell_scripts/final_release/inference/sample_shapenet_car_t23d.sh
+```
+
+```bash
+bash shell_scripts/final_release/inference/sample_shapenet_chair_t23d.sh
+```
+
+```bash
+bash shell_scripts/final_release/inference/sample_shapenet_plane_t23d.sh
+```
+
+For text-to-3D on FFHQ, run
+
+```bash
+bash shell_scripts/final_release/inference/sample_ffhq_t23d.sh
+```
+
+
+## :running_woman: Training
+
+### Dataset
+
+For Objaverse, we use the rendering provided by [G-buffer Objaverse](https://aigc3d.github.io/gobjaverse/). A demo subset for stage-1 VAE reconstruction can be downloaded from [here](https://entuedu-my.sharepoint.com/:u:/g/personal/yushi001_e_ntu_edu_sg/Eb6LX2x-EgJLpiHbhRxsN9ABnEaSyjG-tsVBcUr_dQ5dnQ?e=JXWQo1). Note that for Objaverse training, we pre-process the raw data into [wds-dataset](https://github.com/webdataset/webdataset) shards for fast and flexible loading. The sample shard data can be found in [here](https://entuedu-my.sharepoint.com/:f:/g/personal/yushi001_e_ntu_edu_sg/ErtZQgnEH5ZItDqdUaiVbJgBe4nhZveJemQRqDW6Xwp7Zg?e=Zqt6Ss).
+
+For ShapeNet, we render our own data with foreground mask for training, which can be downloaded from [here](https://entuedu-my.sharepoint.com/:f:/g/personal/yushi001_e_ntu_edu_sg/EijBXIC_bUNOo0L3wnJKRqoBCqVnhhT_BReYRc1tc_0lrA?e=VQwWOZ). For training, we convert the raw data to LMDB for faster data loading. The pre-processed LMDB file can be downloaded from [here](https://entuedu-my.sharepoint.com/:f:/g/personal/yushi001_e_ntu_edu_sg/Ev7L8Als8K9JtLtj1G23Cc0BTNDbhCQPadxNLLVS7mV2FQ?e=C5woyE).
+
+
+For FFHQ, we use the pre-processed dataset from [EG3D](https://github.com/NVlabs/eg3d) and compress it into LMDB, which can also be found in the onedrive link above.
+
+
+### Training Commands
+
+Coming soon.
+
+
+## :newspaper_roll: License
+
+Distributed under the S-Lab License. See `LICENSE` for more information.
+
+
+## Contact
+
+If you have any question, please feel free to contact us via `lanyushi15@gmail.com` or Github issues.
\ No newline at end of file
diff --git a/assets/ffhq_eval_pose.pt b/assets/ffhq_eval_pose.pt
new file mode 100644
index 0000000000000000000000000000000000000000..63522b86ca52d24d43ed236190f48b1d57618a0b
Binary files /dev/null and b/assets/ffhq_eval_pose.pt differ
diff --git a/assets/objv_eval_pose.pt b/assets/objv_eval_pose.pt
new file mode 100644
index 0000000000000000000000000000000000000000..00b3ed2f06ff4c24dd60a03ad93da4f2ee35e1cc
Binary files /dev/null and b/assets/objv_eval_pose.pt differ
diff --git a/assets/shapenet_eval_pose.pt b/assets/shapenet_eval_pose.pt
new file mode 100644
index 0000000000000000000000000000000000000000..3c4e124925416fd6c0b8c64ac4fc49aac354a1d9
Binary files /dev/null and b/assets/shapenet_eval_pose.pt differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000.json
new file mode 100755
index 0000000000000000000000000000000000000000..3b1b331a95c7aff3f3f13345caab7999a29cb94e
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.64323258,
+ 0.0,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 1.2196297E-07,
+ 1.00000012,
+ 0.0
+ ],
+ "y": [
+ 0.188542932,
+ 0.0,
+ -0.982064962
+ ],
+ "z": [
+ -0.9820651,
+ 1.2196297E-07,
+ -0.188542932
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000.png
new file mode 100755
index 0000000000000000000000000000000000000000..c36b27b680ab36a8c66db393a6570cd6cd7082c7
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..908d52cea989a098b5c0e3ddb75abc461b9bbc50
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..08dba0b3c450c170a638cfefeba487d03a74b79b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..9e4141f7d3a8e9c7d7781e8185d5fa6b60390505
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9d84e79ce2b00a2375a859021d4bbf9801ebf575
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..6f484ddfe5bb734181ee1dab931f35a3764e5f13
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00000/00000_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001.json
new file mode 100755
index 0000000000000000000000000000000000000000..28424c306820bc019515733470fc0618b0c272ee
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.58724082,
+ 0.425299883,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.258818865,
+ 0.9659259,
+ 2.14746585E-08
+ ],
+ "y": [
+ 0.18211852,
+ 0.0487984978,
+ -0.982064962
+ ],
+ "z": [
+ -0.948601961,
+ -0.2541769,
+ -0.188542962
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001.png
new file mode 100755
index 0000000000000000000000000000000000000000..80b4a61ee13241a5f362e868c4e2ffb46b655bdd
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..0cb858ce81e6d4b733576134eacc09f2e939a353
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..4e505b0212ba5a96e4b13396e9d686691f85470a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..9b51d466ffc7cb4492f36b9e649a6103d2eacce2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9afaee292725668d0673c4d22964c597e246fe79
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..54279aba4c66ed879976cf9144eb67c65bd731a8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00001/00001_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002.json
new file mode 100755
index 0000000000000000000000000000000000000000..7a812273141b145d616e9a5284d84d550ede7aca
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.42308116,
+ 0.8216163,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.50000006,
+ 0.8660254,
+ -6.586047E-09
+ ],
+ "y": [
+ 0.163282961,
+ 0.0942714661,
+ -0.982064962
+ ],
+ "z": [
+ -0.8504932,
+ -0.4910325,
+ -0.188542932
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002.png
new file mode 100755
index 0000000000000000000000000000000000000000..fd8e073498ece40bd57d435dd328acbacbe56b7d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..dda96a85409a86a900316aeae57778445b59e847
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..120cf6572d4262d5a194e0b1bfcdfbbc8226396f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..b3669f523cbe0563694e476b030d8b1fb10932b7
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..c59b47a61ece8a47c0076f342e8dd823a2b03251
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ffa9921fae99ec8fe31ca0175ca6a52020288572
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00002/00002_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003.json
new file mode 100755
index 0000000000000000000000000000000000000000..639780f8f7d55acb24373dbcc6030cf06a674060
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.16194093,
+ 1.16194093,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.707106769,
+ 0.707106769,
+ -5.59717162E-10
+ ],
+ "y": [
+ 0.13332,
+ 0.13332,
+ -0.982064962
+ ],
+ "z": [
+ -0.6944248,
+ -0.694424748,
+ -0.188542962
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003.png
new file mode 100755
index 0000000000000000000000000000000000000000..7d468d7b31f8a824323524ccc8022877e04621fa
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..cbcad2361f252c867dc1059350472b6cdd8c5c43
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..da0cfbeb14085847a6bceb2534fb2ee01add0b01
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..7f0d53ec21769b7a21cdb5e58a01e07bb80a6eeb
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..bbff5c514f3f45e53c8d403204aa8ae6d685c430
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..13c90cf97ff739e66f7b531de34e0f55e28bf88b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00003/00003_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004.json
new file mode 100755
index 0000000000000000000000000000000000000000..be782477c73c321fda69d0956da17f7462196d54
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.821616232,
+ 1.42308116,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.866025448,
+ 0.49999994,
+ -8.742944E-09
+ ],
+ "y": [
+ 0.09427146,
+ 0.163282961,
+ -0.982064962
+ ],
+ "z": [
+ -0.491032422,
+ -0.850493252,
+ -0.188542917
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004.png
new file mode 100755
index 0000000000000000000000000000000000000000..612a2503cd4aa4a2dc8999a1f53e7470635f9f6c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..5c422804ce23a9471c12c8ac8821560bcfed1e53
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..8d0029f6c147c2ef54318e6267e251adf81e13b1
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..7fc83cb80ea20cccf299d5d4c36269c4310259f8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ca97a89bb7b42640d7ad1ac79a787ed5340d2723
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..6acebc0cac845a7a7b1bcd1ec6e35901194fcf8c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00004/00004_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005.json
new file mode 100755
index 0000000000000000000000000000000000000000..9335cc70b90e45924fbc639e45004d2c3deb1bd0
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.425299734,
+ 1.58724082,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.9659259,
+ 0.258818924,
+ -4.48933068E-09
+ ],
+ "y": [
+ 0.0487984866,
+ 0.18211849,
+ -0.982064962
+ ],
+ "z": [
+ -0.254177,
+ -0.9486019,
+ -0.188542932
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005.png
new file mode 100755
index 0000000000000000000000000000000000000000..84b0a1e7df8d330f94edbe01f81f635ac4925ef0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..6f5f4c8fd36e26ef3eab253c1cfdbfd2b6820b0a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..e7fb1adadba343cfcb8e659572b8d9fe4a52f50f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..320d9a4fb161b02bf3700223fcff1d4d5a058c88
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..cbec04984898d272cbec41d0aa5b293948d5dc4e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..1d8c4eca9683ac812eb9251fcf258dde8da3531b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00005/00005_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006.json
new file mode 100755
index 0000000000000000000000000000000000000000..0c67343577bbd4f2517e8339c2eb47e8634015c0
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -7.182798E-08,
+ 1.64323258,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -1.0,
+ -4.37113847E-08,
+ 1.07384328E-15
+ ],
+ "y": [
+ -8.241472E-09,
+ 0.1885429,
+ -0.982064962
+ ],
+ "z": [
+ 4.29274181E-08,
+ -0.982064962,
+ -0.1885429
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006.png
new file mode 100755
index 0000000000000000000000000000000000000000..ae4fe13df0e8946d752f416573a97e153f4fc92d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..3d695d834170b5bdce166da1cdf6b4a905663b7a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..46982bc4e8584f50707bc5a5f7faf401a8a3557a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..3f43a6f96ba76766f2c5a04ccdb439659a664886
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..3eab7cd4fb28e8f95f45ca1243dc2e014ea6cf45
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..2ad447d0a72da718cf6a6174d61c9793a3c7c8b8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00006/00006_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007.json
new file mode 100755
index 0000000000000000000000000000000000000000..76128d37a1bcca03ddc3049b831220e736c9ff58
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -0.4253002,
+ 1.5872407,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.965925753,
+ -0.258819222,
+ 7.957533E-09
+ ],
+ "y": [
+ -0.04879854,
+ 0.182118446,
+ -0.982064962
+ ],
+ "z": [
+ 0.2541773,
+ -0.948601842,
+ -0.1885429
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007.png
new file mode 100755
index 0000000000000000000000000000000000000000..12c877a486c847131498e5c8d3c62cf0d7de5943
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..5aee19dad18fd4122b187d8364a1dc29f2e36a00
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..7aa28b3f1fd6a1a0afe5c1bd46e23e731d0d552b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..9ba38b6a798450364d30e4643b7d733f58d33702
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..27c8974bef3d069382a860c04f708a87a6740122
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..b2fd36dbec4a0b5859966228771af2e8aed61437
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00007/00007_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008.json
new file mode 100755
index 0000000000000000000000000000000000000000..2e25080dbe9239c628e8ae17a94358a62e07d75e
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -0.8216164,
+ 1.42308116,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.8660254,
+ -0.50000006,
+ 8.585387E-09
+ ],
+ "y": [
+ -0.09427148,
+ 0.163282961,
+ -0.982064962
+ ],
+ "z": [
+ 0.491032541,
+ -0.8504932,
+ -0.188542932
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008.png
new file mode 100755
index 0000000000000000000000000000000000000000..a756bb1f16778cc68b3a3bf4a0dded27b11ed198
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..5a332f75b6b7f774af2d8bd77f2f88321122747a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..7087694c4522d316c609198eac11f47cadbc7e4e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..2b4f111c96342c8e34c64232b75c4d80dd487090
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..eefb1c760ac566b16f53b3bd34d58929214c961e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..30fdf829a5f2f6e0899861666e64b4fd38acd36d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00008/00008_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009.json
new file mode 100755
index 0000000000000000000000000000000000000000..f37060bce87ecc13357e25c98b31cf0411bf1dc2
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.16194093,
+ 1.16194093,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.707106769,
+ -0.707106769,
+ 5.59717162E-10
+ ],
+ "y": [
+ -0.13332,
+ 0.13332,
+ -0.982064962
+ ],
+ "z": [
+ 0.6944248,
+ -0.694424748,
+ -0.188542962
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009.png
new file mode 100755
index 0000000000000000000000000000000000000000..e2ac13fe9134034f55471878716facdd0e25f595
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..3c7c5063db53e4f979829eeb52ab7ed46b34bfa4
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..3b6426f8a2114a2ba28b41054d3b8f02037c9636
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..0fd869a2017b9b4e07caff0ef381be69369d2238
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..f327728cafa747bb122e34e620111918ba3cfc17
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..b16ed5e18e07c9e6458b59219f2815a91f77d31d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00009/00009_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010.json
new file mode 100755
index 0000000000000000000000000000000000000000..349b6041ca0a783dd8311c32260923906f49f715
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.42308128,
+ 0.821616,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.499999672,
+ -0.866025567,
+ -1.57553082E-08
+ ],
+ "y": [
+ -0.16328302,
+ 0.09427143,
+ -0.982064962
+ ],
+ "z": [
+ 0.8504934,
+ -0.491032153,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010.png
new file mode 100755
index 0000000000000000000000000000000000000000..8e9e5b5bb8e2a9c3e6d4540ef1a8e1f9bde9bdbc
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..5567adc8263d2ebfe13406df2699e3592b168b3b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..f11d356b12ce387bd71df92d15884241687a110b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..e2775bc6261cc85108306a06a22854c7ede63913
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..cd2f394c9bf12c64c4ce1f10145f0da69e5460e2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..68a0196060921bd44102dabf780f551a6cbdd12f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00010/00010_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011.json
new file mode 100755
index 0000000000000000000000000000000000000000..49d619679d24481285ec2833125c474bc1002ade
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.58724082,
+ 0.425299674,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.258818716,
+ -0.965925932,
+ -1.65035559E-08
+ ],
+ "y": [
+ -0.182118535,
+ 0.04879846,
+ -0.982064962
+ ],
+ "z": [
+ 0.948601961,
+ -0.2541768,
+ -0.188542962
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011.png
new file mode 100755
index 0000000000000000000000000000000000000000..90c269f732780189b4031d26ba7784bad2d6c2e2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..8385df987212a5b0ab83c20a3bf4adf983d2c7a5
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..91021edb0b2d6de9650f56faa61259850f7d87f8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..2a7a1db386984bbc8aab9e0f37e38cfe72b56111
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..e5ab4ab3f994fa7596a2f31bc8abcc74f96fe74c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ea05975f1ef29a7a37f6d34f95f642e44a3fb51d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00011/00011_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012.json
new file mode 100755
index 0000000000000000000000000000000000000000..5adedb55a2bd1a593423af594960a01131f2cb25
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.64323258,
+ -1.43655953E-07,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 1.23958557E-07,
+ -1.00000012,
+ 1.04893623E-08
+ ],
+ "y": [
+ -0.188542947,
+ -1.04893623E-08,
+ -0.982064962
+ ],
+ "z": [
+ 0.9820651,
+ 1.2196297E-07,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012.png
new file mode 100755
index 0000000000000000000000000000000000000000..a0f9987d62c97d2d760ac56b1c8b9204337289de
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..4919b78dc4203148eeee56bef5f379fb294e4b9e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..42e36d0e3dd87f968bc149fb192509d474d0abee
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..7fcdf300ac4e2532b6932a0e40c5d53e6fcd7e46
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..28804c69945af7db4b0991a2c4f14d314ffe9d0c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..00fdef46ffb5775ced2e4bdf416753c1bf905180
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00012/00012_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013.json
new file mode 100755
index 0000000000000000000000000000000000000000..ab32dece65550b09bb6030a19f440e712017bab0
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.5872407,
+ -0.42530033,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.258819252,
+ -0.965925753,
+ 1.02842872E-08
+ ],
+ "y": [
+ -0.182118475,
+ -0.0487985574,
+ -0.982064962
+ ],
+ "z": [
+ 0.9486018,
+ 0.2541773,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013.png
new file mode 100755
index 0000000000000000000000000000000000000000..6b76fadfcaaa13bf61bf9fade29aa436c7220a6f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..cf352aafadc990f535b456aa8654d1162c5276d2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..3b7f7cef62f93d12827bfec5ed4a83cee63aad94
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..3682e20512f4311e32d22b4d4855c72b0c2cf904
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..72b9374b0774a85ac3d46bdaf717f1e36a75b681
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..c8ba7b10b50f299bf0605c425840f89895aafe06
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00013/00013_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014.json
new file mode 100755
index 0000000000000000000000000000000000000000..b9561eec984d9fb3151e803bb4f46f11c43b2189
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.42308068,
+ -0.821617,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.5000005,
+ -0.8660251,
+ -2.61196842E-09
+ ],
+ "y": [
+ -0.163282931,
+ -0.09427156,
+ -0.982064962
+ ],
+ "z": [
+ 0.8504929,
+ 0.491032928,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014.png
new file mode 100755
index 0000000000000000000000000000000000000000..de638c6f4d8806597ff0eee25fb29c3f621c6c21
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..239828a8daa26a3cddd2cb252fba6c247e1bb748
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9e94adf7d7ddea907d57bb6276a0d02d0d6429be
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..eb5a92d007bf4d90e5fe9e1836c455e9dc767eef
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..817d5fc57d9edc8b43e04b64b24532d26113f236
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..80c34bfaf1a3870cd48a959aae8f838af4f54e19
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00014/00014_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015.json
new file mode 100755
index 0000000000000000000000000000000000000000..5734baf7d7b2bca80e5d1065612952475b91c5f5
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.16194069,
+ -1.161941,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.707107,
+ -0.7071067,
+ 4.072612E-09
+ ],
+ "y": [
+ -0.133319989,
+ -0.133320034,
+ -0.982064962
+ ],
+ "z": [
+ 0.694424748,
+ 0.694425,
+ -0.188542977
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015.png
new file mode 100755
index 0000000000000000000000000000000000000000..7d08125b699f4e44b5ece0c37fea4002437190fd
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..8db75aa1813a25a630716a9be98d6f58a4d97605
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..752d8e4a02469921d1fd051ebdd5a54179ea4045
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..9db8b5bf025634720b5d3851bd8e681188fd1ad8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..59ffe83ad9ed3493eb1f99716b71d46ea14eaa30
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..6dc761dc38ee6402843aaf7567c0a2c805f7d757
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00015/00015_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016.json
new file mode 100755
index 0000000000000000000000000000000000000000..05da6734584da9e096d30d4a10673ca1634422a0
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -0.8216161,
+ -1.42308116,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.866025448,
+ -0.49999994,
+ 2.51798427E-09
+ ],
+ "y": [
+ -0.0942714661,
+ -0.16328299,
+ -0.982064962
+ ],
+ "z": [
+ 0.491032422,
+ 0.850493252,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016.png
new file mode 100755
index 0000000000000000000000000000000000000000..da20341324b9f4b84fd4ad0eac913c33ff668687
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..d29ffe0577df22c80d4b1440b69cf75067684b01
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..cb904ceeda4146bba1bcde2fa3b037b7a1b6fda5
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..a3879987606e3d15b03df0e1841c1683f91d1620
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..85fd6ca40213e587237a1b8ba086d79635d39f8c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..61f795c0c4f00f4ff9face3dfdd6aea9fc3fe63d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00016/00016_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017.json
new file mode 100755
index 0000000000000000000000000000000000000000..f2fedbff866c3c1af59bb86fd7f71f197f5c88a9
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -0.425299078,
+ -1.587241,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.965926051,
+ -0.258818537,
+ -4.6904205E-09
+ ],
+ "y": [
+ -0.04879841,
+ -0.1821185,
+ -0.982064962
+ ],
+ "z": [
+ 0.254176617,
+ 0.9486021,
+ -0.188542932
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017.png
new file mode 100755
index 0000000000000000000000000000000000000000..be0268d0df8e465439849917bd6875fa72d62b85
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..151dfc25b9b8b2d6a32cca5c1318c0372b31f164
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..fbbf6fc5c0f60384c0f5529a237c5658b7548ac9
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..9eda72a51264a13ffa7daa590b94e8b73e981ee5
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..e6e8eea427c0ce1e963296cf10db2f29c6a35d1f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..d0fae4b0e643fc27819d38fa7a2f7a400c644e13
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00017/00017_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018.json
new file mode 100755
index 0000000000000000000000000000000000000000..b8de376711c139fe0f40c24c504bdb074783cafd
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.95953529E-08,
+ -1.64323258,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 1.00000036,
+ 1.19248815E-08,
+ 2.22708732E-16
+ ],
+ "y": [
+ 2.24835173E-09,
+ -0.188542917,
+ -0.982064962
+ ],
+ "z": [
+ -1.17110082E-08,
+ 0.9820653,
+ -0.188542917
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018.png
new file mode 100755
index 0000000000000000000000000000000000000000..5dd04412efdc573ed7a992d80665873f7e4357b8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..71931894ae16840d230a9593f032e0098ca1b63f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9f3a8c41dc22c5df7eddb9f2feacaf0e1250df00
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..77e93e805894882efd8cc41e68ad6a13d2aff21a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ef3e2868e38d794992e3515aeab08df5f201e2bb
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..7bb216384a8a1075a75ad5adda3c53f5bb435e34
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00018/00018_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019.json
new file mode 100755
index 0000000000000000000000000000000000000000..85344ce33472f09937ed4aff718df0f19d99cd97
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.425300568,
+ -1.58724058,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.9659258,
+ 0.258819431,
+ 8.49588755E-09
+ ],
+ "y": [
+ 0.04879858,
+ -0.182118461,
+ -0.982064962
+ ],
+ "z": [
+ -0.2541775,
+ 0.948601842,
+ -0.188542932
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019.png
new file mode 100755
index 0000000000000000000000000000000000000000..22913561502a5c8ec2207f92657ddbdde0f50996
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..33e66639e547d5bd5484f96c21d3a89cfbce65ce
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..c47702dcc1f7899f5413cefc2d4c0f639ce13d88
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..6db3152a2b6d4294ce19a7a12d970d622f6cb5dd
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..0cf2dcfaed865c658936febf3656f4ade38927a1
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..b3d4c40a28dc4f87eef023de0f72c8ec9a2fbdd4
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00019/00019_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020.json
new file mode 100755
index 0000000000000000000000000000000000000000..8ee9e4ba0c47f3a9d5d76ced76b06084c524805c
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.8216169,
+ -1.4230808,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.8660252,
+ 0.5000004,
+ -7.682651E-09
+ ],
+ "y": [
+ 0.09427155,
+ -0.163282946,
+ -0.982064962
+ ],
+ "z": [
+ -0.4910329,
+ 0.850493,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020.png
new file mode 100755
index 0000000000000000000000000000000000000000..fdb2bbd187352a7585eafdff4b172f4c082cbab6
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..ced2ca59dfbe39c1ce0dcbe6134c60e5fe298b6f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..7eb6b2af62d9ed5e9bafda5e7705496d7a47a8be
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..a270cb8248c0ba885e38a744eca208b217d714e0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ca07015c9987f7b5f9e60ac042b20de3bc1bb49b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..3587977847183d404fcdd8b9e602428751722c2d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00020/00020_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021.json
new file mode 100755
index 0000000000000000000000000000000000000000..37601269b4238af1c7cf8d1932065302fcf50878
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.16194129,
+ -1.16194046,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.707106531,
+ 0.707107067,
+ 3.714943E-09
+ ],
+ "y": [
+ 0.133320048,
+ -0.133319944,
+ -0.982064962
+ ],
+ "z": [
+ -0.694425046,
+ 0.69442457,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021.png
new file mode 100755
index 0000000000000000000000000000000000000000..0495268056d27bf144e67a25dcba5214dd714768
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..224206f122fdaca47d8913101274075160274a6b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..675e79be795b33e62314dd76e08f2ece6a6b5e74
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..57855dd3e493f847cd4d528a66a7c93335041a5c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..779c20f5ceccd9383b95a6fd47621898f7a49ea7
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..bc04f37398cbafe5733d7edb7d3c707cc21ea3da
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00021/00021_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022.json
new file mode 100755
index 0000000000000000000000000000000000000000..c450a100a123591bc338fed21d1d06f612e532da
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.4230814,
+ -0.8216159,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.50000006,
+ 0.8660256,
+ 6.3035217E-09
+ ],
+ "y": [
+ 0.163283,
+ -0.0942714438,
+ -0.982064962
+ ],
+ "z": [
+ -0.8504934,
+ 0.491032541,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022.png
new file mode 100755
index 0000000000000000000000000000000000000000..cb92008d3952fabad5496742f61d1a37eb677db9
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..7b041d7633418233b79e051e8a5fcab01c0e5af0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..dff11d891b74f78924fabcac5b1146bfcc73386f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..3be344667b845a36976ca2672a6bcc10168273d1
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..d83cc6c278c42558ad1c77860cc3a77365019906
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..991c02acfd7832be0e61190ea5ff41c3bce887f0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00022/00022_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023.json
new file mode 100755
index 0000000000000000000000000000000000000000..66c2bf0110c9b6bfb15a839b345801f7c51fff64
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.58724082,
+ -0.425299525,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.258819073,
+ 0.965925932,
+ 2.69960054E-08
+ ],
+ "y": [
+ 0.182118535,
+ -0.0487984642,
+ -0.982064962
+ ],
+ "z": [
+ -0.948602,
+ 0.254177123,
+ -0.188542962
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023.png
new file mode 100755
index 0000000000000000000000000000000000000000..a69ce63ccf33fb58721c35fd0a63f0b0cb4ba01b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..1dc3671e0b4ee68ccaaabd3bc0ced241eba71141
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..5a8db9ddc3a648a82e5c15844afb5e1c78b424da
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..2206a543e6f15f5d0294a1bdfce775706de4f504
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..e60f2a78ade7d447a3c38982dc074a717319c350
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..23055625d30e3d64f160d4d6700be23f0fe7e82f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00023/00023_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024.json
new file mode 100755
index 0000000000000000000000000000000000000000..2c11a3f21446b739225bd1b2d363a0c90cd05e07
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.64323258,
+ 2.873119E-07,
+ 0.315478027
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 1.2196297E-07,
+ 1.00000012,
+ 1.04893623E-08
+ ],
+ "y": [
+ 0.188542947,
+ 1.04893623E-08,
+ -0.982064962
+ ],
+ "z": [
+ -0.9820651,
+ 1.23958557E-07,
+ -0.188542947
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024.png
new file mode 100755
index 0000000000000000000000000000000000000000..04d4a6742ffdc4ad2107cdd0d271617e13278d21
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..8946bf2ec44c5ae2285145f4e9e2b609eaa50a83
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..2e62409cf9e882a71bf09745e6df628f3dc2085d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..468bc6f1e6443bf5044bc693f82c751e31561cc3
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..92ba3e5f953ca027cfd3bbd82dda4c28768619db
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9092f3f19d33c9da15aa403ddd9559b39721fda9
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00024/00024_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025.json
new file mode 100755
index 0000000000000000000000000000000000000000..f32ce32b1783ffbc3628f630fa85fdc0b074cbe6
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.0,
+ 0.0,
+ 1.67324233
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -1.0,
+ 0.0,
+ 0.0
+ ],
+ "y": [
+ 0.0,
+ 1.0,
+ 1.34358856E-07
+ ],
+ "z": [
+ 0.0,
+ 1.34358856E-07,
+ -1.0
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025.png
new file mode 100755
index 0000000000000000000000000000000000000000..4ea0d335579d56279d713949c2545174cedd1af7
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..a1952501b62834806fbef2ba449c139c2405a107
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..e657cdefe5f2a072628e04947acc5d12cbe23a51
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..ea952d850f3f23ac1ac2ea08a2a799dbef45330b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..65db9be8e2cccc7d5f3d0d8575082ed99fe80bd2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..b3b29810366ece67b4da6e13c9e56655b456fca7
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00025/00025_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026.json
new file mode 100755
index 0000000000000000000000000000000000000000..b3ddc1f73ea9ea43c27c134a5a6543d2577141ba
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.0,
+ 0.0,
+ -1.67324233
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -1.0,
+ 0.0,
+ 0.0
+ ],
+ "y": [
+ 0.0,
+ -1.0,
+ 1.34358856E-07
+ ],
+ "z": [
+ 0.0,
+ 1.34358856E-07,
+ 1.0
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026.png
new file mode 100755
index 0000000000000000000000000000000000000000..9b38e49af4da177672519fb157d734561edf4e52
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..df3c7c10f582c5401393e100d4e13bfb07d380fc
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9ba6872ae36fa1265d7142cfc254258ae5d31e7d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..e61f5be68619c9a54acb24c9f017506cfe6f947a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..f21589b993d8e7117c6957145c3684e90271f83d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..c80a5b3b9c83840fec4f7d1f8f317228c662d28d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00026/00026_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027.json
new file mode 100755
index 0000000000000000000000000000000000000000..dc5308bb7d96dba01227bcd3c48d3705726d4894
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.98590279,
+ 0.0,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 2.803585E-08,
+ 1.0,
+ 0.0
+ ],
+ "y": [
+ -0.03488534,
+ 0.0,
+ -0.9993913
+ ],
+ "z": [
+ -0.9993914,
+ 2.803585E-08,
+ 0.03488534
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027.png
new file mode 100755
index 0000000000000000000000000000000000000000..fe1f0d467bcc23352cbd13fc9e2c602144e8e7d5
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..df0d778ee9348becc033ba8feae06d7ff2eec174
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..8dab1b9f60cb6956e153007028336cdebed84cab
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..91c0c9df3a336da5b3e6bc92a82f27460149b6e2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..139e535db98006f461cb9d94b33b4255d026eba2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..da04fa359d6d37d9d87fd8faf45954b30d9c6011
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00027/00027_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028.json
new file mode 100755
index 0000000000000000000000000000000000000000..ae65d80cbe44e6ab86626c19856dda804ae1abcd
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.7198422,
+ 0.9929514,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.5,
+ 0.8660254,
+ -1.35829681E-09
+ ],
+ "y": [
+ -0.030211594,
+ -0.0174426716,
+ -0.9993913
+ ],
+ "z": [
+ -0.865498245,
+ -0.499695659,
+ 0.0348853432
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028.png
new file mode 100755
index 0000000000000000000000000000000000000000..7d7717a5263df6eeeb2613456102c884f3de7edc
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..1f42b6a42822ac916bede0a578547d10a1346998
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..a7be33e91c77c9a900052c67b30d447171bbb2fc
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..1d76caa31bf64f0824a0ccda22fb042ac5f75463
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..1eb0a6d6fd6d8f96f7fb1b994862805bae8dcbfe
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..35b27d8667395d4c731dfe3569b68aefa469139f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00028/00028_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029.json
new file mode 100755
index 0000000000000000000000000000000000000000..49d1c5639f5fec83abca1ef171e4e388eccd89ce
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.992951334,
+ 1.71984231,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.8660254,
+ 0.49999997,
+ -1.137825E-09
+ ],
+ "y": [
+ -0.0174426716,
+ -0.0302115958,
+ -0.9993913
+ ],
+ "z": [
+ -0.499695629,
+ -0.8654983,
+ 0.0348853432
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029.png
new file mode 100755
index 0000000000000000000000000000000000000000..4b2875cfd5aff7b5d189392144b4ef16aa40ace2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..f3d6ec9d45d69e32ef1c06a4309e9a1ca5db1c3a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ba6838fe63624f5ca6a6f180f2abf330cae350c8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..19d91e60e791dfff403955e05e1c691c0940840f
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..5e47936b4db3a244a18a652184b310c0388248d6
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..1980973b0d8438b2d96ec93049034089274ea1fd
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00029/00029_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030.json
new file mode 100755
index 0000000000000000000000000000000000000000..81c7d3753fb6e31899643f5fcfbe1b3b142f130f
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -8.68065655E-08,
+ 1.98590279,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -1.0,
+ -4.37113918E-08,
+ 4.93610766E-17
+ ],
+ "y": [
+ 1.52488688E-09,
+ -0.0348853469,
+ -0.9993913
+ ],
+ "z": [
+ 4.36847856E-08,
+ -0.9993913,
+ 0.0348853469
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030.png
new file mode 100755
index 0000000000000000000000000000000000000000..7901fae2ffca69c7293bc27035e9f735cecb534a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..0db041cc8d6d802e70f5c7641ab48878c00e39c9
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..b27d25db181c718cfb526640db6a76c0dfc89c02
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..1711f7a2a7ae2549a61ba57a5c1c34508bee4375
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..d0271259728dbafd138cb18d792aa010fe9f5f33
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ab2d429572f4485fbbeeebbf30864c349eb5e882
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00030/00030_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031.json
new file mode 100755
index 0000000000000000000000000000000000000000..25116bba9ae4261c551447e9e6e2908600fc4812
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -0.9929515,
+ 1.7198422,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.8660253,
+ -0.50000006,
+ 1.34770461E-09
+ ],
+ "y": [
+ 0.0174426753,
+ -0.030211594,
+ -0.9993913
+ ],
+ "z": [
+ 0.499695748,
+ -0.865498245,
+ 0.0348853469
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031.png
new file mode 100755
index 0000000000000000000000000000000000000000..4bb350d0622849580694cec4e9fd8a9141ffa428
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..8fdbe158646aea037c751a66490b4207aa44b302
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..0eb84f321a2ee97b02b574882f9a265ccebe5a1b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..53783ab25183d81cdaa03983b45ebd3f4e902c58
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..5ecf4d132d0dd64f1d2efd94af6e64da2c62e30d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..8f96633c7114ac99de04c459dcb36c80feb8c0d5
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00031/00031_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032.json
new file mode 100755
index 0000000000000000000000000000000000000000..5e14bcd41505851424eb520ad6781e47bfab7055
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.71984243,
+ 0.992951035,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -0.499999821,
+ -0.8660255,
+ -2.37919018E-10
+ ],
+ "y": [
+ 0.0302115921,
+ -0.0174426623,
+ -0.9993913
+ ],
+ "z": [
+ 0.865498364,
+ -0.49969548,
+ 0.03488534
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032.png
new file mode 100755
index 0000000000000000000000000000000000000000..677b9fb1bc86691683b334e529c7183c38a053ca
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..55bcd948c0f7b1254f24721c7a6152ba2c472c70
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..0e75adfa5b42da8a088f376f21cf3904ef6d5054
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..e9257c8a420285c5c246e8c8a007671a46b47a99
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..4828bae7475a599dc7e92d9c89fb14785fbc41ca
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..20127d3b17c31101509a94f6afa4bf240e01140a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00032/00032_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033.json
new file mode 100755
index 0000000000000000000000000000000000000000..5ce0c32a06bc2c2de7225381a648edf25230a784
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.98590279,
+ -1.73613131E-07,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 1.96643555E-07,
+ -1.0,
+ 1.62417335E-09
+ ],
+ "y": [
+ 0.0348853432,
+ 4.25795044E-09,
+ -0.9993913
+ ],
+ "z": [
+ 0.9993914,
+ 1.965976E-07,
+ 0.0348853432
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033.png
new file mode 100755
index 0000000000000000000000000000000000000000..5fb74d7667cef61562123b12fe7f1dbdd3dca30e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..d30339be6e3fa20e69c6d348c0c1bd66ced02613
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..533a74c6ca952074a8b9629e42029ab6fbe418f0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..b0949a277991dc0d19a534466cf154cdfdc235a3
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..fa3d44438eafc32e3467a8f719595287c80f473a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..448e12308de30b2f345aecaa18be528e472a735e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00033/00033_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034.json
new file mode 100755
index 0000000000000000000000000000000000000000..981207097888606714e3419edd364aaa1a5b3579
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -1.71984172,
+ -0.9929522,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.500000358,
+ -0.866025269,
+ -2.40818832E-09
+ ],
+ "y": [
+ 0.0302115884,
+ 0.0174426865,
+ -0.9993913
+ ],
+ "z": [
+ 0.8654981,
+ 0.499696046,
+ 0.0348853469
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034.png
new file mode 100755
index 0000000000000000000000000000000000000000..8afa369cf4e375bbd5c3cc993708d26742755710
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..d2ebc94299523271f6908e629ebc723a4559afea
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..5a308fec2f88c55ba72478e407a42d059bec9f32
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..a57eeccee07ebd611dbc01cc98f1f53ad4965a9e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..92258928f124c66846a9cb4741130db3a14f6d5e
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..3d073be5ff1be2bdf55c69d42b659045f90811ca
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00034/00034_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035.json
new file mode 100755
index 0000000000000000000000000000000000000000..135a1d5c6f2b7e07dbe3cc4c91832705c2eccfb2
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ -0.9929512,
+ -1.71984231,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.8660254,
+ -0.4999999,
+ -1.93233674E-09
+ ],
+ "y": [
+ 0.0174426679,
+ 0.0302115958,
+ -0.9993913
+ ],
+ "z": [
+ 0.499695569,
+ 0.865498245,
+ 0.0348853432
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035.png
new file mode 100755
index 0000000000000000000000000000000000000000..15d55bafb60df57dd1a7f43c99ff68c2b260cc47
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..c7b6c6d8df3521fb56cab713f8ff31a8229b480d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..db3dcde4eb16773dc6c5c39607aad34612a5abbe
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..fb05f9364b82d15d2f0a8ff209683dfb55b6e4f2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..27d8fa7ac9b6e7cae6821095534418221604c7c2
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..9fa65af2bbcd3f677c302884123487433843a73d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00035/00035_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036.json
new file mode 100755
index 0000000000000000000000000000000000000000..d32f07ac2b5960d3c390db4a44c72d6885b0893e
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 2.3681654E-08,
+ -1.98590279,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 1.0,
+ 1.19248806E-08,
+ 1.41831333E-17
+ ],
+ "y": [
+ -4.16003565E-10,
+ 0.0348853469,
+ -0.9993913
+ ],
+ "z": [
+ -1.19176224E-08,
+ 0.9993914,
+ 0.0348853469
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036.png
new file mode 100755
index 0000000000000000000000000000000000000000..0dff390e48bd1f82d4ca234f7fd937108d25bdc8
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..a88988d7b420b4f462cc04ebf6a077766cebf66b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..3426f8fc96a06c691e6b3617113caaf2bff47d16
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..57aa1a19c2172bd0f0312f7db6ffde4ce2a4596b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..1a0d44446a8168cd9c097a186e83b9687f40fcb6
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..8ac4f2f9fde133f12f385fe2d63bfcf6f8a73b4b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00036/00036_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037.json
new file mode 100755
index 0000000000000000000000000000000000000000..5ea4fda75ea7f5d20ed03442ff60ba63b380be3b
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 0.9929521,
+ -1.71984184,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.86602515,
+ 0.500000358,
+ 1.38204725E-09
+ ],
+ "y": [
+ -0.01744268,
+ 0.0302115828,
+ -0.9993913
+ ],
+ "z": [
+ -0.499696016,
+ 0.865498,
+ 0.03488534
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037.png
new file mode 100755
index 0000000000000000000000000000000000000000..91e7e5d3a17c418568faf2ee5d0b88735b5b3907
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..039d2f5033ba83588c212e2903d838b961f91d67
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..766ac88a2ca2a62476eead14d93e01eabf0a79ec
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..f859b637d6c0867e82f49a04f411b35bd0eb29f0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..f5305a16e2417be6df76e9d62ae429cde7806eb5
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..f20c4789cfeb2ad76067fb5003091830ec62b34a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00037/00037_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038.json
new file mode 100755
index 0000000000000000000000000000000000000000..481ffbc2c53d7db4491eddba9a1b3d620133a483
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.71984255,
+ -0.9929509,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ 0.499999762,
+ 0.866025567,
+ -2.685876E-10
+ ],
+ "y": [
+ -0.0302115921,
+ 0.01744266,
+ -0.9993913
+ ],
+ "z": [
+ -0.8654984,
+ 0.49969542,
+ 0.0348853357
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038.png
new file mode 100755
index 0000000000000000000000000000000000000000..1fd7801eabfcadcb32efba62a8b22f70e9ff05ff
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..a379a89784a4b8ff9dbaf612833762bdcb31cc4b
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..254f7b956f260c5e8c73bc6c136467401eafe606
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..a77440c545c49f56d1b27646049b1c9b293e117a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..4296ee8ad3f542cc82436642afba7175e374f88c
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..ffd4aa1ec3e2d60dab62dcfa79fe406286870385
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00038/00038_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039.json b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039.json
new file mode 100755
index 0000000000000000000000000000000000000000..3d3d6480e63d6af81a01823f43e64574dfad5445
--- /dev/null
+++ b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039.json
@@ -0,0 +1,47 @@
+{
+ "max_depth": 5.0,
+ "bbox": [
+ [
+ -0.330194056,
+ -0.449999958,
+ -0.263895959
+ ],
+ [
+ 0.330194056,
+ 0.450000018,
+ 0.263895959
+ ]
+ ],
+ "origin": [
+ 1.98590279,
+ 3.47226262E-07,
+ -0.06932109
+ ],
+ "x_fov": 0.691150367,
+ "y_fov": 0.691150367,
+ "x": [
+ -1.406178E-07,
+ 1.0,
+ -1.00960407E-09
+ ],
+ "y": [
+ -0.03488534,
+ -6.89172763E-09,
+ -0.9993913
+ ],
+ "z": [
+ -0.9993914,
+ -1.40479926E-07,
+ 0.03488534
+ ],
+ "scale": [
+ 0.0023696092,
+ 0.0023696092,
+ 0.0023696092
+ ],
+ "offset": [
+ 0.0,
+ -0.4037283,
+ -0.06950388
+ ]
+}
\ No newline at end of file
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039.png
new file mode 100755
index 0000000000000000000000000000000000000000..898a09e18f70b0da46239adf2fc7360b7490711d
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_albedo.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_albedo.png
new file mode 100755
index 0000000000000000000000000000000000000000..5fe1497b079d63bb72d76a0bf4d2459c64360687
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_albedo.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_hdr.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_hdr.exr
new file mode 100755
index 0000000000000000000000000000000000000000..392a65e338a4aedd5ac4ca0e7380ffa491c33d7a
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_hdr.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_mr.png b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_mr.png
new file mode 100755
index 0000000000000000000000000000000000000000..b330f8a01e35421cc6c48b531e2528046a8289d0
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_mr.png differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_nd.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_nd.exr
new file mode 100755
index 0000000000000000000000000000000000000000..08fce31eb03c7a2911ffed900274a7023b3728cc
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_nd.exr differ
diff --git a/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_ng.exr b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_ng.exr
new file mode 100755
index 0000000000000000000000000000000000000000..4760bf7c1462207d56e31e6271ec15aa6be3e4ef
Binary files /dev/null and b/assets/stage1_vae_reconstruction/Objaverse/Animals/0/10120/campos_512_v4/00039/00039_ng.exr differ
diff --git a/assets/stage1_vae_reconstruction/reconstruction_result/mesh-visualization.png b/assets/stage1_vae_reconstruction/reconstruction_result/mesh-visualization.png
new file mode 100644
index 0000000000000000000000000000000000000000..37ba8c892d8191ab439755d03c055a72976ab0ab
Binary files /dev/null and b/assets/stage1_vae_reconstruction/reconstruction_result/mesh-visualization.png differ
diff --git a/assets/stage1_vae_reconstruction/reconstruction_result/vae-stage1-demo-result.mp4 b/assets/stage1_vae_reconstruction/reconstruction_result/vae-stage1-demo-result.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9041b6111535664270e74fea048afbd50cdc63f1
Binary files /dev/null and b/assets/stage1_vae_reconstruction/reconstruction_result/vae-stage1-demo-result.mp4 differ
diff --git a/assets/t23d/blue-plastic-chair.gif b/assets/t23d/blue-plastic-chair.gif
new file mode 100644
index 0000000000000000000000000000000000000000..a9b84e335c1ee3cf952a0e2ef80e15f1b8a9bc74
Binary files /dev/null and b/assets/t23d/blue-plastic-chair.gif differ
diff --git a/assets/t23d/cannon.gif b/assets/t23d/cannon.gif
new file mode 100644
index 0000000000000000000000000000000000000000..4c71fb158fe56e959b5ee3a15c2b1859fc324cd4
Binary files /dev/null and b/assets/t23d/cannon.gif differ
diff --git a/assets/t23d/mast.gif b/assets/t23d/mast.gif
new file mode 100644
index 0000000000000000000000000000000000000000..3334d399af4cd369ec16a246ae2c2317fafea9b8
Binary files /dev/null and b/assets/t23d/mast.gif differ
diff --git a/assets/t23d/standing-hund.gif b/assets/t23d/standing-hund.gif
new file mode 100644
index 0000000000000000000000000000000000000000..fc107a9d15538e1338f7031f2b3b75c4d62a6838
Binary files /dev/null and b/assets/t23d/standing-hund.gif differ
diff --git a/assets/t23d/ufo.gif b/assets/t23d/ufo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..a4f1bc2bd3a43eeeb568083998e51f660ffb4dd0
Binary files /dev/null and b/assets/t23d/ufo.gif differ
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cldm/__pycache__/cldm.cpython-39.pyc b/cldm/__pycache__/cldm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cbd589f96da49d34451252ee0d47379d48d5ddc
Binary files /dev/null and b/cldm/__pycache__/cldm.cpython-39.pyc differ
diff --git a/cldm/cldm.py b/cldm/cldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfbb9200f4d56fc0542db6a11a35ecdc7c8b656a
--- /dev/null
+++ b/cldm/cldm.py
@@ -0,0 +1,456 @@
+import torch
+import torch as th
+import torch.nn as nn
+
+from ldm.modules.diffusionmodules.util import (
+ conv_nd,
+ linear,
+ zero_module,
+ timestep_embedding,
+)
+
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from ldm.modules.attention_compat import SpatialTransformer
+# from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+from guided_diffusion.unet import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+# from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.util import log_txt_as_img, exists # , instantiate_from_config
+# from ldm.models.diffusion.ddim import DDIMSampler
+from pdb import set_trace as st
+
+
+class ControlledUnetModel(UNetModel):
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, get_attr='', **kwargs):
+
+ if get_attr != '': # not breaking the forward hooks
+ return getattr(self, get_attr)
+
+ hs = []
+ with torch.no_grad(): # fix middle_block, SD
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.roll_out:
+ x = rearrange(x, 'b (n c) h w->b c h (n w)', n=3) # torch.Size([84, 4, 32, 96])
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+
+ assert control is not None
+ # if control is not None:
+ h += control.pop()
+
+ for i, module in enumerate(self.output_blocks):
+ if only_mid_control or control is None:
+ h = torch.cat([h, hs.pop()], dim=1)
+ else:
+ # st()
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+ h = module(h, emb, context)
+
+ h = h.type(x.dtype)
+ h = self.out(h)
+ if self.roll_out:
+ return rearrange(h, 'b c h (n w) -> b (n c) h w', n=3)
+ return h
+
+
+class ControlNet(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ hint_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ # * new keys introduced in LDM
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ roll_out=False,
+ ):
+ super().__init__()
+ self.roll_out = roll_out
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.dims = dims
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ # self.use_checkpoint = use_checkpoint
+ self.use_checkpoint = False
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+ self.input_hint_block = TimestepEmbedSequential( # f=8
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ )
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_convs.append(self.make_zero_conv(ch))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ self.zero_convs.append(self.make_zero_conv(ch))
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self.middle_block_out = self.make_zero_conv(ch)
+ self._feature_size += ch
+
+ def make_zero_conv(self, channels):
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+ def forward(self, x, hint, timesteps, context, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb) # time condition embedding
+ guided_hint = self.input_hint_block(hint, emb, context) # B 320 8 8, if input resolution = 64
+
+ if self.roll_out:
+ x = rearrange(x, 'b (n c) h w->b c h (n w)', n=3) # torch.Size([84, 4, 32, 96])
+ guided_hint = repeat(guided_hint, 'b c h w -> b c h (n w)', n=3) # torch.Size([84, 4, 32, 96])
+
+ outs = []
+
+ h = x.type(self.dtype)
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None: # f=8, shall send in 128x128 img_sr
+ h = module(h, emb, context) # B 320 16 16
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ outs.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ outs.append(self.middle_block_out(h, emb, context))
+
+ return outs
+
+# ! do not support PL here
+# class ControlLDM(LatentDiffusion):
+
+# def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
+# super().__init__(*args, **kwargs)
+# self.control_model = instantiate_from_config(control_stage_config)
+# self.control_key = control_key
+# self.only_mid_control = only_mid_control
+# self.control_scales = [1.0] * 13
+
+# @torch.no_grad()
+# def get_input(self, batch, k, bs=None, *args, **kwargs):
+# x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
+# control = batch[self.control_key]
+# if bs is not None:
+# control = control[:bs]
+# control = control.to(self.device)
+# control = einops.rearrange(control, 'b h w c -> b c h w')
+# control = control.to(memory_format=torch.contiguous_format).float()
+# return x, dict(c_crossattn=[c], c_concat=[control])
+
+# def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+# assert isinstance(cond, dict)
+# diffusion_model = self.model.diffusion_model
+
+# cond_txt = torch.cat(cond['c_crossattn'], 1)
+
+# if cond['c_concat'] is None:
+# eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
+# else:
+# control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
+# control = [c * scale for c, scale in zip(control, self.control_scales)]
+# eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
+
+# return eps
+
+# @torch.no_grad()
+# def get_unconditional_conditioning(self, N):
+# return self.get_learned_conditioning([""] * N)
+
+# @torch.no_grad()
+# def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
+# quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+# plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
+# use_ema_scope=True,
+# **kwargs):
+# use_ddim = ddim_steps is not None
+
+# log = dict()
+# z, c = self.get_input(batch, self.first_stage_key, bs=N)
+# c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
+# N = min(z.shape[0], N)
+# n_row = min(z.shape[0], n_row)
+# log["reconstruction"] = self.decode_first_stage(z)
+# log["control"] = c_cat * 2.0 - 1.0
+# log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
+
+# if plot_diffusion_rows:
+# # get diffusion row
+# diffusion_row = list()
+# z_start = z[:n_row]
+# for t in range(self.num_timesteps):
+# if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+# t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+# t = t.to(self.device).long()
+# noise = torch.randn_like(z_start)
+# z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+# diffusion_row.append(self.decode_first_stage(z_noisy))
+
+# diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+# diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+# diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+# diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+# log["diffusion_row"] = diffusion_grid
+
+# if sample:
+# # get denoise row
+# samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+# batch_size=N, ddim=use_ddim,
+# ddim_steps=ddim_steps, eta=ddim_eta)
+# x_samples = self.decode_first_stage(samples)
+# log["samples"] = x_samples
+# if plot_denoise_rows:
+# denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+# log["denoise_row"] = denoise_grid
+
+# if unconditional_guidance_scale > 1.0:
+# uc_cross = self.get_unconditional_conditioning(N)
+# uc_cat = c_cat # torch.zeros_like(c_cat)
+# uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+# samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+# batch_size=N, ddim=use_ddim,
+# ddim_steps=ddim_steps, eta=ddim_eta,
+# unconditional_guidance_scale=unconditional_guidance_scale,
+# unconditional_conditioning=uc_full,
+# )
+# x_samples_cfg = self.decode_first_stage(samples_cfg)
+# log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+# return log
+
+# @torch.no_grad()
+# def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+# ddim_sampler = DDIMSampler(self)
+# b, c, h, w = cond["c_concat"][0].shape
+# shape = (self.channels, h // 8, w // 8)
+# samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
+# return samples, intermediates
+
+# def configure_optimizers(self):
+# lr = self.learning_rate
+# params = list(self.control_model.parameters())
+# if not self.sd_locked:
+# params += list(self.model.diffusion_model.output_blocks.parameters())
+# params += list(self.model.diffusion_model.out.parameters())
+# opt = torch.optim.AdamW(params, lr=lr)
+# return opt
+
+# def low_vram_shift(self, is_diffusing):
+# if is_diffusing:
+# self.model = self.model.cuda()
+# self.control_model = self.control_model.cuda()
+# self.first_stage_model = self.first_stage_model.cpu()
+# self.cond_stage_model = self.cond_stage_model.cpu()
+# else:
+# self.model = self.model.cpu()
+# self.control_model = self.control_model.cpu()
+# self.first_stage_model = self.first_stage_model.cuda()
+# self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py
new file mode 100644
index 0000000000000000000000000000000000000000..25b1bc947272ad14d7f7e5e4d1809005253b63d0
--- /dev/null
+++ b/cldm/ddim_hacked.py
@@ -0,0 +1,317 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ model_t = self.model.apply_model(x, t, c)
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ num_reference_steps = timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
diff --git a/cldm/hack.py b/cldm/hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..56661b131bc1e3833528fcd3d6e92ec3a957ff21
--- /dev/null
+++ b/cldm/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+
+import ldm.modules.encoders.modules
+import ldm.modules.attention_compat
+
+from transformers import logging
+from ldm.modules.attention_compat import default
+
+
+def disable_verbosity():
+ logging.set_verbosity_error()
+ print('logging improved.')
+ return
+
+
+def enable_sliced_attention():
+ ldm.modules.attention_compat.CrossAttention.forward = _hacked_sliced_attentin_forward
+ print('Enabled sliced_attention.')
+ return
+
+
+def hack_everything(clip_skip=0):
+ disable_verbosity()
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+ print('Enabled clip hacks.')
+ return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+ PAD = self.tokenizer.pad_token_id
+ EOS = self.tokenizer.eos_token_id
+ BOS = self.tokenizer.bos_token_id
+
+ def tokenize(t):
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+
+ def transformer_encode(t):
+ if self.clip_skip > 1:
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+ else:
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+
+ def split(x):
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+
+ def pad(x, p, i):
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+ raw_tokens_list = tokenize(text)
+ tokens_list = []
+
+ for raw_tokens in raw_tokens_list:
+ raw_tokens_123 = split(raw_tokens)
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+ tokens_list.append(raw_tokens_123)
+
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+ y = transformer_encode(feed)
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+
+ return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ limit = k.shape[0]
+ att_step = 1
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+ q_chunks.reverse()
+ k_chunks.reverse()
+ v_chunks.reverse()
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ del k, q, v
+ for i in range(0, limit, att_step):
+ q_buffer = q_chunks.pop()
+ k_buffer = k_chunks.pop()
+ v_buffer = v_chunks.pop()
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+ del k_buffer, q_buffer
+ # attention, what we cannot get enough of, by chunks
+
+ sim_buffer = sim_buffer.softmax(dim=-1)
+
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+ del v_buffer
+ sim[i:i + att_step, :, :] = sim_buffer
+
+ del sim_buffer
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
diff --git a/cldm/logger.py b/cldm/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8803846f2a8979f87f3cf9ea5b12869439e62f
--- /dev/null
+++ b/cldm/logger.py
@@ -0,0 +1,76 @@
+import os
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+
+class ImageLogger(Callback):
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+ log_images_kwargs=None):
+ super().__init__()
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.log_first_step = log_first_step
+
+ @rank_zero_only
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
+ root = os.path.join(save_dir, "image_log", split)
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
+ hasattr(pl_module, "log_images") and
+ callable(pl_module.log_images) and
+ self.max_images > 0):
+ logger = type(pl_module.logger)
+
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ with torch.no_grad():
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().cpu()
+ if self.clamp:
+ images[k] = torch.clamp(images[k], -1., 1.)
+
+ self.log_local(pl_module.logger.save_dir, split, images,
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, check_idx):
+ return check_idx % self.batch_freq == 0
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ if not self.disabled:
+ self.log_img(pl_module, batch, batch_idx, split="train")
diff --git a/cldm/model.py b/cldm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed3c31ac145b78907c7f771d1d8db6fb32d92ed
--- /dev/null
+++ b/cldm/model.py
@@ -0,0 +1,28 @@
+import os
+import torch
+
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+ return d.get('state_dict', d)
+
+
+def load_state_dict(ckpt_path, location='cpu'):
+ _, extension = os.path.splitext(ckpt_path)
+ if extension.lower() == ".safetensors":
+ import safetensors.torch
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+ else:
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+ state_dict = get_state_dict(state_dict)
+ print(f'Loaded state_dict from [{ckpt_path}]')
+ return state_dict
+
+
+def create_model(config_path):
+ config = OmegaConf.load(config_path)
+ model = instantiate_from_config(config.model).cpu()
+ print(f'Loaded model config from [{config_path}]')
+ return model
diff --git a/datasets/README.md b/datasets/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a16511d23965c88dac294147b26060c6daa6b59c
--- /dev/null
+++ b/datasets/README.md
@@ -0,0 +1 @@
+## TODO
\ No newline at end of file
diff --git a/datasets/__pycache__/eg3d_dataset.cpython-39.pyc b/datasets/__pycache__/eg3d_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44787fe949d7cd302fb8d5f7c3d81c1507e0087c
Binary files /dev/null and b/datasets/__pycache__/eg3d_dataset.cpython-39.pyc differ
diff --git a/datasets/__pycache__/g_buffer_objaverse.cpython-39.pyc b/datasets/__pycache__/g_buffer_objaverse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..374670e435f643ff7d81f0655b34ed3c3cac5ce0
Binary files /dev/null and b/datasets/__pycache__/g_buffer_objaverse.cpython-39.pyc differ
diff --git a/datasets/__pycache__/shapenet.cpython-39.pyc b/datasets/__pycache__/shapenet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88d24afc7478845b36af729fe04b6c0447b19373
Binary files /dev/null and b/datasets/__pycache__/shapenet.cpython-39.pyc differ
diff --git a/datasets/eg3d_dataset.py b/datasets/eg3d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce59c6b53a0d57528948a8be1b6a8b066ae3ef98
--- /dev/null
+++ b/datasets/eg3d_dataset.py
@@ -0,0 +1,601 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Streaming images and labels from datasets created with dataset_tool.py."""
+
+import cv2
+import os
+import numpy as np
+import zipfile
+import PIL.Image
+import json
+import torch
+import dnnlib
+from torchvision import transforms
+
+from pdb import set_trace as st
+
+from .shapenet import LMDBDataset_MV_Compressed, decompress_array
+
+try:
+ import pyspng
+except ImportError:
+ pyspng = None
+
+#----------------------------------------------------------------------------
+
+
+# copide from eg3d/train.py
+def init_dataset_kwargs(data,
+ class_name='datasets.eg3d_dataset.ImageFolderDataset',
+ reso_gt=128):
+ # try:
+ # if data == 'None':
+ # dataset_kwargs = dnnlib.EasyDict({}) #
+ # dataset_kwargs.name = 'eg3d_dataset'
+ # dataset_kwargs.resolution = 128
+ # dataset_kwargs.use_labels = False
+ # dataset_kwargs.max_size = 70000
+ # return dataset_kwargs, 'eg3d_dataset'
+
+ dataset_kwargs = dnnlib.EasyDict(class_name=class_name,
+ reso_gt=reso_gt,
+ path=data,
+ use_labels=True,
+ max_size=None,
+ xflip=False)
+ dataset_obj = dnnlib.util.construct_class_by_name(
+ **dataset_kwargs) # Subclass of training.dataset.Dataset.
+ dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
+ dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
+ dataset_kwargs.max_size = len(
+ dataset_obj) # Be explicit about dataset size.
+
+ return dataset_kwargs, dataset_obj.name
+ # except IOError as err:
+ # raise click.ClickException(f'--data: {err}')
+
+
+class Dataset(torch.utils.data.Dataset):
+
+ def __init__(
+ self,
+ name, # Name of the dataset.
+ raw_shape, # Shape of the raw image data (NCHW).
+ reso_gt=128,
+ max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
+ use_labels=False, # Enable conditioning labels? False = label dimension is zero.
+ xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
+ random_seed=0, # Random seed to use when applying max_size.
+ ):
+ self._name = name
+ self._raw_shape = list(raw_shape)
+ self._use_labels = use_labels
+ self._raw_labels = None
+ self._label_shape = None
+
+ # self.reso_gt = 128
+ self.reso_gt = reso_gt # ! hard coded
+ self.reso_encoder = 224
+
+ # Apply max_size.
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
+ # self._raw_idx = np.arange(self.__len__(), dtype=np.int64)
+ if (max_size is not None) and (self._raw_idx.size > max_size):
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
+
+ # Apply xflip.
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
+ if xflip:
+ self._raw_idx = np.tile(self._raw_idx, 2)
+ self._xflip = np.concatenate(
+ [self._xflip, np.ones_like(self._xflip)])
+
+ # dino encoder normalizer
+ self.normalize_for_encoder_input = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ transforms.Resize(size=(self.reso_encoder, self.reso_encoder),
+ antialias=True), # type: ignore
+ ])
+
+ self.normalize_for_gt = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ transforms.Resize(size=(self.reso_gt, self.reso_gt),
+ antialias=True), # type: ignore
+ ])
+
+ def _get_raw_labels(self):
+ if self._raw_labels is None:
+ self._raw_labels = self._load_raw_labels(
+ ) if self._use_labels else None
+ if self._raw_labels is None:
+ self._raw_labels = np.zeros([self._raw_shape[0], 0],
+ dtype=np.float32)
+ assert isinstance(self._raw_labels, np.ndarray)
+ # assert self._raw_labels.shape[0] == self._raw_shape[0]
+ assert self._raw_labels.dtype in [np.float32, np.int64]
+ if self._raw_labels.dtype == np.int64:
+ assert self._raw_labels.ndim == 1
+ assert np.all(self._raw_labels >= 0)
+ self._raw_labels_std = self._raw_labels.std(0)
+ return self._raw_labels
+
+ def close(self): # to be overridden by subclass
+ pass
+
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
+ raise NotImplementedError
+
+ def _load_raw_labels(self): # to be overridden by subclass
+ raise NotImplementedError
+
+ def __getstate__(self):
+ return dict(self.__dict__, _raw_labels=None)
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ pass
+
+ def __len__(self):
+ return self._raw_idx.size
+ # return self._get_raw_labels().shape[0]
+
+ def __getitem__(self, idx):
+ # print(self._raw_idx[idx], idx)
+
+ matte = self._load_raw_matte(self._raw_idx[idx])
+ assert isinstance(matte, np.ndarray)
+ assert list(matte.shape)[1:] == self.image_shape[1:]
+ if self._xflip[idx]:
+ assert matte.ndim == 1 # CHW
+ matte = matte[:, :, ::-1]
+ # matte_orig = matte.copy().astype(np.float32) / 255
+ matte_orig = matte.copy().astype(np.float32) # segmentation version
+ # assert matte_orig.max() == 1
+ matte = np.transpose(matte,
+ # (1, 2, 0)).astype(np.float32) / 255 # [0,1] range
+ (1, 2, 0)).astype(np.float32) # [0,1] range
+ matte = cv2.resize(matte, (self.reso_gt, self.reso_gt),
+ interpolation=cv2.INTER_NEAREST)
+ assert matte.min() >= 0 and matte.max(
+ ) <= 1, f'{matte.min(), matte.max()}'
+
+ if matte.ndim == 3: # H, W
+ matte = matte[..., 0]
+
+ image = self._load_raw_image(self._raw_idx[idx])
+
+ assert isinstance(image, np.ndarray)
+ assert list(image.shape) == self.image_shape
+ assert image.dtype == np.uint8
+ if self._xflip[idx]:
+ assert image.ndim == 3 # CHW
+ image = image[:, :, ::-1]
+
+ # blending
+ # blending = True
+ blending = False
+ if blending:
+ image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur(
+ image, (5, 5), cv2.BORDER_DEFAULT)
+ # image = image * matte_orig
+
+ image = np.transpose(image, (1, 2, 0)).astype(
+ np.float32
+ ) / 255 # H W C for torchvision process, normalize to [0,1]
+
+ image_sr = torch.from_numpy(image)[..., :3].permute(
+ 2, 0, 1) * 2 - 1 # normalize to [-1,1]
+ image_to_encoder = self.normalize_for_encoder_input(image)
+
+ image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt),
+ interpolation=cv2.INTER_AREA)
+ image_gt = torch.from_numpy(image_gt)[..., :3].permute(
+ 2, 0, 1) * 2 - 1 # normalize to [-1,1]
+
+ return dict(
+ c=self.get_label(idx),
+ img_to_encoder=image_to_encoder, # 224
+ img_sr=image_sr, # 512
+ img=image_gt, # [-1,1] range
+ # depth=torch.zeros_like(image_gt)[0, ...] # type: ignore
+ depth=matte,
+ depth_mask=matte,
+ # depth_mask=matte > 0,
+ # alpha=matte,
+ ) # return dict here
+
+ def get_label(self, idx):
+ label = self._get_raw_labels()[self._raw_idx[idx]]
+ if label.dtype == np.int64:
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
+ onehot[label] = 1
+ label = onehot
+ return label.copy()
+
+ def get_details(self, idx):
+ d = dnnlib.EasyDict()
+ d.raw_idx = int(self._raw_idx[idx])
+ d.xflip = (int(self._xflip[idx]) != 0)
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
+ return d
+
+ def get_label_std(self):
+ return self._raw_labels_std
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def image_shape(self):
+ return list(self._raw_shape[1:])
+
+ @property
+ def num_channels(self):
+ assert len(self.image_shape) == 3 # CHW
+ return self.image_shape[0]
+
+ @property
+ def resolution(self):
+ assert len(self.image_shape) == 3 # CHW
+ assert self.image_shape[1] == self.image_shape[2]
+ return self.image_shape[1]
+
+ @property
+ def label_shape(self):
+ if self._label_shape is None:
+ raw_labels = self._get_raw_labels()
+ if raw_labels.dtype == np.int64:
+ self._label_shape = [int(np.max(raw_labels)) + 1]
+ else:
+ self._label_shape = raw_labels.shape[1:]
+ return list(self._label_shape)
+
+ @property
+ def label_dim(self):
+ assert len(self.label_shape) == 1
+ return self.label_shape[0]
+
+ @property
+ def has_labels(self):
+ return any(x != 0 for x in self.label_shape)
+
+ @property
+ def has_onehot_labels(self):
+ return self._get_raw_labels().dtype == np.int64
+
+
+#----------------------------------------------------------------------------
+
+
+class ImageFolderDataset(Dataset):
+
+ def __init__(
+ self,
+ path, # Path to directory or zip.
+ resolution=None, # Ensure specific resolution, None = highest available.
+ reso_gt=128,
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self._path = path
+ # self._matte_path = path.replace('unzipped_ffhq_512',
+ # 'unzipped_ffhq_matte')
+ self._matte_path = path.replace('unzipped_ffhq_512',
+ 'ffhq_512_seg')
+ self._zipfile = None
+
+ if os.path.isdir(self._path):
+ self._type = 'dir'
+ self._all_fnames = {
+ os.path.relpath(os.path.join(root, fname), start=self._path)
+ for root, _dirs, files in os.walk(self._path)
+ for fname in files
+ }
+ elif self._file_ext(self._path) == '.zip':
+ self._type = 'zip'
+ self._all_fnames = set(self._get_zipfile().namelist())
+ else:
+ raise IOError('Path must point to a directory or zip')
+
+ PIL.Image.init()
+ self._image_fnames = sorted(
+ fname for fname in self._all_fnames
+ if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ name = os.path.splitext(os.path.basename(self._path))[0]
+ raw_shape = [len(self._image_fnames)] + list(
+ self._load_raw_image(0).shape)
+ # raw_shape = [len(self._image_fnames)] + list(
+ # self._load_raw_image(0).shape)
+ if resolution is not None and (raw_shape[2] != resolution
+ or raw_shape[3] != resolution):
+ raise IOError('Image files do not match the specified resolution')
+ super().__init__(name=name,
+ raw_shape=raw_shape,
+ reso_gt=reso_gt,
+ **super_kwargs)
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _get_zipfile(self):
+ assert self._type == 'zip'
+ if self._zipfile is None:
+ self._zipfile = zipfile.ZipFile(self._path)
+ return self._zipfile
+
+ def _open_file(self, fname):
+ if self._type == 'dir':
+ return open(os.path.join(self._path, fname), 'rb')
+ if self._type == 'zip':
+ return self._get_zipfile().open(fname, 'r')
+ return None
+
+ def _open_matte_file(self, fname):
+ if self._type == 'dir':
+ return open(os.path.join(self._matte_path, fname), 'rb')
+ # if self._type == 'zip':
+ # return self._get_zipfile().open(fname, 'r')
+ # return None
+
+ def close(self):
+ try:
+ if self._zipfile is not None:
+ self._zipfile.close()
+ finally:
+ self._zipfile = None
+
+ def __getstate__(self):
+ return dict(super().__getstate__(), _zipfile=None)
+
+ def _load_raw_image(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ with self._open_file(fname) as f:
+ if pyspng is not None and self._file_ext(fname) == '.png':
+ image = pyspng.load(f.read())
+ else:
+ image = np.array(PIL.Image.open(f))
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+ return image
+
+ def _load_raw_matte(self, raw_idx):
+ # ! from seg version
+ fname = self._image_fnames[raw_idx]
+ with self._open_matte_file(fname) as f:
+ if pyspng is not None and self._file_ext(fname) == '.png':
+ image = pyspng.load(f.read())
+ else:
+ image = np.array(PIL.Image.open(f))
+ # if image.max() != 1:
+ image = (image > 0).astype(np.float32) # process segmentation
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+ return image
+
+ def _load_raw_matte_orig(self, raw_idx):
+ fname = self._image_fnames[raw_idx]
+ with self._open_matte_file(fname) as f:
+ if pyspng is not None and self._file_ext(fname) == '.png':
+ image = pyspng.load(f.read())
+ else:
+ image = np.array(PIL.Image.open(f))
+ st() # process segmentation
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+ return image
+
+ def _load_raw_labels(self):
+ fname = 'dataset.json'
+ if fname not in self._all_fnames:
+ return None
+ with self._open_file(fname) as f:
+ # st()
+ labels = json.load(f)['labels']
+ if labels is None:
+ return None
+ labels = dict(labels)
+ labels_ = []
+ for fname, _ in labels.items():
+ # if 'mirror' not in fname:
+ labels_.append(labels[fname])
+ labels = labels_
+ # !
+ # labels = [
+ # labels[fname.replace('\\', '/')] for fname in self._image_fnames
+ # ]
+ labels = np.array(labels)
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
+ self._raw_labels = labels
+ return labels
+
+
+#----------------------------------------------------------------------------
+
+
+# class ImageFolderDatasetUnzipped(ImageFolderDataset):
+
+# def __init__(self, path, resolution=None, **super_kwargs):
+# super().__init__(path, resolution, **super_kwargs)
+
+
+# class ImageFolderDatasetPose(ImageFolderDataset):
+
+# def __init__(
+# self,
+# path, # Path to directory or zip.
+# resolution=None, # Ensure specific resolution, None = highest available.
+# **super_kwargs, # Additional arguments for the Dataset base class.
+# ):
+# super().__init__(path, resolution, **super_kwargs)
+# # only return labels
+
+# def __len__(self):
+# return self._raw_idx.size
+# # return self._get_raw_labels().shape[0]
+
+# def __getitem__(self, idx):
+# # image = self._load_raw_image(self._raw_idx[idx])
+# # assert isinstance(image, np.ndarray)
+# # assert list(image.shape) == self.image_shape
+# # assert image.dtype == np.uint8
+# # if self._xflip[idx]:
+# # assert image.ndim == 3 # CHW
+# # image = image[:, :, ::-1]
+# return dict(c=self.get_label(idx), ) # return dict here
+
+
+class ImageFolderDatasetLMDB(ImageFolderDataset):
+ def __init__(self, path, resolution=None, reso_gt=128, **super_kwargs):
+ super().__init__(path, resolution, reso_gt, **super_kwargs)
+
+ def __getitem__(self, idx):
+ # print(self._raw_idx[idx], idx)
+
+ matte = self._load_raw_matte(self._raw_idx[idx])
+ assert isinstance(matte, np.ndarray)
+ assert list(matte.shape)[1:] == self.image_shape[1:]
+ if self._xflip[idx]:
+ assert matte.ndim == 1 # CHW
+ matte = matte[:, :, ::-1]
+ # matte_orig = matte.copy().astype(np.float32) / 255
+ matte_orig = matte.copy().astype(np.float32) # segmentation version
+ assert matte_orig.max() <= 1 # some ffhq images are dirty, so may be all zero
+ matte = np.transpose(matte,
+ # (1, 2, 0)).astype(np.float32) / 255 # [0,1] range
+ (1, 2, 0)).astype(np.float32) # [0,1] range
+
+ # ! load 512 matte
+ # matte = cv2.resize(matte, (self.reso_gt, self.reso_gt),
+ # interpolation=cv2.INTER_NEAREST)
+
+ assert matte.min() >= 0 and matte.max(
+ ) <= 1, f'{matte.min(), matte.max()}'
+
+ if matte.ndim == 3: # H, W
+ matte = matte[..., 0]
+
+ image = self._load_raw_image(self._raw_idx[idx])
+
+ assert isinstance(image, np.ndarray)
+ assert list(image.shape) == self.image_shape
+ assert image.dtype == np.uint8
+ if self._xflip[idx]:
+ assert image.ndim == 3 # CHW
+ image = image[:, :, ::-1]
+
+ # blending
+ # blending = True
+ # blending = False
+ # if blending:
+ # image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur(
+ # image, (5, 5), cv2.BORDER_DEFAULT)
+ # image = image * matte_orig
+
+ # image = np.transpose(image, (1, 2, 0)).astype(
+ # np.float32
+ # ) / 255 # H W C for torchvision process, normalize to [0,1]
+
+ # image_sr = torch.from_numpy(image)[..., :3].permute(
+ # 2, 0, 1) * 2 - 1 # normalize to [-1,1]
+ # image_to_encoder = self.normalize_for_encoder_input(image)
+
+ # image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt),
+ # interpolation=cv2.INTER_AREA)
+ # image_gt = torch.from_numpy(image_gt)[..., :3].permute(
+ # 2, 0, 1) * 2 - 1 # normalize to [-1,1]
+
+ return dict(
+ c=self.get_label(idx),
+ # img_to_encoder=image_to_encoder, # 224
+ # img_sr=image_sr, # 512
+ img=image, # [-1,1] range
+ # depth=torch.zeros_like(image_gt)[0, ...] # type: ignore
+ # depth=matte,
+ depth_mask=matte,
+ ) # return dict here
+
+class LMDBDataset_MV_Compressed_eg3d(LMDBDataset_MV_Compressed):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ **kwargs):
+ super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
+ **kwargs)
+
+ self.normalize_for_encoder_input = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ transforms.Resize(size=(self.reso_encoder, self.reso_encoder),
+ antialias=True), # type: ignore
+ ])
+
+ self.normalize_for_gt = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ transforms.Resize(size=(self.reso, self.reso),
+ antialias=True), # type: ignore
+ ])
+
+ def __getitem__(self, idx):
+ # sample = super(LMDBDataset).__getitem__(idx)
+
+ # do gzip uncompress online
+ with self.env.begin(write=False) as txn:
+ img_key = f'{idx}-img'.encode('utf-8')
+ image = self.load_image_fn(txn.get(img_key))
+
+ depth_key = f'{idx}-depth_mask'.encode('utf-8')
+ # depth = decompress_array(txn.get(depth_key), (512,512), np.float32)
+ depth = decompress_array(txn.get(depth_key), (64,64), np.float32)
+
+ c_key = f'{idx}-c'.encode('utf-8')
+ c = decompress_array(txn.get(c_key), (25, ), np.float32)
+
+ # ! post processing, e.g., normalizing
+ depth = cv2.resize(depth, (self.reso, self.reso),
+ interpolation=cv2.INTER_NEAREST)
+
+ image = np.transpose(image, (1, 2, 0)).astype(
+ np.float32
+ ) / 255 # H W C for torchvision process, normalize to [0,1]
+
+ image_sr = torch.from_numpy(image)[..., :3].permute(
+ 2, 0, 1) * 2 - 1 # normalize to [-1,1]
+ image_to_encoder = self.normalize_for_encoder_input(image)
+
+ image_gt = cv2.resize(image, (self.reso, self.reso),
+ interpolation=cv2.INTER_AREA)
+ image_gt = torch.from_numpy(image_gt)[..., :3].permute(
+ 2, 0, 1) * 2 - 1 # normalize to [-1,1]
+
+
+ return {
+ 'img_to_encoder': image_to_encoder, # 224
+ 'img_sr': image_sr, # 512
+ 'img': image_gt, # [-1,1] range
+ 'c': c,
+ 'depth': depth,
+ 'depth_mask': depth,
+ }
diff --git a/datasets/g_buffer_objaverse.py b/datasets/g_buffer_objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..04a0309848757a35121bd70c17ac4c0f3e5686cc
--- /dev/null
+++ b/datasets/g_buffer_objaverse.py
@@ -0,0 +1,3000 @@
+import os
+import collections
+import math
+import time
+import itertools
+import pickle
+from typing import Any
+import lmdb
+import cv2
+import imageio
+import numpy as np
+from PIL import Image
+import Imath
+import OpenEXR
+from pdb import set_trace as st
+from pathlib import Path
+import torchvision
+
+from einops import rearrange, repeat
+from functools import partial
+import io
+import gzip
+import random
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+from torch.utils.data.distributed import DistributedSampler
+from pathlib import Path
+import lz4.frame
+
+import torch.multiprocessing
+
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+from utils.general_utils import PILtoTorch, matrix_to_quaternion
+
+from guided_diffusion import logger
+import json
+
+import webdataset as wds
+
+from .shapenet import LMDBDataset, LMDBDataset_MV_Compressed, decompress_and_open_image_gzip, decompress_array
+from kiui.op import safe_normalize
+
+from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World
+
+
+def fov2focal(fov, pixels):
+ return pixels / (2 * math.tan(fov / 2))
+
+
+def focal2fov(focal, pixels):
+ return 2 * math.atan(pixels / (2 * focal))
+
+
+def resize_depth_mask(depth_to_resize, resolution):
+ depth_resized = cv2.resize(depth_to_resize, (resolution, resolution),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ return depth_resized, depth_resized > 0 # type: ignore
+
+
+def resize_depth_mask_Tensor(depth_to_resize, resolution):
+
+ if depth_to_resize.shape[-1] != resolution:
+ depth_resized = torch.nn.functional.interpolate(
+ input=depth_to_resize.unsqueeze(1),
+ size=(resolution, resolution),
+ mode='bilinear',
+ align_corners=False,
+ ).squeeze(1)
+ else:
+ depth_resized = depth_to_resize
+
+ return depth_resized, depth_resized > 0 # type: ignore
+
+
+def load_dataset(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_lmdb=False,
+ use_wds=False,
+ use_lmdb_compressed=False,
+ infi_sampler=True):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # st()
+ if use_wds:
+ return load_wds_data(file_path, reso, reso_encoder, batch_size,
+ num_workers)
+
+ if use_lmdb:
+ logger.log('using LMDB dataset')
+ # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
+
+ if use_lmdb_compressed:
+ if 'nv' in trainer_name:
+ dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ if 'nv' in trainer_name:
+ dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
+
+ # dataset = dataset_cls(file_path)
+ else:
+ if 'nv' in trainer_name:
+ dataset_cls = NovelViewObjverseDataset
+ else:
+ dataset_cls = MultiViewObjverseDataset # 1.5-2iter/s
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+
+ loader = DataLoader(dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ shuffle=False)
+ return loader
+
+
+def load_data(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_lmdb=False,
+ use_wds=False,
+ use_lmdb_compressed=False,
+ plucker_embedding=False,
+ infi_sampler=True):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # st()
+ # if use_lmdb:
+ # logger.log('using LMDB dataset')
+ # # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
+ # if 'nv' in trainer_name:
+ # dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ # else:
+ # dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ # # dataset = dataset_cls(file_path)
+
+ if use_lmdb:
+ logger.log('using LMDB dataset')
+ # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
+
+ if use_lmdb_compressed:
+ if 'nv' in trainer_name:
+ dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ if 'nv' in trainer_name:
+ dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
+
+ else:
+ if 'nv' in trainer_name:
+ dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s
+ else:
+ dataset_cls = MultiViewObjverseDataset
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size,
+ plucker_embedding=plucker_embedding)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+
+ # st()
+
+ if infi_sampler:
+ train_sampler = DistributedSampler(dataset=dataset,
+ shuffle=True,
+ drop_last=True)
+
+ loader = DataLoader(dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=True,
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ sampler=train_sampler)
+
+ while True:
+ yield from loader
+
+ # else:
+ # # loader = DataLoader(dataset,
+ # # batch_size=batch_size,
+ # # num_workers=num_workers,
+ # # drop_last=False,
+ # # pin_memory=True,
+ # # persistent_workers=num_workers > 0,
+ # # shuffle=False)
+ # st()
+ # return dataset
+
+
+def load_eval_data(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ num_workers=1,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ interval=1,
+ use_lmdb=False,
+ plucker_embedding=False,
+ load_real=False,
+ four_view_for_latent=False,
+ shuffle_across_cls=False,
+ load_extra_36_view=False,
+ gs_cam_format=False,
+ single_view_for_i23d=False,
+ **kwargs,
+):
+
+ if use_lmdb:
+ logger.log('using LMDB dataset')
+ dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=True,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ interval=interval)
+
+ elif load_real:
+ dataset = RealDataset(file_path,
+ reso,
+ reso_encoder,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ test=True,
+ imgnet_normalize=imgnet_normalize,
+ interval=interval,
+ plucker_embedding=plucker_embedding)
+
+ else:
+ dataset = MultiViewObjverseDataset(
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ test=True,
+ imgnet_normalize=imgnet_normalize,
+ interval=interval,
+ plucker_embedding=plucker_embedding,
+ four_view_for_latent=four_view_for_latent,
+ load_extra_36_view=load_extra_36_view,
+ shuffle_across_cls=shuffle_across_cls,
+ gs_cam_format=gs_cam_format,
+ single_view_for_i23d=single_view_for_i23d,
+ )
+
+ print('eval dataset size: {}'.format(len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ shuffle=False,
+ )
+ # sampler=train_sampler)
+ return loader
+
+
+def load_data_for_lmdb(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ shuffle_across_cls=False,
+ four_view_for_latent=False,
+ wds_split=1):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # if 'nv' in trainer_name:
+ # dataset_cls = NovelViewDataset
+ # else:
+ # dataset_cls = MultiViewDataset
+ # st()
+ dataset_cls = MultiViewObjverseDatasetforLMDB
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size,
+ shuffle_across_cls=shuffle_across_cls,
+ wds_split=wds_split,
+ four_view_for_latent=four_view_for_latent)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
+ loader = DataLoader(
+ dataset,
+ shuffle=False,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ # prefetch_factor=2,
+ # prefetch_factor=3,
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ )
+ # sampler=train_sampler)
+
+ # while True:
+ # yield from loader
+ return loader, dataset.dataset_name, len(dataset)
+
+
+def load_lmdb_for_lmdb(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec'):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # if 'nv' in trainer_name:
+ # dataset_cls = NovelViewDataset
+ # else:
+ # dataset_cls = MultiViewDataset
+ # st()
+ dataset_cls = Objv_LMDBDataset_MV_Compressed_for_lmdb
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
+ loader = DataLoader(
+ dataset,
+ shuffle=False,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ prefetch_factor=2,
+ # prefetch_factor=3,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ # sampler=train_sampler)
+
+ # while True:
+ # yield from loader
+ return loader, len(dataset)
+
+
+def load_memory_data(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ num_workers=1,
+ # load_depth=True,
+ preprocess=None,
+ imgnet_normalize=True,
+ **kwargs):
+ # load a single-instance into the memory to speed up training IO
+ # dataset = MultiViewObjverseDataset(file_path,
+ dataset = NovelViewObjverseDataset(file_path,
+ reso,
+ reso_encoder,
+ preprocess=preprocess,
+ load_depth=True,
+ test=False,
+ overfitting=True,
+ imgnet_normalize=imgnet_normalize,
+ overfitting_bs=batch_size,
+ **kwargs)
+ logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=len(dataset),
+ num_workers=num_workers,
+ drop_last=False,
+ shuffle=False,
+ )
+
+ all_data: dict = next(
+ iter(loader)
+ ) # torchvision.utils.save_image(all_data['img'], 'gt.jpg', normalize=True, value_range=(-1,1))
+ if kwargs.get('gs_cam_format', False): # gs rendering pipeline
+ # ! load V=4 images for training in a batch.
+ while True:
+ # indices = torch.randperm(len(dataset))[:4]
+ indices = torch.randperm(
+ len(dataset))[:batch_size] # all instances
+ # indices2 = torch.randperm(len(dataset))[:] # all instances
+
+ batch_c = collections.defaultdict(dict)
+ for k in ['c', 'nv_c']:
+ for k_c, v_c in all_data[k].items():
+ batch_c[k][k_c] = torch.index_select(
+ v_c, dim=0, index=indices).reshape(
+ batch_size //
+ 4, 4, *v_c.shape[1:]).float() if isinstance(
+ v_c, torch.Tensor) else v_c.float() # float
+
+ batch_c['c']['tanfov'] = batch_c['c']['tanfov'][0][0].item()
+ batch_c['nv_c']['tanfov'] = batch_c['nv_c']['tanfov'][0][0].item()
+
+ batch_data = {}
+ for k, v in all_data.items():
+ if k not in ['c', 'nv_c']:
+ batch_data[k] = torch.index_select(
+ v, dim=0, index=indices).float() if isinstance(
+ v, torch.Tensor) else v # float
+
+ yield {
+ **batch_data,
+ **batch_c,
+ }
+
+ else:
+ while True:
+ start_idx = np.random.randint(0, len(dataset) - batch_size + 1)
+ yield {
+ k: v[start_idx:start_idx + batch_size]
+ for k, v in all_data.items()
+ }
+
+
+def read_dnormal(normald_path, cond_pos, h=None, w=None):
+ cond_cam_dis = np.linalg.norm(cond_pos, 2)
+
+ near = 0.867 #sqrt(3) * 0.5
+ near_distance = cond_cam_dis - near
+
+ normald = cv2.imread(normald_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
+ depth = normald[..., 3:]
+
+ depth[depth < near_distance] = 0
+
+ if h is not None:
+ assert w is not None
+ depth = cv2.resize(depth, (h, w)) # 512,512, 1 -> self.reso, self.reso
+
+ else:
+ depth = depth[..., 0]
+
+ return torch.from_numpy(depth).float()
+
+
+def get_intri(target_im=None, h=None, w=None, normalize=False):
+ if target_im is None:
+ assert (h is not None and w is not None)
+ else:
+ h, w = target_im.shape[:2]
+
+ fx = fy = 1422.222
+ res_raw = 1024
+ f_x = f_y = fx * h / res_raw
+ K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
+ if normalize: # center is [0.5, 0.5], eg3d renderer tradition
+ K[:6] /= h
+ # print("intr: ", K)
+ return K
+
+
+def convert_pose(C2W):
+ # https://github.com/modelscope/richdreamer/blob/c3d9a77fa15fc42dbae12c2d41d64aaec14efd37/dataset/gobjaverse/depth_warp_example.py#L402
+ flip_yz = np.eye(4)
+ flip_yz[1, 1] = -1
+ flip_yz[2, 2] = -1
+ C2W = np.matmul(C2W, flip_yz)
+ return torch.from_numpy(C2W)
+
+
+def read_camera_matrix_single(json_file):
+ with open(json_file, 'r', encoding='utf8') as reader:
+ json_content = json.load(reader)
+ '''
+ # NOTE that different from unity2blender experiments.
+ camera_matrix = np.eye(4)
+ camera_matrix[:3, 0] = np.array(json_content['x'])
+ camera_matrix[:3, 1] = -np.array(json_content['y'])
+ camera_matrix[:3, 2] = -np.array(json_content['z'])
+ camera_matrix[:3, 3] = np.array(json_content['origin'])
+
+
+ '''
+ camera_matrix = np.eye(4) # blender-based
+ camera_matrix[:3, 0] = np.array(json_content['x'])
+ camera_matrix[:3, 1] = np.array(json_content['y'])
+ camera_matrix[:3, 2] = np.array(json_content['z'])
+ camera_matrix[:3, 3] = np.array(json_content['origin'])
+ # print(camera_matrix)
+ # '''
+
+ # return convert_pose(camera_matrix)
+ return camera_matrix
+
+
+def unity2blender(normal):
+ normal_clone = normal.copy()
+ normal_clone[..., 0] = -normal[..., -1]
+ normal_clone[..., 1] = -normal[..., 0]
+ normal_clone[..., 2] = normal[..., 1]
+
+ return normal_clone
+
+
+def blender2midas(img):
+ '''Blender: rub
+ midas: lub
+ '''
+ img[..., 0] = -img[..., 0]
+ img[..., 1] = -img[..., 1]
+ img[..., -1] = -img[..., -1]
+ return img
+
+
+def current_milli_time():
+ return round(time.time() * 1000)
+
+
+# modified from ShapeNet class
+class MultiViewObjverseDataset(Dataset):
+
+ def __init__(
+ self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1,
+ interval=1,
+ plucker_embedding=False,
+ shuffle_across_cls=False,
+ wds_split=1, # 4 splits to accelerate preprocessing
+ four_view_for_latent=False,
+ single_view_for_i23d=False,
+ load_extra_36_view=False,
+ gs_cam_format=False,
+ **kwargs):
+ self.load_extra_36_view = load_extra_36_view
+ # st()
+ self.gs_cam_format = gs_cam_format
+ self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction
+ self.single_view_for_i23d = single_view_for_i23d
+ self.file_path = file_path
+ self.overfitting = overfitting
+ self.scene_scale = scene_scale
+ self.reso = reso
+ self.reso_encoder = reso_encoder
+ self.classes = False
+ self.load_depth = load_depth
+ self.preprocess = preprocess
+ self.plucker_embedding = plucker_embedding
+ self.intrinsics = get_intri(h=self.reso, w=self.reso,
+ normalize=True).reshape(9)
+
+ assert not self.classes, "Not support class condition now."
+
+ dataset_name = Path(self.file_path).stem.split('_')[0]
+ self.dataset_name = dataset_name
+
+ self.zfar = 100.0
+ self.znear = 0.01
+
+ # if test:
+ # self.ins_list = sorted(os.listdir(self.file_path))[0:1] # the first 1 instance for evaluation reference.
+ # else:
+ # ! TODO, read from list?
+
+ def load_single_cls_instances(file_path):
+ ins_list = [] # the first 1 instance for evaluation reference.
+ for dict_dir in os.listdir(file_path)[:]:
+ for ins_dir in os.listdir(os.path.join(file_path, dict_dir)):
+ # self.ins_list.append(os.path.join(self.file_path, dict_dir, ins_dir,))
+ ins_list.append(
+ os.path.join(file_path, dict_dir, ins_dir,
+ 'campos_512_v4'))
+ return ins_list
+
+ # st()
+ if shuffle_across_cls:
+ self.ins_list = []
+ # for subset in ['Animals', 'Transportations_tar', 'Furnitures']:
+ # for subset in ['Furnitures']:
+ # selected subset for training
+ for subset in [ # ! around 17W instances in total. MVImageNet is the next thing to deal with? Later.
+ # 'daily-used',
+ # 'Food',
+ # 'Plants',
+ # 'Electronics',
+ # 'BuildingsOutdoor',
+ # 'Human-Shape',
+ 'Animals',
+ # 'Transportations_tar',
+ # 'Furnitures',
+ ]: # selected subset for training
+ self.ins_list += load_single_cls_instances(
+ os.path.join(self.file_path, subset))
+ # st()
+ current_time = int(current_milli_time()
+ ) # randomly shuffle given current time
+ random.seed(current_time)
+ random.shuffle(self.ins_list)
+
+ else: # preprocess single class
+ self.ins_list = load_single_cls_instances(self.file_path)
+ self.ins_list = sorted(self.ins_list)
+
+ # if test:
+ # self.ins_list = self.ins_list[0:1]
+
+ if overfitting:
+ self.ins_list = self.ins_list[:1]
+
+ self.rgb_list = []
+ self.pose_list = []
+ self.depth_list = []
+ self.data_ins_list = []
+ self.instance_data_length = -1
+
+ with open(
+ # '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
+ '/mnt/yslan/objaverse/richdreamer/dataset/text_captions_cap3d.json',
+ ) as f:
+ self.caption_data = json.load(f)
+
+ self.shuffle_across_cls = shuffle_across_cls
+
+ # for ins in self.ins_list[47000:]:
+ if four_view_for_latent:
+ self.wds_split_all = 1 # ! when dumping latent
+ ins_list_to_process = self.ins_list
+ else:
+ self.wds_split_all = 4
+ # self.wds_split_all = 8 # ! 8 cls in total
+ all_ins_size = len(self.ins_list)
+ ratio_size = all_ins_size // self.wds_split_all + 1
+
+ ins_list_to_process = self.ins_list[ratio_size *
+ (wds_split - 1):ratio_size *
+ wds_split]
+
+ # st()
+ for ins in ins_list_to_process:
+ # ins = os.path.join(
+ # # self.file_path, ins , 'campos_512_v4'
+ # self.file_path, ins ,
+ # # 'compos_512_v4'
+ # )
+ # cur_rgb_path = os.path.join(self.file_path, ins, 'compos_512_v4')
+ # cur_pose_path = os.path.join(self.file_path, ins, 'pose')
+
+ # st()
+ # ][:27])
+
+ if self.four_view_for_latent:
+ # cur_all_fname = [t.split('.')[0] for t in os.listdir(ins)
+ # ] # use full set for training
+ # cur_all_fname = [f'{idx:05d}' for idx in [0, 12, 30, 36]
+ # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24]
+ # cur_all_fname = [f'{idx:05d}' for idx in [7,16,24,25]
+ cur_all_fname = [f'{idx:05d}' for idx in [4,12,20,25]
+ ] # ! four views for inference
+ # cur_all_fname += [f'{idx:05d}' for idx in range(40) if idx not in [0,12,30,36]] # ! four views for inference
+ elif self.single_view_for_i23d:
+ # cur_all_fname = [f'{idx:05d}'
+ # for idx in [16]] # 20 is also fine
+ cur_all_fname = [f'{idx:05d}'
+ for idx in [2]] # ! furniture side view
+
+ else:
+ cur_all_fname = [t.split('.')[0] for t in os.listdir(ins)
+ ] # use full set for training
+
+ if shuffle_across_cls:
+ random.seed(current_time)
+ random.shuffle(cur_all_fname)
+ else:
+ cur_all_fname = sorted(cur_all_fname)
+
+ if self.instance_data_length == -1:
+ self.instance_data_length = len(cur_all_fname)
+ else:
+ try: # data missing?
+ assert len(cur_all_fname) == self.instance_data_length
+ except:
+ # with open('error_log.txt', 'a') as f:
+ # f.write(str(e) + '\n')
+ with open('missing_ins_new2.txt', 'a') as f:
+ f.write(str(Path(ins.parent)) +
+ '\n') # remove the "campos_512_v4"
+ continue
+
+ # if test: # use middle image as the novel view model input
+ # mid_index = len(cur_all_fname) // 3 * 2
+ # cur_all_fname.insert(0, cur_all_fname[mid_index])
+
+ self.pose_list += ([
+ os.path.join(ins, fname, fname + '.json')
+ for fname in cur_all_fname
+ ])
+ self.rgb_list += ([
+ os.path.join(ins, fname, fname + '.png')
+ for fname in cur_all_fname
+ ])
+
+ self.depth_list += ([
+ os.path.join(ins, fname, fname + '_nd.exr')
+ for fname in cur_all_fname
+ ])
+ self.data_ins_list += ([ins] * len(cur_all_fname))
+
+ # check
+
+ # ! setup normalizataion
+ transformations = [
+ transforms.ToTensor(), # [0,1] range
+ ]
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+
+ def get_source_cw2wT(self, source_cameras_view_to_world):
+ return matrix_to_quaternion(
+ source_cameras_view_to_world[:3, :3].transpose(0, 1))
+
+ def c_to_3dgs_format(self, pose):
+ # TODO, switch to torch version (batched later)
+
+ c2w = pose[:16].reshape(4, 4) # 3x4
+
+ # ! load cam
+ w2c = np.linalg.inv(c2w)
+ R = np.transpose(
+ w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
+ T = w2c[:3, 3]
+ fx = pose[16]
+ FovX = focal2fov(fx, 1)
+ FovY = focal2fov(fx, 1)
+
+ tanfovx = math.tan(FovX * 0.5)
+ tanfovy = math.tan(FovY * 0.5)
+
+ assert tanfovx == tanfovy
+
+ trans = np.array([0.0, 0.0, 0.0])
+ scale = 1.0
+
+ world_view_transform = torch.tensor(getWorld2View2(R, T, trans,
+ scale)).transpose(
+ 0, 1)
+ projection_matrix = getProjectionMatrix(znear=self.znear,
+ zfar=self.zfar,
+ fovX=FovX,
+ fovY=FovY).transpose(0, 1)
+ full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
+ projection_matrix.unsqueeze(0))).squeeze(0)
+ camera_center = world_view_transform.inverse()[3, :3]
+
+ view_world_transform = torch.tensor(getView2World(R, T, trans,
+ scale)).transpose(
+ 0, 1)
+
+ # item.update(viewpoint_cam=[viewpoint_cam])
+ c = {}
+ c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform)
+ c.update(
+ # projection_matrix=projection_matrix, # K
+ cam_view=world_view_transform, # world_view_transform
+ cam_view_proj=full_proj_transform, # full_proj_transform
+ cam_pos=camera_center,
+ tanfov=tanfovx, # TODO, fix in the renderer
+ # orig_c2w=c2w,
+ # orig_w2c=w2c,
+ orig_pose=torch.from_numpy(pose),
+ orig_c2w=torch.from_numpy(c2w),
+ orig_w2c=torch.from_numpy(w2c),
+ # tanfovy=tanfovy,
+ )
+
+ return c # dict for gs rendering
+
+ def __len__(self):
+ return len(self.rgb_list)
+
+ def load_bbox(self, mask):
+ # st()
+ nonzero_value = torch.nonzero(mask)
+ height, width = nonzero_value.max(dim=0)[0]
+ top, left = nonzero_value.min(dim=0)[0]
+ bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
+ return bbox
+
+ def __getitem__(self, idx):
+ # try:
+ data = self._read_data(idx)
+ return data
+ # except Exception as e:
+ # with open('error_log.txt', 'a') as f:
+ # f.write(str(e) + '\n')
+ # with open('error_idx.txt', 'a') as f:
+ # f.write(str(self.data_ins_list[idx]) + '\n')
+ # print(e, flush=True)
+ # return {}
+
+ def gen_rays(self, c2w):
+ # Generate rays
+ self.h = self.reso_encoder
+ self.w = self.reso_encoder
+ yy, xx = torch.meshgrid(
+ torch.arange(self.h, dtype=torch.float32) + 0.5,
+ torch.arange(self.w, dtype=torch.float32) + 0.5,
+ indexing='ij')
+
+ # normalize to 0-1 pixel range
+ yy = yy / self.h
+ xx = xx / self.w
+
+ # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
+ cx, cy, fx, fy = self.intrinsics[2], self.intrinsics[
+ 5], self.intrinsics[0], self.intrinsics[4]
+ # cx *= self.w
+ # cy *= self.h
+
+ # f_x = f_y = fx * h / res_raw
+ c2w = torch.from_numpy(c2w).float()
+
+ xx = (xx - cx) / fx
+ yy = (yy - cy) / fy
+ zz = torch.ones_like(xx)
+ dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
+ dirs /= torch.norm(dirs, dim=-1, keepdim=True)
+ dirs = dirs.reshape(-1, 3, 1)
+ del xx, yy, zz
+ # st()
+ dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
+
+ origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
+ origins = origins.view(self.h, self.w, 3)
+ dirs = dirs.view(self.h, self.w, 3)
+
+ return origins, dirs
+
+ def _read_data(self, idx):
+ rgb_fname = self.rgb_list[idx]
+ pose_fname = self.pose_list[idx]
+
+ raw_img = imageio.imread(rgb_fname)
+
+ # ! RGBD
+ alpha_mask = raw_img[..., -1:] / 255
+ raw_img = alpha_mask * raw_img[..., :3] + (
+ 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
+
+ raw_img = raw_img.astype(
+ np.uint8) # otherwise, float64 won't call ToTensor()
+
+ # return raw_img
+ # st()
+
+ if self.preprocess is None:
+ img_to_encoder = cv2.resize(raw_img,
+ (self.reso_encoder, self.reso_encoder),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ img_to_encoder = img_to_encoder[
+ ..., :3] #[3, reso_encoder, reso_encoder]
+ img_to_encoder = self.normalize(img_to_encoder)
+ else:
+ img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip
+
+ # return img_to_encoder
+
+ img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ # interpolation=cv2.INTER_AREA)
+
+ # img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA)
+ # img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
+ # img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
+
+ # img_sr = cv2.resize(
+ # raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4
+ # ) # just as refinement, since eg3d uses 64->128 final resolution
+
+ # img = torch.from_numpy(img)[..., :3].permute(
+ # 2, 0, 1) / 255.0 #[3, reso, reso]
+
+ img = torch.from_numpy(img)[..., :3].permute(
+ 2, 0, 1
+ ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ # img_sr = torch.from_numpy(img_sr)[..., :3].permute(
+ # 2, 0, 1
+ # ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16]
+ # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
+
+ # return c2w
+
+ # if self.load_depth:
+ # depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx],
+ # try:
+ depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso,
+ self.reso)
+ # return depth
+ # except:
+ # # print(self.depth_list[idx])
+ # raise NotImplementedError(self.depth_list[idx])
+ # if depth
+
+ # try:
+ bbox = self.load_bbox(depth > 0)
+ # except:
+ # print(rgb_fname)
+ # return {}
+ # st()
+
+ # plucker
+ rays_o, rays_d = self.gen_rays(c2w)
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d],
+ dim=-1) # [h, w, 6]
+
+ img_to_encoder = torch.cat(
+ [img_to_encoder, rays_plucker.permute(2, 0, 1)],
+ 0).float() # concat in C dim
+
+ # ! add depth as input
+
+ normalized_depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:],
+ self.reso_encoder,
+ self.reso_encoder).unsqueeze(0)
+ # normalized_depth = depth.unsqueeze(0) # min=0
+ img_to_encoder = torch.cat([img_to_encoder, normalized_depth],
+ 0) # concat in C dim
+
+ c = np.concatenate([c2w.reshape(16), self.intrinsics],
+ axis=0).reshape(25).astype(
+ np.float32) # 25, no '1' dim needed.
+
+ if self.gs_cam_format:
+ c = self.c_to_3dgs_format(c)
+ else:
+ c = torch.from_numpy(c)
+
+ ret_dict = {
+ # 'rgb_fname': rgb_fname,
+ 'img_to_encoder': img_to_encoder,
+ 'img': img,
+ 'c': c,
+ # 'img_sr': img_sr,
+ # 'ins_name': self.data_ins_list[idx]
+ }
+
+ ins = str(
+ (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
+ if self.shuffle_across_cls:
+ caption = self.caption_data['/'.join(ins.split('/')[1:])]
+ else:
+ caption = self.caption_data[ins]
+
+ ret_dict.update({
+ 'depth': depth,
+ 'depth_mask': depth > 0,
+ # 'depth_mask_sr': depth_mask_sr,
+ 'bbox': bbox,
+ 'caption': caption,
+ 'rays_plucker': rays_plucker, # cam embedding used in lgm
+ 'ins': ins, # placeholder
+ })
+
+ return ret_dict
+
+
+class RealDataset(Dataset):
+
+ def __init__(
+ self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1,
+ interval=1,
+ plucker_embedding=False,
+ shuffle_across_cls=False,
+ wds_split=1, # 4 splits to accelerate preprocessing
+ ) -> None:
+ super().__init__()
+
+ self.file_path = file_path
+ self.overfitting = overfitting
+ self.scene_scale = scene_scale
+ self.reso = reso
+ self.reso_encoder = reso_encoder
+ self.classes = False
+ self.load_depth = load_depth
+ self.preprocess = preprocess
+ self.plucker_embedding = plucker_embedding
+
+ self.rgb_list = []
+
+ all_fname = [
+ t for t in os.listdir(self.file_path)
+ if t.split('.')[1] in ['png', 'jpg']
+ ]
+ self.rgb_list += ([
+ os.path.join(self.file_path, fname) for fname in all_fname
+ ])
+ # if len(self.rgb_list) == 1:
+ # # placeholder
+ # self.rgb_list = self.rgb_list * 40
+
+ # ! setup normalizataion
+ transformations = [
+ transforms.ToTensor(), # [0,1] range
+ ]
+
+ assert imgnet_normalize
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+ camera = torch.load('eval_pose.pt', map_location='cpu')
+ self.eval_camera = camera
+
+ # pre-cache
+ self.calc_rays_plucker()
+
+ def gen_rays(self, c):
+ # Generate rays
+ intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
+ self.h = self.reso_encoder
+ self.w = self.reso_encoder
+ yy, xx = torch.meshgrid(
+ torch.arange(self.h, dtype=torch.float32) + 0.5,
+ torch.arange(self.w, dtype=torch.float32) + 0.5,
+ indexing='ij')
+
+ # normalize to 0-1 pixel range
+ yy = yy / self.h
+ xx = xx / self.w
+
+ # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
+ cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
+ 0], intrinsics[4]
+ # cx *= self.w
+ # cy *= self.h
+
+ # f_x = f_y = fx * h / res_raw
+ if not isinstance(c2w, torch.Tensor):
+ c2w = torch.from_numpy(c2w)
+
+ c2w = c2w.float()
+
+ xx = (xx - cx) / fx
+ yy = (yy - cy) / fy
+ zz = torch.ones_like(xx)
+ dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
+ dirs /= torch.norm(dirs, dim=-1, keepdim=True)
+ dirs = dirs.reshape(-1, 3, 1)
+ del xx, yy, zz
+ # st()
+ dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
+
+ origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
+ origins = origins.view(self.h, self.w, 3)
+ dirs = dirs.view(self.h, self.w, 3)
+
+ return origins, dirs
+
+ def calc_rays_plucker(self):
+ all_rays_plucker = []
+
+ for c2w in self.eval_camera:
+ rays_o, rays_d = self.gen_rays(c2w)
+ rays_plucker = torch.cat(
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d],
+ dim=-1) # [h, w, 6]
+ all_rays_plucker.append(rays_plucker)
+
+ self.all_rays_plucker = torch.stack(all_rays_plucker,
+ 0).permute(0, 3, 1, 2) # B 6 H W
+
+ # st()
+ pass
+
+ def __len__(self):
+ return len(self.rgb_list)
+
+ def __getitem__(self, index) -> Any:
+ # return super().__getitem__(index)
+
+ rgb_fname = self.rgb_list[index]
+ # ! preprocess, normalize
+
+ raw_img = imageio.imread(rgb_fname)
+
+ # interpolation=cv2.INTER_AREA)
+ if raw_img.shape[-1] == 4:
+ alpha_mask = raw_img[..., 3:4] / 255.0
+ bg_white = np.ones_like(alpha_mask) * 255.0
+ raw_img = raw_img[..., :3] * alpha_mask + (
+ 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder]
+ raw_img = raw_img.astype(np.uint8)
+
+ img_to_encoder = cv2.resize(raw_img,
+ (self.reso_encoder, self.reso_encoder),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ # img_to_encoder = img_to_encoder
+ img_to_encoder = self.normalize(img_to_encoder)
+
+ # ! concat plucker
+ img_to_encoder = torch.cat(
+ [img_to_encoder, self.all_rays_plucker[index]],
+ 0) # concat in C dim
+
+ # log gt
+ img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ img = torch.from_numpy(img)[..., :3].permute(
+ 2, 0, 1
+ ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ ret_dict = {
+ # 'rgb_fname': rgb_fname,
+ 'img_to_encoder':
+ img_to_encoder.unsqueeze(0).repeat_interleave(40, 0),
+ 'img': img.unsqueeze(0).repeat_interleave(40, 0),
+ 'c': self.eval_camera, # TODO, get pre-calculated samples
+ 'ins': 'placeholder',
+ 'bbox': 'placeholder',
+ 'caption': 'placeholder',
+ }
+
+ # ! repeat as a intance
+
+ return ret_dict
+
+
+class NovelViewObjverseDataset(MultiViewObjverseDataset):
+ """novel view prediction version.
+ """
+
+ def __init__(self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1,
+ **kwargs):
+ super().__init__(file_path, reso, reso_encoder, preprocess, classes,
+ load_depth, test, scene_scale, overfitting,
+ imgnet_normalize, dataset_size, overfitting_bs,
+ **kwargs)
+
+ def __getitem__(self, idx):
+ input_view = super().__getitem__(
+ idx) # get previous input view results
+
+ # get novel view of the same instance
+ novel_view = super().__getitem__(
+ (idx // self.instance_data_length) * self.instance_data_length +
+ random.randint(0, self.instance_data_length - 1))
+
+ # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
+
+ input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
+ return input_view
+
+
+class MultiViewObjverseDatasetforLMDB(MultiViewObjverseDataset):
+
+ def __init__(
+ self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1,
+ shuffle_across_cls=False,
+ wds_split=1,
+ four_view_for_latent=False,
+ ):
+ super().__init__(file_path,
+ reso,
+ reso_encoder,
+ preprocess,
+ classes,
+ load_depth,
+ test,
+ scene_scale,
+ overfitting,
+ imgnet_normalize,
+ dataset_size,
+ overfitting_bs,
+ shuffle_across_cls=shuffle_across_cls,
+ wds_split=wds_split,
+ four_view_for_latent=four_view_for_latent)
+
+ assert self.reso == 256
+
+ with open(
+ '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
+ ) as f:
+ self.caption_data = json.load(f)
+ lmdb_path = '/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_uncompressed/'
+
+ # with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
+ # self.idx_to_ins_mapping = json.load(f)
+
+ def __len__(self):
+ return super().__len__()
+ # return 100 # for speed debug
+
+ def __getitem__(self, idx):
+ # ret_dict = super().__getitem__(idx)
+ rgb_fname = self.rgb_list[idx]
+ pose_fname = self.pose_list[idx]
+ raw_img = imageio.imread(rgb_fname) # [..., :3]
+
+ # assert raw_img.shape[-1] == 4
+
+ if raw_img.shape[-1] == 4: # ! set bg to white
+ alpha_mask = raw_img[..., -1:] / 255 # [0,1]
+ raw_img = alpha_mask * raw_img[..., :3] + (
+ 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
+ raw_img = raw_img.astype(np.uint8)
+
+ raw_img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16]
+ c = np.concatenate([c2w.reshape(16), self.intrinsics],
+ axis=0).reshape(25).astype(
+ np.float32) # 25, no '1' dim needed.
+ c = torch.from_numpy(c)
+ # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
+
+ # if self.load_depth:
+ # depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx],
+ # try:
+ depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso,
+ self.reso)
+ # except:
+ # # print(self.depth_list[idx])
+ # raise NotImplementedError(self.depth_list[idx])
+ # if depth
+
+ # try:
+ bbox = self.load_bbox(depth > 0)
+
+ ins = str(
+ (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
+ if self.shuffle_across_cls:
+ caption = self.caption_data['/'.join(ins.split('/')[1:])]
+ else:
+ caption = self.caption_data[ins]
+
+ ret_dict = {
+ 'raw_img': raw_img,
+ 'c': c,
+ 'depth': depth,
+ # 'depth_mask': depth_mask, # 64x64 here?
+ 'bbox': bbox,
+ 'ins': ins,
+ 'caption': caption,
+ # 'fname': rgb_fname,
+ }
+ return ret_dict
+
+
+class Objv_LMDBDataset_MV_Compressed(LMDBDataset_MV_Compressed):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ test=False,
+ **kwargs):
+ super().__init__(lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize,
+ dataset_size=dataset_size,
+ **kwargs)
+ self.instance_data_length = 40 # ! could save some key attributes in LMDB
+ if test:
+ self.length = self.instance_data_length
+ elif dataset_size > 0:
+ self.length = dataset_size * self.instance_data_length
+
+ # load caption data, and idx-to-ins mapping
+ with open(
+ '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
+ ) as f:
+ self.caption_data = json.load(f)
+ with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
+ self.idx_to_ins_mapping = json.load(f)
+
+ def _load_data(self, idx):
+ # '''
+ raw_img, depth, c, bbox = self._load_lmdb_data(idx)
+ # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
+
+ # resize depth and bbox
+ caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
+
+ return {
+ **self._post_process_sample(raw_img, depth),
+ 'c': c,
+ 'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8),
+ # 'bbox': (bbox*(self.reso/256.0)).astype(np.uint8), # TODO, double check 512 in wds?
+ 'caption': caption
+ }
+ # '''
+ # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
+ # st()
+ # return {}
+
+ def __getitem__(self, idx):
+ return self._load_data(idx)
+
+
+class Objv_LMDBDataset_MV_NoCompressed(Objv_LMDBDataset_MV_Compressed):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ test=False,
+ **kwargs):
+ super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
+ dataset_size, test, **kwargs)
+
+ def _load_data(self, idx):
+ # '''
+ raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
+
+ # resize depth and bbox
+ caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
+
+ return {
+ **self._post_process_sample(raw_img, depth), 'c': c,
+ 'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8),
+ 'caption': caption
+ }
+ return {}
+
+
+class Objv_LMDBDataset_NV_NoCompressed(Objv_LMDBDataset_MV_NoCompressed):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ test=False,
+ **kwargs):
+ super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
+ dataset_size, test, **kwargs)
+
+ def __getitem__(self, idx):
+ input_view = self._load_data(idx) # get previous input view results
+
+ # get novel view of the same instance
+ try:
+ novel_view = self._load_data(
+ (idx // self.instance_data_length) *
+ self.instance_data_length +
+ random.randint(0, self.instance_data_length - 1))
+ except Exception as e:
+ raise NotImplementedError(idx)
+
+ # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
+
+ input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
+ return input_view
+
+
+class Objv_LMDBDataset_MV_Compressed_for_lmdb(LMDBDataset_MV_Compressed):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ test=False,
+ **kwargs):
+ super().__init__(lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize,
+ dataset_size=dataset_size,
+ **kwargs)
+ self.instance_data_length = 40 # ! could save some key attributes in LMDB
+ if test:
+ self.length = self.instance_data_length
+ elif dataset_size > 0:
+ self.length = dataset_size * self.instance_data_length
+
+ # load caption data, and idx-to-ins mapping
+ with open(
+ '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
+ ) as f:
+ self.caption_data = json.load(f)
+ with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
+ self.idx_to_ins_mapping = json.load(f)
+
+ # def _load_data(self, idx):
+ # # '''
+ # raw_img, depth, c, bbox = self._load_lmdb_data(idx)
+
+ # # resize depth and bbox
+ # caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
+
+ # # st()
+
+ # return {
+ # **self._post_process_sample(raw_img, depth), 'c': c,
+ # 'bbox': (bbox*(self.reso/512.0)).astype(np.uint8),
+ # 'caption': caption
+ # }
+ # # '''
+ # # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
+ # # st()
+ # # return {}
+
+ def load_bbox(self, mask):
+ # st()
+ nonzero_value = torch.nonzero(mask)
+ height, width = nonzero_value.max(dim=0)[0]
+ top, left = nonzero_value.min(dim=0)[0]
+ bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
+ return bbox
+
+ def __getitem__(self, idx):
+ raw_img, depth, c, bbox = self._load_lmdb_data(idx)
+ return {'raw_img': raw_img, 'depth': depth, 'c': c, 'bbox': bbox}
+
+
+class Objv_LMDBDataset_NV_Compressed(Objv_LMDBDataset_MV_Compressed):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ **kwargs):
+ super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
+ dataset_size, **kwargs)
+
+ def __getitem__(self, idx):
+ input_view = self._load_data(idx) # get previous input view results
+
+ # get novel view of the same instance
+ try:
+ novel_view = self._load_data(
+ (idx // self.instance_data_length) *
+ self.instance_data_length +
+ random.randint(0, self.instance_data_length - 1))
+ except Exception as e:
+ raise NotImplementedError(idx)
+
+ # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
+
+ input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
+ return input_view
+
+
+#
+
+
+# test tar loading
+def load_wds_ResampledShard(file_path,
+ batch_size,
+ num_workers,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ plucker_embedding=False,
+ decode_encode_img_only=False,
+ load_instance=False,
+ mv_input=False,
+ split_chunk_input=False,
+ duplicate_sample=True,
+ append_depth=False,
+ gs_cam_format=False,
+ orthog_duplicate=False,
+ **kwargs):
+
+ # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
+ class PostProcess:
+
+ def __init__(
+ self,
+ reso,
+ reso_encoder,
+ imgnet_normalize,
+ plucker_embedding,
+ decode_encode_img_only,
+ mv_input,
+ split_chunk_input,
+ duplicate_sample,
+ append_depth,
+ gs_cam_format,
+ orthog_duplicate,
+ ) -> None:
+ self.gs_cam_format = gs_cam_format
+ self.append_depth = append_depth
+ self.plucker_embedding = plucker_embedding
+ self.decode_encode_img_only = decode_encode_img_only
+ self.duplicate_sample = duplicate_sample
+ self.orthog_duplicate = orthog_duplicate
+
+ self.zfar = 100.0
+ self.znear = 0.01
+
+ transformations = []
+ if not split_chunk_input:
+ transformations.append(transforms.ToTensor())
+
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+
+ self.reso_encoder = reso_encoder
+ self.reso = reso
+ self.instance_data_length = 40
+ # self.pair_per_instance = 1 # compat
+ self.mv_input = mv_input
+ self.split_chunk_input = split_chunk_input # 8
+ self.chunk_size = 8 if split_chunk_input else 40
+ # st()
+ if split_chunk_input:
+ self.pair_per_instance = 1
+ else:
+ self.pair_per_instance = 4 if mv_input else 2 # check whether improves IO
+
+ def gen_rays(self, c):
+ # Generate rays
+ intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
+ self.h = self.reso_encoder
+ self.w = self.reso_encoder
+ yy, xx = torch.meshgrid(
+ torch.arange(self.h, dtype=torch.float32) + 0.5,
+ torch.arange(self.w, dtype=torch.float32) + 0.5,
+ indexing='ij')
+
+ # normalize to 0-1 pixel range
+ yy = yy / self.h
+ xx = xx / self.w
+
+ # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
+ cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
+ 0], intrinsics[4]
+ # cx *= self.w
+ # cy *= self.h
+
+ # f_x = f_y = fx * h / res_raw
+ c2w = torch.from_numpy(c2w).float()
+
+ xx = (xx - cx) / fx
+ yy = (yy - cy) / fy
+ zz = torch.ones_like(xx)
+ dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
+ dirs /= torch.norm(dirs, dim=-1, keepdim=True)
+ dirs = dirs.reshape(-1, 3, 1)
+ del xx, yy, zz
+ # st()
+ dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
+
+ origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
+ origins = origins.view(self.h, self.w, 3)
+ dirs = dirs.view(self.h, self.w, 3)
+
+ return origins, dirs
+
+ def _post_process_batch_sample(
+ self, sample): # sample is an instance batch here
+ caption, ins = sample[-2:]
+ instance_samples = []
+
+ for instance_idx in range(sample[0].shape[0]):
+ instance_samples.append(
+ self._post_process_sample(item[instance_idx]
+ for item in sample[:-2]))
+
+ return (*instance_samples, caption, ins)
+
+ def _post_process_sample(self, data_sample):
+ # raw_img, depth, c, bbox, caption, ins = data_sample
+ raw_img, depth, c, bbox = data_sample
+
+ bbox = (bbox * (self.reso / 256)).astype(
+ np.uint8) # normalize bbox to the reso range
+
+ if raw_img.shape[-2] != self.reso_encoder:
+ img_to_encoder = cv2.resize(
+ raw_img, (self.reso_encoder, self.reso_encoder),
+ interpolation=cv2.INTER_LANCZOS4)
+ else:
+ img_to_encoder = raw_img
+
+ img_to_encoder = self.normalize(img_to_encoder)
+ if self.plucker_embedding:
+ rays_o, rays_d = self.gen_rays(c)
+ rays_plucker = torch.cat(
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d],
+ dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
+ img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
+
+ img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
+
+ if self.decode_encode_img_only:
+ depth_reso, fg_mask_reso = depth, depth
+ else:
+ depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso)
+
+ # return {
+ # # **sample,
+ # 'img_to_encoder': img_to_encoder,
+ # 'img': img,
+ # 'depth_mask': fg_mask_reso,
+ # # 'img_sr': img_sr,
+ # 'depth': depth_reso,
+ # 'c': c,
+ # 'bbox': bbox,
+ # 'caption': caption,
+ # 'ins': ins
+ # # ! no need to load img_sr for now
+ # }
+ # if len(data_sample) == 4:
+ return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox)
+ # else:
+ # return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox, data_sample[-2], data_sample[-1])
+
+ def _post_process_sample_batch(self, data_sample):
+ # raw_img, depth, c, bbox, caption, ins = data_sample
+ raw_img, depth, c, bbox = data_sample
+
+ bbox = (bbox * (self.reso / 256)).astype(
+ np.uint8) # normalize bbox to the reso range
+
+ assert raw_img.shape[-2] == self.reso_encoder
+ # img_to_encoder = cv2.resize(
+ # raw_img, (self.reso_encoder, self.reso_encoder),
+ # interpolation=cv2.INTER_LANCZOS4)
+ # else:
+ # img_to_encoder = raw_img
+
+ raw_img = torch.from_numpy(raw_img).permute(0, 3, 1,
+ 2) / 255.0 # [0,1]
+ img_to_encoder = self.normalize(raw_img)
+
+ if self.plucker_embedding:
+ rays_plucker = []
+ for idx in range(c.shape[0]):
+ rays_o, rays_d = self.gen_rays(c[idx])
+ rays_plucker.append(
+ torch.cat(
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d],
+ dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w
+ rays_plucker = torch.stack(rays_plucker, 0)
+ img_to_encoder = torch.cat([img_to_encoder, rays_plucker],
+ 1) # concat in C dim
+ if self.append_depth:
+ normalized_depth = torch.from_numpy(depth).clone().unsqueeze(
+ 1) # min=0
+ # normalized_depth -= torch.min(normalized_depth) # always 0 here
+ # normalized_depth /= torch.max(normalized_depth)
+ # normalized_depth = normalized_depth.unsqueeze(1) * 2 - 1 # normalize to [-1,1]
+ img_to_encoder = torch.cat([img_to_encoder, normalized_depth],
+ 1) # concat in C dim
+
+ # img = cv2.resize(raw_img, (self.reso, self.reso),
+ # interpolation=cv2.INTER_LANCZOS4)
+
+ # img = torch.from_numpy(raw_img).permute(2, 0, 1) / 127.5 - 1
+ # st()
+ if raw_img.shape[-1] != self.reso:
+ img = torch.nn.functional.interpolate(
+ input=raw_img,
+ size=(self.reso, self.reso),
+ mode='bilinear',
+ align_corners=False,
+ ) * 2 - 1 # [-1,1] range
+ else:
+ img = raw_img * 2 - 1
+
+ if self.decode_encode_img_only:
+ depth_reso, fg_mask_reso = depth, depth
+ else:
+ depth_reso, fg_mask_reso = resize_depth_mask_Tensor(
+ torch.from_numpy(depth), self.reso)
+
+ # if not self.gs_cam_format: # otherwise still playing with np format later
+ c = torch.from_numpy(c)
+
+ return (img_to_encoder, img, fg_mask_reso, depth_reso, c,
+ torch.from_numpy(bbox))
+
+ def rand_sample_idx(self):
+ return random.randint(0, self.instance_data_length - 1)
+
+ def rand_pair(self):
+ return (self.rand_sample_idx() for _ in range(2))
+
+ def paired_post_process(self, sample):
+ # repeat n times?
+ all_inp_list = []
+ all_nv_list = []
+ caption, ins = sample[-2:]
+ # expanded_return = []
+ for _ in range(self.pair_per_instance):
+ cano_idx, nv_idx = self.rand_pair()
+ cano_sample = self._post_process_sample(
+ item[cano_idx] for item in sample[:-2])
+ nv_sample = self._post_process_sample(item[nv_idx]
+ for item in sample[:-2])
+ all_inp_list.extend(cano_sample)
+ all_nv_list.extend(nv_sample)
+ return (*all_inp_list, *all_nv_list, caption, ins)
+ # return [cano_sample, nv_sample, caption, ins]
+ # return (*cano_sample, *nv_sample, caption, ins)
+
+ def get_source_cw2wT(self, source_cameras_view_to_world):
+ return matrix_to_quaternion(
+ source_cameras_view_to_world[:3, :3].transpose(0, 1))
+
+ def c_to_3dgs_format(self, pose):
+ # TODO, switch to torch version (batched later)
+
+ c2w = pose[:16].reshape(4, 4) # 3x4
+
+ # ! load cam
+ w2c = np.linalg.inv(c2w)
+ R = np.transpose(
+ w2c[:3, :3]
+ ) # R is stored transposed due to 'glm' in CUDA code
+ T = w2c[:3, 3]
+ fx = pose[16]
+ FovX = focal2fov(fx, 1)
+ FovY = focal2fov(fx, 1)
+
+ tanfovx = math.tan(FovX * 0.5)
+ tanfovy = math.tan(FovY * 0.5)
+
+ assert tanfovx == tanfovy
+
+ trans = np.array([0.0, 0.0, 0.0])
+ scale = 1.0
+
+ view_world_transform = torch.tensor(
+ getView2World(R, T, trans, scale)).transpose(0, 1)
+
+ world_view_transform = torch.tensor(
+ getWorld2View2(R, T, trans, scale)).transpose(0, 1)
+ projection_matrix = getProjectionMatrix(znear=self.znear,
+ zfar=self.zfar,
+ fovX=FovX,
+ fovY=FovY).transpose(0, 1)
+ full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
+ projection_matrix.unsqueeze(0))).squeeze(0)
+ camera_center = world_view_transform.inverse()[3, :3]
+
+ # item.update(viewpoint_cam=[viewpoint_cam])
+ c = {}
+ #
+ c["source_cv2wT_quat"] = self.get_source_cw2wT(
+ view_world_transform)
+ c.update(
+ # projection_matrix=projection_matrix, # K
+ cam_view=world_view_transform, # world_view_transform
+ cam_view_proj=full_proj_transform, # full_proj_transform
+ cam_pos=camera_center,
+ tanfov=tanfovx, # TODO, fix in the renderer
+ orig_pose=torch.from_numpy(pose),
+ orig_c2w=torch.from_numpy(c2w),
+ orig_w2c=torch.from_numpy(w2c),
+ # tanfovy=tanfovy,
+ )
+
+ return c # dict for gs rendering
+
+ def paired_post_process_chunk(self, sample):
+ # repeat n times?
+ all_inp_list = []
+ all_nv_list = []
+ caption, ins = sample[-2:]
+ assert sample[0].shape[0] == 8 # random chunks
+ # expanded_return = []
+
+ if self.duplicate_sample:
+ processed_sample = self._post_process_sample_batch(
+ item for item in sample[:-2])
+
+ if self.orthog_duplicate:
+ indices = torch.cat([torch.randperm(8),
+ torch.randperm(8)]) # for now
+ else:
+ indices = torch.randperm(8)
+
+ shuffle_processed_sample = []
+
+ for _, item in enumerate(processed_sample):
+ shuffle_processed_sample.append(
+ torch.index_select(item, dim=0, index=indices))
+ processed_sample = shuffle_processed_sample
+
+ if not self.orthog_duplicate:
+ all_inp_list.extend(item[:4] for item in processed_sample)
+ all_nv_list.extend(item[4:] for item in processed_sample)
+ else:
+ all_inp_list.extend(item[:8] for item in processed_sample)
+ all_nv_list.extend(item[8:] for item in processed_sample)
+
+ return (*all_inp_list, *all_nv_list, caption, ins)
+
+ else:
+ processed_sample = self._post_process_sample_batch( # avoid shuffle shorten processing time
+ item[:4] for item in sample[:-2])
+
+ all_inp_list.extend(item for item in processed_sample)
+ all_nv_list.extend(
+ item for item in processed_sample) # ! placeholder
+
+ return (*all_inp_list, *all_nv_list, caption, ins)
+
+ # randomly shuffle 8 views, avoid overfitting
+
+ def single_sample_create_dict(self, sample, prefix=''):
+ # if len(sample) == 1:
+ # sample = sample[0]
+ # assert len(sample) == 6
+ img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
+
+ if self.gs_cam_format:
+ # TODO, can optimize later after model converges
+ B, V, _ = c.shape # B 4 25
+ c = rearrange(c, 'B V C -> (B V) C').cpu().numpy()
+ all_gs_c = [self.c_to_3dgs_format(pose) for pose in c]
+ c = {
+ k:
+ rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]),
+ '(B V) ... -> B V ...',
+ B=B,
+ V=V) if isinstance(all_gs_c[0][k], torch.Tensor)
+ else all_gs_c[0][k]
+ for k in all_gs_c[0].keys()
+ }
+ # c = collate_gs_c
+
+ return {
+ # **sample,
+ f'{prefix}img_to_encoder': img_to_encoder,
+ f'{prefix}img': img,
+ f'{prefix}depth_mask': fg_mask_reso,
+ f'{prefix}depth': depth_reso,
+ f'{prefix}c': c,
+ f'{prefix}bbox': bbox,
+ }
+
+ def single_instance_sample_create_dict(self, sample, prfix=''):
+ assert len(sample) == 42
+
+ inp_sample_list = [[] for _ in range(6)]
+
+ for item in sample[:40]:
+ for item_idx in range(6):
+ inp_sample_list[item_idx].append(item[0][item_idx])
+
+ inp_sample = self.single_sample_create_dict(
+ (torch.stack(item_list) for item_list in inp_sample_list),
+ prefix='')
+
+ return {
+ **inp_sample, #
+ 'caption': sample[-2],
+ 'ins': sample[-1]
+ }
+
+ def decode_zip(self, sample_pyd, shape=(256, 256)):
+ if isinstance(sample_pyd, tuple):
+ sample_pyd = sample_pyd[0]
+ assert isinstance(sample_pyd, dict)
+
+ raw_img = decompress_and_open_image_gzip(
+ sample_pyd['raw_img'],
+ is_img=True,
+ decompress=True,
+ decompress_fn=lz4.frame.decompress)
+
+ caption = sample_pyd['caption'].decode('utf-8')
+ ins = sample_pyd['ins'].decode('utf-8')
+
+ c = decompress_array(sample_pyd['c'], (
+ self.chunk_size,
+ 25,
+ ),
+ np.float32,
+ decompress=True,
+ decompress_fn=lz4.frame.decompress)
+
+ bbox = decompress_array(
+ sample_pyd['bbox'],
+ (
+ self.chunk_size,
+ 4,
+ ),
+ np.float32,
+ # decompress=False)
+ decompress=True,
+ decompress_fn=lz4.frame.decompress)
+
+ if self.decode_encode_img_only:
+ depth = np.zeros(shape=(self.chunk_size,
+ *shape)) # save loading time
+ else:
+ depth = decompress_array(sample_pyd['depth'],
+ (self.chunk_size, *shape),
+ np.float32,
+ decompress=True,
+ decompress_fn=lz4.frame.decompress)
+
+ # return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c}
+ # return raw_img, depth, c, bbox, caption, ins
+ # return raw_img, bbox, caption, ins
+ # return bbox, caption, ins
+ return raw_img, depth, c, bbox, caption, ins
+ # ! run single-instance pipeline first
+ # return raw_img[0], depth[0], c[0], bbox[0], caption, ins
+
+ def create_dict(self, sample):
+ # sample = [item[0] for item in sample] # wds wrap items in []
+ # st()
+ cano_sample_list = [[] for _ in range(6)]
+ nv_sample_list = [[] for _ in range(6)]
+ # st()
+ # bs = (len(sample)-2) // 6
+ for idx in range(0, self.pair_per_instance):
+
+ cano_sample = sample[6 * idx:6 * (idx + 1)]
+ nv_sample = sample[6 * self.pair_per_instance +
+ 6 * idx:6 * self.pair_per_instance + 6 *
+ (idx + 1)]
+
+ for item_idx in range(6):
+ cano_sample_list[item_idx].append(cano_sample[item_idx])
+ nv_sample_list[item_idx].append(nv_sample[item_idx])
+
+ # ! cycle input/output view for more pairs
+ cano_sample_list[item_idx].append(nv_sample[item_idx])
+ nv_sample_list[item_idx].append(cano_sample[item_idx])
+
+ # if self.split_chunk_input:
+ # cano_sample = self.single_sample_create_dict(
+ # (torch.cat(item_list, 0) for item_list in cano_sample_list),
+ # prefix='')
+ # nv_sample = self.single_sample_create_dict(
+ # (torch.cat(item_list, 0) for item_list in nv_sample_list),
+ # prefix='nv_')
+ # else:
+ cano_sample = self.single_sample_create_dict(
+ (torch.cat(item_list, 0) for item_list in cano_sample_list),
+ prefix='')
+ nv_sample = self.single_sample_create_dict(
+ (torch.cat(item_list, 0) for item_list in nv_sample_list),
+ prefix='nv_')
+
+ return {
+ **cano_sample,
+ **nv_sample, 'caption': sample[-2],
+ 'ins': sample[-1]
+ }
+
+ def prepare_mv_input(self, sample):
+ # sample = [item[0] for item in sample] # wds wrap items in []
+ bs = len(sample['caption']) # number of instances
+ chunk_size = sample['img'].shape[0] // bs
+
+ if self.split_chunk_input:
+ for k, v in sample.items():
+ if isinstance(v, torch.Tensor):
+ sample[k] = rearrange(v,
+ "b f c ... -> (b f) c ...",
+ f=4 if not self.orthog_duplicate
+ else 8).contiguous()
+
+ # img = rearrange(sample['img'], "(b f) c h w -> b f c h w", f=4).contiguous()
+ # gt = rearrange(sample['nv_img'], "(b f) c h w -> b c (f h) w", f=4).contiguous()
+ # img = rearrange(sample['img'], "b f c h w -> b c (f h) w", f=4).contiguous()
+ # gt = rearrange(sample['nv_img'], "b f c h w -> b c (f h) w", f=4).contiguous()
+ # torchvision.utils.save_image(img, 'inp.jpg', normalize=True)
+ # torchvision.utils.save_image(gt, 'nv.jpg', normalize=True)
+
+ # ! shift nv
+ else:
+ for k, v in sample.items():
+ if k not in ['ins', 'caption']:
+
+ rolled_idx = torch.LongTensor(
+ list(
+ itertools.chain.from_iterable(
+ list(range(i, sample['img'].shape[0], bs))
+ for i in range(bs))))
+
+ v = torch.index_select(v, dim=0, index=rolled_idx)
+ sample[k] = v
+
+ # img = sample['img']
+ # gt = sample['nv_img']
+ # torchvision.utils.save_image(img[0], 'inp.jpg', normalize=True)
+ # torchvision.utils.save_image(gt[0], 'nv.jpg', normalize=True)
+
+ for k, v in sample.items():
+ if 'nv' in k:
+ rolled_idx = torch.LongTensor(
+ list(
+ itertools.chain.from_iterable(
+ list(
+ np.roll(
+ np.arange(i * chunk_size, (i + 1) *
+ chunk_size), 4)
+ for i in range(bs)))))
+
+ v = torch.index_select(v, dim=0, index=rolled_idx)
+ sample[k] = v
+
+ # torchvision.utils.save_image(sample['nv_img'], 'nv.png', normalize=True)
+ # torchvision.utils.save_image(sample['img'], 'inp.png', normalize=True)
+
+ return sample
+
+ post_process_cls = PostProcess(
+ reso,
+ reso_encoder,
+ imgnet_normalize=imgnet_normalize,
+ plucker_embedding=plucker_embedding,
+ decode_encode_img_only=decode_encode_img_only,
+ mv_input=mv_input,
+ split_chunk_input=split_chunk_input,
+ duplicate_sample=duplicate_sample,
+ append_depth=append_depth,
+ gs_cam_format=gs_cam_format,
+ orthog_duplicate=orthog_duplicate,
+ )
+
+ # ! add shuffling
+
+ if isinstance(file_path, list): # lst of shard urls
+ all_shards = []
+ for url_path in file_path:
+ all_shards.extend(wds.shardlists.expand_source(url_path))
+ logger.log('all_shards', all_shards)
+ else:
+ all_shards = file_path # to be expanded
+
+ if not load_instance: # during reconstruction training, load pair
+ if not split_chunk_input:
+ dataset = wds.DataPipeline(
+ wds.ResampledShards(all_shards), # url_shard
+ # at this point we have an iterator over all the shards
+ wds.shuffle(50),
+ wds.split_by_worker, # if multi-node
+ wds.tarfile_to_samples(),
+ # add wds.split_by_node here if you are using multiple nodes
+ wds.shuffle(
+ 1000
+ ), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.decode(wds.autodecode.basichandlers), # TODO
+ wds.to_tuple(
+ "sample.pyd"), # extract the pyd from top level dict
+ wds.map(post_process_cls.decode_zip),
+ wds.map(post_process_cls.paired_post_process
+ ), # create input-novelview paired samples
+ # wds.map(post_process_cls._post_process_sample),
+ # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.batched(
+ 16,
+ partial=True,
+ # collation_fn=collate
+ ) # streaming more data at once, and rebatch later
+ )
+
+ else:
+ dataset = wds.DataPipeline(
+ wds.ResampledShards(all_shards), # url_shard
+ # at this point we have an iterator over all the shards
+ wds.shuffle(100),
+ wds.split_by_worker, # if multi-node
+ wds.tarfile_to_samples(),
+ # add wds.split_by_node here if you are using multiple nodes
+ wds.shuffle(
+ # 7500 if not duplicate_sample else 2500
+ # 7500 if not duplicate_sample else 5000
+ # 1000,
+ 250,
+ ), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.decode(wds.autodecode.basichandlers), # TODO
+ wds.to_tuple(
+ "sample.pyd"), # extract the pyd from top level dict
+ wds.map(post_process_cls.decode_zip),
+ wds.map(post_process_cls.paired_post_process_chunk
+ ), # create input-novelview paired samples
+ # wds.map(post_process_cls._post_process_sample),
+ # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.batched(
+ 20,
+ partial=True,
+ # collation_fn=collate
+ ) # streaming more data at once, and rebatch later
+ )
+
+ loader_shard = wds.WebLoader(
+ dataset,
+ num_workers=num_workers,
+ drop_last=False,
+ batch_size=None,
+ shuffle=False,
+ persistent_workers=num_workers
+ # > 0).unbatched().shuffle(1000).batched(batch_size).map(
+ > 0).unbatched().shuffle(250).batched(batch_size).map(
+ post_process_cls.create_dict)
+
+ if mv_input:
+ loader_shard = loader_shard.map(post_process_cls.prepare_mv_input)
+
+ else: # load single instance during test/eval
+ assert batch_size == 1
+
+ dataset = wds.DataPipeline(
+ wds.ResampledShards(all_shards), # url_shard
+ # at this point we have an iterator over all the shards
+ wds.shuffle(50),
+ wds.split_by_worker, # if multi-node
+ wds.tarfile_to_samples(),
+ # add wds.split_by_node here if you are using multiple nodes
+ wds.detshuffle(
+ 100
+ ), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.decode(wds.autodecode.basichandlers), # TODO
+ wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
+ wds.map(post_process_cls.decode_zip),
+ # wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples
+ wds.map(post_process_cls._post_process_batch_sample),
+ # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.batched(
+ 2,
+ partial=True,
+ # collation_fn=collate
+ ) # streaming more data at once, and rebatch later
+ )
+
+ loader_shard = wds.WebLoader(
+ dataset,
+ num_workers=num_workers,
+ drop_last=False,
+ batch_size=None,
+ shuffle=False,
+ persistent_workers=num_workers
+ > 0).unbatched().shuffle(200).batched(batch_size).map(
+ post_process_cls.single_instance_sample_create_dict)
+
+ # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
+ # 1000).batched(batch_size).map(post_process_cls.create_dict)
+ # .map(collate)
+ # .map(collate)
+
+ # .batched(batch_size)
+ #
+
+ # .unbatched().shuffle(1000).batched(batch_size).map(post_process)
+ # # https://github.com/webdataset/webdataset/issues/187
+
+ # return next(iter(loader_shard))
+ #return dataset
+ return loader_shard
+
+
+# test tar loading
+def load_wds_diff_ResampledShard(file_path,
+ batch_size,
+ num_workers,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ plucker_embedding=False,
+ decode_encode_img_only=False,
+ mv_latent_dir='',
+ **kwargs):
+
+ # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
+ class PostProcess:
+
+ def __init__(
+ self,
+ reso,
+ reso_encoder,
+ imgnet_normalize,
+ plucker_embedding,
+ decode_encode_img_only,
+ mv_latent_dir,
+ ) -> None:
+ self.plucker_embedding = plucker_embedding
+
+ self.mv_latent_dir = mv_latent_dir
+ self.decode_encode_img_only = decode_encode_img_only
+
+ transformations = [
+ transforms.ToTensor(), # [0,1] range
+ ]
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+
+ self.reso_encoder = reso_encoder
+ self.reso = reso
+ self.instance_data_length = 40
+ # self.pair_per_instance = 1 # compat
+ self.pair_per_instance = 2 # check whether improves IO
+ # self.pair_per_instance = 3 # check whether improves IO
+ # self.pair_per_instance = 4 # check whether improves IO
+
+ def get_rays_kiui(self, c, opengl=True):
+ h, w = self.reso_encoder, self.reso_encoder
+ intrinsics, pose = c[16:], c[:16].reshape(4, 4)
+ # cx, cy, fx, fy = intrinsics[2], intrinsics[5]
+ fx = fy = 525 # pixel space
+ cx = cy = 256 # rendering default K
+ factor = self.reso / (cx * 2) # 128 / 512
+ fx = fx * factor
+ fy = fy * factor
+
+ x, y = torch.meshgrid(
+ torch.arange(w, device=pose.device),
+ torch.arange(h, device=pose.device),
+ indexing="xy",
+ )
+ x = x.flatten()
+ y = y.flatten()
+
+ cx = w * 0.5
+ cy = h * 0.5
+
+ # focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
+
+ camera_dirs = F.pad(
+ torch.stack(
+ [
+ (x - cx + 0.5) / fx,
+ (y - cy + 0.5) / fy * (-1.0 if opengl else 1.0),
+ ],
+ dim=-1,
+ ),
+ (0, 1),
+ value=(-1.0 if opengl else 1.0),
+ ) # [hw, 3]
+
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
+
+ rays_o = rays_o.view(h, w, 3)
+ rays_d = safe_normalize(rays_d).view(h, w, 3)
+
+ return rays_o, rays_d
+
+ def gen_rays(self, c):
+ # Generate rays
+ intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
+ self.h = self.reso_encoder
+ self.w = self.reso_encoder
+ yy, xx = torch.meshgrid(
+ torch.arange(self.h, dtype=torch.float32) + 0.5,
+ torch.arange(self.w, dtype=torch.float32) + 0.5,
+ indexing='ij')
+
+ # normalize to 0-1 pixel range
+ yy = yy / self.h
+ xx = xx / self.w
+
+ # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
+ cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
+ 0], intrinsics[4]
+ # cx *= self.w
+ # cy *= self.h
+
+ # f_x = f_y = fx * h / res_raw
+ c2w = torch.from_numpy(c2w).float()
+
+ xx = (xx - cx) / fx
+ yy = (yy - cy) / fy
+ zz = torch.ones_like(xx)
+ dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
+ dirs /= torch.norm(dirs, dim=-1, keepdim=True)
+ dirs = dirs.reshape(-1, 3, 1)
+ del xx, yy, zz
+ # st()
+ dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
+
+ origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
+ origins = origins.view(self.h, self.w, 3)
+ dirs = dirs.view(self.h, self.w, 3)
+
+ return origins, dirs
+
+ def _post_process_sample(self, data_sample):
+ # raw_img, depth, c, bbox, caption, ins = data_sample
+ raw_img, c, caption, ins = data_sample
+
+ # bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range
+
+ # if raw_img.shape[-2] != self.reso_encoder:
+ # img_to_encoder = cv2.resize(
+ # raw_img, (self.reso_encoder, self.reso_encoder),
+ # interpolation=cv2.INTER_LANCZOS4)
+ # else:
+ # img_to_encoder = raw_img
+
+ # img_to_encoder = self.normalize(img_to_encoder)
+ # if self.plucker_embedding:
+ # rays_o, rays_d = self.gen_rays(c)
+ # rays_plucker = torch.cat(
+ # [torch.cross(rays_o, rays_d, dim=-1), rays_d],
+ # dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
+ # img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
+
+ # img = cv2.resize(raw_img, (self.reso, self.reso),
+ # interpolation=cv2.INTER_LANCZOS4)
+ img = raw_img # 256x256
+
+ img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
+
+ # load latent
+
+ latent_path = Path(self.mv_latent_dir, ins, 'latent.npy')
+ latent = np.load(latent_path)
+
+ # return (img_to_encoder, img, c, caption, ins)
+ return (latent, img, c, caption, ins)
+
+ def rand_sample_idx(self):
+ return random.randint(0, self.instance_data_length - 1)
+
+ def rand_pair(self):
+ return (self.rand_sample_idx() for _ in range(2))
+
+ def paired_post_process(self, sample):
+ # repeat n times?
+ all_inp_list = []
+ all_nv_list = []
+ caption, ins = sample[-2:]
+ # expanded_return = []
+ for _ in range(self.pair_per_instance):
+ cano_idx, nv_idx = self.rand_pair()
+ cano_sample = self._post_process_sample(
+ item[cano_idx] for item in sample[:-2])
+ nv_sample = self._post_process_sample(item[nv_idx]
+ for item in sample[:-2])
+ all_inp_list.extend(cano_sample)
+ all_nv_list.extend(nv_sample)
+ return (*all_inp_list, *all_nv_list, caption, ins)
+ # return [cano_sample, nv_sample, caption, ins]
+ # return (*cano_sample, *nv_sample, caption, ins)
+
+ # def single_sample_create_dict(self, sample, prefix=''):
+ # # if len(sample) == 1:
+ # # sample = sample[0]
+ # # assert len(sample) == 6
+ # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
+ # return {
+ # # **sample,
+ # f'{prefix}img_to_encoder': img_to_encoder,
+ # f'{prefix}img': img,
+ # f'{prefix}depth_mask': fg_mask_reso,
+ # f'{prefix}depth': depth_reso,
+ # f'{prefix}c': c,
+ # f'{prefix}bbox': bbox,
+ # }
+
+ def single_sample_create_dict(self, sample, prefix=''):
+ # if len(sample) == 1:
+ # sample = sample[0]
+ # assert len(sample) == 6
+ # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
+ # img_to_encoder, img, c, caption, ins = sample
+ # img, c, caption, ins = sample
+ latent, img, c, caption, ins = sample
+ # load latent
+ return {
+ # **sample,
+ # 'img_to_encoder': img_to_encoder,
+ 'latent': latent,
+ 'img': img,
+ 'c': c,
+ 'caption': caption,
+ 'ins': ins
+ }
+
+ def decode_zip(self, sample_pyd, shape=(256, 256)):
+ if isinstance(sample_pyd, tuple):
+ sample_pyd = sample_pyd[0]
+ assert isinstance(sample_pyd, dict)
+
+ raw_img = decompress_and_open_image_gzip(
+ sample_pyd['raw_img'],
+ is_img=True,
+ decompress=True,
+ decompress_fn=lz4.frame.decompress)
+
+ caption = sample_pyd['caption'].decode('utf-8')
+ ins = sample_pyd['ins'].decode('utf-8')
+
+ c = decompress_array(sample_pyd['c'], (25, ),
+ np.float32,
+ decompress=True,
+ decompress_fn=lz4.frame.decompress)
+
+ # bbox = decompress_array(
+ # sample_pyd['bbox'],
+ # (
+ # 40,
+ # 4,
+ # ),
+ # np.float32,
+ # # decompress=False)
+ # decompress=True,
+ # decompress_fn=lz4.frame.decompress)
+
+ # if self.decode_encode_img_only:
+ # depth = np.zeros(shape=(40, *shape)) # save loading time
+ # else:
+ # depth = decompress_array(sample_pyd['depth'], (40, *shape),
+ # np.float32,
+ # decompress=True,
+ # decompress_fn=lz4.frame.decompress)
+
+ # return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c}
+ # return raw_img, depth, c, bbox, caption, ins
+ # return raw_img, bbox, caption, ins
+ # return bbox, caption, ins
+ return raw_img, c, caption, ins
+ # ! run single-instance pipeline first
+ # return raw_img[0], depth[0], c[0], bbox[0], caption, ins
+
+ def create_dict(self, sample):
+ # sample = [item[0] for item in sample] # wds wrap items in []
+ # cano_sample_list = [[] for _ in range(6)]
+ # nv_sample_list = [[] for _ in range(6)]
+ # for idx in range(0, self.pair_per_instance):
+ # cano_sample = sample[6*idx:6*(idx+1)]
+ # nv_sample = sample[6*self.pair_per_instance+6*idx:6*self.pair_per_instance+6*(idx+1)]
+
+ # for item_idx in range(6):
+ # cano_sample_list[item_idx].append(cano_sample[item_idx])
+ # nv_sample_list[item_idx].append(nv_sample[item_idx])
+
+ # # ! cycle input/output view for more pairs
+ # cano_sample_list[item_idx].append(nv_sample[item_idx])
+ # nv_sample_list[item_idx].append(cano_sample[item_idx])
+
+ cano_sample = self.single_sample_create_dict(sample, prefix='')
+ # nv_sample = self.single_sample_create_dict((torch.cat(item_list) for item_list in nv_sample_list) , prefix='nv_')
+
+ return cano_sample
+ # return {
+ # **cano_sample,
+ # # **nv_sample,
+ # 'caption': sample[-2],
+ # 'ins': sample[-1]
+ # }
+
+ post_process_cls = PostProcess(
+ reso,
+ reso_encoder,
+ imgnet_normalize=imgnet_normalize,
+ plucker_embedding=plucker_embedding,
+ decode_encode_img_only=decode_encode_img_only,
+ mv_latent_dir=mv_latent_dir,
+ )
+
+ if isinstance(file_path, list): # lst of shard urls
+ all_shards = []
+ for url_path in file_path:
+ all_shards.extend(wds.shardlists.expand_source(url_path))
+ logger.log('all_shards', all_shards)
+ else:
+ all_shards = file_path # to be expanded
+
+ dataset = wds.DataPipeline(
+ wds.ResampledShards(all_shards), # url_shard
+ # at this point we have an iterator over all the shards
+ wds.shuffle(50),
+ wds.split_by_worker, # if multi-node
+ wds.tarfile_to_samples(),
+ # add wds.split_by_node here if you are using multiple nodes
+ wds.detshuffle(
+ 15000
+ ), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.decode(wds.autodecode.basichandlers), # TODO
+ wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
+ wds.map(post_process_cls.decode_zip),
+ # wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples
+ wds.map(post_process_cls._post_process_sample),
+ # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.batched(
+ 80,
+ partial=True,
+ # collation_fn=collate
+ ) # streaming more data at once, and rebatch later
+ )
+
+ loader_shard = wds.WebLoader(
+ dataset,
+ num_workers=num_workers,
+ drop_last=False,
+ batch_size=None,
+ shuffle=False,
+ persistent_workers=num_workers
+ > 0).unbatched().shuffle(2500).batched(batch_size).map(
+ post_process_cls.create_dict)
+
+ # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
+ # 1000).batched(batch_size).map(post_process_cls.create_dict)
+ # .map(collate)
+ # .map(collate)
+
+ # .batched(batch_size)
+ #
+
+ # .unbatched().shuffle(1000).batched(batch_size).map(post_process)
+ # # https://github.com/webdataset/webdataset/issues/187
+
+ # return next(iter(loader_shard))
+ #return dataset
+ return loader_shard
+
+
+def load_wds_data(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ num_workers=6,
+ plucker_embedding=False,
+ decode_encode_img_only=False,
+ load_wds_diff=False,
+ load_wds_latent=False,
+ load_instance=False, # for evaluation
+ mv_input=False,
+ split_chunk_input=False,
+ duplicate_sample=True,
+ mv_latent_dir='',
+ append_depth=False,
+ gs_cam_format=False,
+ orthog_duplicate=False,
+ **args):
+
+ if load_wds_diff:
+ assert num_workers == 0 # on aliyun, worker=0 performs much much faster
+ wds_loader = load_wds_diff_ResampledShard(
+ file_path,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ reso=reso,
+ reso_encoder=reso_encoder,
+ plucker_embedding=plucker_embedding,
+ decode_encode_img_only=decode_encode_img_only,
+ mv_input=mv_input,
+ split_chunk_input=split_chunk_input,
+ append_depth=append_depth,
+ mv_latent_dir=mv_latent_dir,
+ gs_cam_format=gs_cam_format,
+ orthog_duplicate=orthog_duplicate,
+ )
+ elif load_wds_latent:
+ # for diffusion training, cache latent
+ wds_loader = load_wds_latent_ResampledShard(
+ file_path,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ reso=reso,
+ reso_encoder=reso_encoder,
+ plucker_embedding=plucker_embedding,
+ decode_encode_img_only=decode_encode_img_only,
+ mv_input=mv_input,
+ split_chunk_input=split_chunk_input,
+ )
+
+ # elif load_instance:
+ # wds_loader = load_wds_instance_ResampledShard(
+ # file_path,
+ # batch_size=batch_size,
+ # num_workers=num_workers,
+ # reso=reso,
+ # reso_encoder=reso_encoder,
+ # plucker_embedding=plucker_embedding,
+ # decode_encode_img_only=decode_encode_img_only
+ # )
+
+ else:
+ wds_loader = load_wds_ResampledShard(
+ file_path,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ reso=reso,
+ reso_encoder=reso_encoder,
+ plucker_embedding=plucker_embedding,
+ decode_encode_img_only=decode_encode_img_only,
+ load_instance=load_instance,
+ mv_input=mv_input,
+ split_chunk_input=split_chunk_input,
+ duplicate_sample=duplicate_sample,
+ append_depth=append_depth,
+ gs_cam_format=gs_cam_format,
+ orthog_duplicate=orthog_duplicate,
+ )
+
+ while True:
+ yield from wds_loader
+ # yield from wds_loader
+
+
+# test tar loading
+def load_wds_latent_ResampledShard(file_path,
+ batch_size,
+ num_workers,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ plucker_embedding=False,
+ decode_encode_img_only=False,
+ **kwargs):
+
+ # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
+ class PostProcess:
+
+ def __init__(
+ self,
+ reso,
+ reso_encoder,
+ imgnet_normalize,
+ plucker_embedding,
+ decode_encode_img_only,
+ ) -> None:
+ self.plucker_embedding = plucker_embedding
+ self.decode_encode_img_only = decode_encode_img_only
+
+ transformations = [
+ transforms.ToTensor(), # [0,1] range
+ ]
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+
+ self.reso_encoder = reso_encoder
+ self.reso = reso
+ self.instance_data_length = 40
+ # self.pair_per_instance = 1 # compat
+ self.pair_per_instance = 2 # check whether improves IO
+ # self.pair_per_instance = 3 # check whether improves IO
+ # self.pair_per_instance = 4 # check whether improves IO
+
+ def _post_process_sample(self, data_sample):
+ # raw_img, depth, c, bbox, caption, ins = data_sample
+ raw_img, c, caption, ins = data_sample
+
+ # bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range
+
+ if raw_img.shape[-2] != self.reso_encoder:
+ img_to_encoder = cv2.resize(
+ raw_img, (self.reso_encoder, self.reso_encoder),
+ interpolation=cv2.INTER_LANCZOS4)
+ else:
+ img_to_encoder = raw_img
+
+ img_to_encoder = self.normalize(img_to_encoder)
+ if self.plucker_embedding:
+ rays_o, rays_d = self.gen_rays(c)
+ rays_plucker = torch.cat(
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d],
+ dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
+ img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
+
+ img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
+
+ return (img_to_encoder, img, c, caption, ins)
+
+ def rand_sample_idx(self):
+ return random.randint(0, self.instance_data_length - 1)
+
+ def rand_pair(self):
+ return (self.rand_sample_idx() for _ in range(2))
+
+ def paired_post_process(self, sample):
+ # repeat n times?
+ all_inp_list = []
+ all_nv_list = []
+ caption, ins = sample[-2:]
+ # expanded_return = []
+ for _ in range(self.pair_per_instance):
+ cano_idx, nv_idx = self.rand_pair()
+ cano_sample = self._post_process_sample(
+ item[cano_idx] for item in sample[:-2])
+ nv_sample = self._post_process_sample(item[nv_idx]
+ for item in sample[:-2])
+ all_inp_list.extend(cano_sample)
+ all_nv_list.extend(nv_sample)
+ return (*all_inp_list, *all_nv_list, caption, ins)
+ # return [cano_sample, nv_sample, caption, ins]
+ # return (*cano_sample, *nv_sample, caption, ins)
+
+ # def single_sample_create_dict(self, sample, prefix=''):
+ # # if len(sample) == 1:
+ # # sample = sample[0]
+ # # assert len(sample) == 6
+ # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
+ # return {
+ # # **sample,
+ # f'{prefix}img_to_encoder': img_to_encoder,
+ # f'{prefix}img': img,
+ # f'{prefix}depth_mask': fg_mask_reso,
+ # f'{prefix}depth': depth_reso,
+ # f'{prefix}c': c,
+ # f'{prefix}bbox': bbox,
+ # }
+
+ def single_sample_create_dict(self, sample, prefix=''):
+ # if len(sample) == 1:
+ # sample = sample[0]
+ # assert len(sample) == 6
+ # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
+ img_to_encoder, img, c, caption, ins = sample
+ return {
+ # **sample,
+ 'img_to_encoder': img_to_encoder,
+ 'img': img,
+ 'c': c,
+ 'caption': caption,
+ 'ins': ins
+ }
+
+ def decode_zip(self, sample_pyd, shape=(256, 256)):
+ if isinstance(sample_pyd, tuple):
+ sample_pyd = sample_pyd[0]
+ assert isinstance(sample_pyd, dict)
+
+ latent = sample_pyd['latent']
+ caption = sample_pyd['caption'].decode('utf-8')
+ c = sample_pyd['c']
+ # img = sample_pyd['img']
+ # st()
+
+ return latent, caption, c
+
+ def create_dict(self, sample):
+
+ return {
+ # **sample,
+ 'latent': sample[0],
+ 'caption': sample[1],
+ 'c': sample[2],
+ }
+
+ post_process_cls = PostProcess(
+ reso,
+ reso_encoder,
+ imgnet_normalize=imgnet_normalize,
+ plucker_embedding=plucker_embedding,
+ decode_encode_img_only=decode_encode_img_only,
+ )
+
+ if isinstance(file_path, list): # lst of shard urls
+ all_shards = []
+ for url_path in file_path:
+ all_shards.extend(wds.shardlists.expand_source(url_path))
+ logger.log('all_shards', all_shards)
+ else:
+ all_shards = file_path # to be expanded
+
+ dataset = wds.DataPipeline(
+ wds.ResampledShards(all_shards), # url_shard
+ # at this point we have an iterator over all the shards
+ wds.shuffle(50),
+ wds.split_by_worker, # if multi-node
+ wds.tarfile_to_samples(),
+ # add wds.split_by_node here if you are using multiple nodes
+ wds.detshuffle(
+ 2500
+ ), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.decode(wds.autodecode.basichandlers), # TODO
+ wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
+ wds.map(post_process_cls.decode_zip),
+ # wds.map(post_process_cls._post_process_sample),
+ # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
+ wds.batched(
+ 150,
+ partial=True,
+ # collation_fn=collate
+ ) # streaming more data at once, and rebatch later
+ )
+
+ loader_shard = wds.WebLoader(
+ dataset,
+ num_workers=num_workers,
+ drop_last=False,
+ batch_size=None,
+ shuffle=False,
+ persistent_workers=num_workers
+ > 0).unbatched().shuffle(1000).batched(batch_size).map(
+ post_process_cls.create_dict)
+
+ # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
+ # 1000).batched(batch_size).map(post_process_cls.create_dict)
+ # .map(collate)
+ # .map(collate)
+
+ # .batched(batch_size)
+ #
+
+ # .unbatched().shuffle(1000).batched(batch_size).map(post_process)
+ # # https://github.com/webdataset/webdataset/issues/187
+
+ # return next(iter(loader_shard))
+ #return dataset
+ return loader_shard
diff --git a/datasets/shapenet.py b/datasets/shapenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f2046240084ffb33499c18b17c3c8aa2f2efe61
--- /dev/null
+++ b/datasets/shapenet.py
@@ -0,0 +1,972 @@
+import os
+import torchvision
+import pickle
+from typing import Any
+import lmdb
+import cv2
+import imageio
+import numpy as np
+from PIL import Image
+import Imath
+import OpenEXR
+from pdb import set_trace as st
+from pathlib import Path
+
+from functools import partial
+import io
+import gzip
+import random
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+from torch.utils.data.distributed import DistributedSampler
+from pathlib import Path
+
+from guided_diffusion import logger
+
+def load_dataset(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_lmdb=False,
+ infi_sampler=True
+):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # st()
+ if use_lmdb:
+ logger.log('using LMDB dataset')
+ # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
+ if 'nv' in trainer_name:
+ dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ # dataset = dataset_cls(file_path)
+ else:
+ if 'nv' in trainer_name:
+ dataset_cls = NovelViewDataset # 1.5-2iter/s
+ else:
+ dataset_cls = MultiViewDataset
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+
+ loader = DataLoader(dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ shuffle=False)
+ return loader
+
+
+def load_data(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_lmdb=False,
+ infi_sampler=True
+):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # st()
+ if use_lmdb:
+ logger.log('using LMDB dataset')
+ # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
+ if 'nv' in trainer_name:
+ dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ else:
+ dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
+ # dataset = dataset_cls(file_path)
+ else:
+ if 'nv' in trainer_name:
+ dataset_cls = NovelViewDataset # 1.5-2iter/s
+ else:
+ dataset_cls = MultiViewDataset
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+
+ # st()
+
+ if infi_sampler:
+ train_sampler = DistributedSampler(dataset=dataset,
+ shuffle=True,
+ drop_last=True)
+
+ loader = DataLoader(dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=True,
+ pin_memory=True,
+ persistent_workers=num_workers > 0,
+ sampler=train_sampler)
+
+ while True:
+ yield from loader
+
+ else:
+ # loader = DataLoader(dataset,
+ # batch_size=batch_size,
+ # num_workers=num_workers,
+ # drop_last=False,
+ # pin_memory=True,
+ # persistent_workers=num_workers > 0,
+ # shuffle=False)
+ st()
+ return dataset
+
+
+def load_eval_rays(file_path="",
+ reso=64,
+ reso_encoder=224,
+ imgnet_normalize=True):
+ dataset = MultiViewDataset(file_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=imgnet_normalize)
+ pose_list = dataset.single_pose_list
+ ray_list = []
+ for pose_fname in pose_list:
+ # c2w = dataset.get_c2w(pose_fname).reshape(1,4,4) #[1, 4, 4]
+ # rays_o, rays_d = dataset.gen_rays(c2w)
+ # ray_list.append(
+ # [rays_o.unsqueeze(0),
+ # rays_d.unsqueeze(0),
+ # c2w.reshape(-1, 16)])
+
+ c2w = dataset.get_c2w(pose_fname).reshape(16) #[1, 4, 4]
+
+ c = torch.cat([c2w, dataset.intrinsics],
+ dim=0).reshape(25) # 25, no '1' dim needed.
+ ray_list.append(c)
+
+ return ray_list
+
+
+def load_eval_data(file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ num_workers=1,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ interval=1, **kwargs):
+
+ dataset = MultiViewDataset(file_path,
+ reso,
+ reso_encoder,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ test=True,
+ imgnet_normalize=imgnet_normalize,
+ interval=interval)
+ print('eval dataset size: {}'.format(len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ shuffle=False,
+ )
+ # sampler=train_sampler)
+ return loader
+
+
+def load_memory_data(file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ num_workers=1,
+ load_depth=True,
+ preprocess=None,
+ imgnet_normalize=True):
+ # load a single-instance into the memory to speed up training IO
+ dataset = MultiViewDataset(file_path,
+ reso,
+ reso_encoder,
+ preprocess=preprocess,
+ load_depth=True,
+ test=False,
+ overfitting=True,
+ imgnet_normalize=imgnet_normalize,
+ overfitting_bs=batch_size)
+ logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=len(dataset),
+ num_workers=num_workers,
+ drop_last=False,
+ shuffle=False,
+ )
+
+ all_data: dict = next(iter(loader))
+ while True:
+ start_idx = np.random.randint(0, len(dataset) - batch_size + 1)
+ yield {
+ k: v[start_idx:start_idx + batch_size]
+ for k, v in all_data.items()
+ }
+
+
+class MultiViewDataset(Dataset):
+
+ def __init__(self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1,
+ interval=1):
+ self.file_path = file_path
+ self.overfitting = overfitting
+ self.scene_scale = scene_scale
+ self.reso = reso
+ self.reso_encoder = reso_encoder
+ self.classes = False
+ self.load_depth = load_depth
+ self.preprocess = preprocess
+ assert not self.classes, "Not support class condition now."
+
+ # self.ins_list = os.listdir(self.file_path)
+ # if test: # TODO
+
+ dataset_name = Path(self.file_path).stem.split('_')[0]
+ self.dataset_name = dataset_name
+
+ if test:
+ # ins_list_file = Path(self.file_path).parent / f'{dataset_name}_test_list.txt' # ? in domain
+ if dataset_name == 'chair':
+ self.ins_list = sorted(os.listdir(
+ self.file_path))[1:2] # more diversity
+ else:
+ self.ins_list = sorted(os.listdir(self.file_path))[
+ 0:1] # the first 1 instance for evaluation reference.
+ else:
+ # self.ins_list = sorted(Path(self.file_path).glob('[0-8]*'))
+ # self.ins_list = Path(self.file_path).glob('*')
+ # self.ins_list = list(Path(self.file_path).glob('*'))[:dataset_size]
+
+ # ins_list_file = Path(
+ # self.file_path).parent / f'{dataset_name}s_train_list.txt'
+ # assert ins_list_file.exists(), 'add training list for ShapeNet'
+ # with open(ins_list_file, 'r') as f:
+ # self.ins_list = [name.strip() for name in f.readlines()]
+
+ # if dataset_name == 'chair':
+ ins_list_file = Path(
+ self.file_path).parent / f'{dataset_name}_train_list.txt'
+ # st()
+ assert ins_list_file.exists(), 'add training list for ShapeNet'
+ with open(ins_list_file, 'r') as f:
+ self.ins_list = [name.strip()
+ for name in f.readlines()][:dataset_size]
+ # else:
+ # self.ins_list = Path(self.file_path).glob('*')
+
+ if overfitting:
+ self.ins_list = self.ins_list[:1]
+
+ self.rgb_list = []
+ self.pose_list = []
+ self.depth_list = []
+ self.data_ins_list = []
+ self.instance_data_length = -1
+ for ins in self.ins_list:
+ cur_rgb_path = os.path.join(self.file_path, ins, 'rgb')
+ cur_pose_path = os.path.join(self.file_path, ins, 'pose')
+
+ cur_all_fname = sorted([
+ t.split('.')[0] for t in os.listdir(cur_rgb_path)
+ if 'depth' not in t
+ ][::interval])
+ if self.instance_data_length == -1:
+ self.instance_data_length = len(cur_all_fname)
+ else:
+ assert len(cur_all_fname) == self.instance_data_length
+
+ # ! check filtered data
+ # for idx in range(len(cur_all_fname)):
+ # fname = cur_all_fname[idx]
+ # if not Path(os.path.join(cur_rgb_path, fname + '.png') ).exists():
+ # cur_all_fname.remove(fname)
+
+ # del cur_all_fname[idx]
+
+ if test:
+ mid_index = len(cur_all_fname) // 3 * 2
+ cur_all_fname.insert(0, cur_all_fname[mid_index])
+
+ self.pose_list += ([
+ os.path.join(cur_pose_path, fname + '.txt')
+ for fname in cur_all_fname
+ ])
+ self.rgb_list += ([
+ os.path.join(cur_rgb_path, fname + '.png')
+ for fname in cur_all_fname
+ ])
+
+ self.depth_list += ([
+ os.path.join(cur_rgb_path, fname + '_depth0001.exr')
+ for fname in cur_all_fname
+ ])
+ self.data_ins_list += ([ins] * len(cur_all_fname))
+
+ # validate overfitting on images
+ if overfitting:
+ # bs=9
+ # self.pose_list = self.pose_list[::50//9+1]
+ # self.rgb_list = self.rgb_list[::50//9+1]
+ # self.depth_list = self.depth_list[::50//9+1]
+ # bs=6
+ # self.pose_list = self.pose_list[::50//6+1]
+ # self.rgb_list = self.rgb_list[::50//6+1]
+ # self.depth_list = self.depth_list[::50//6+1]
+ # bs=3
+ assert overfitting_bs != -1
+ # bs=1
+ # self.pose_list = self.pose_list[25:26]
+ # self.rgb_list = self.rgb_list[25:26]
+ # self.depth_list = self.depth_list[25:26]
+
+ # uniform pose sampling
+ self.pose_list = self.pose_list[::50//overfitting_bs+1]
+ self.rgb_list = self.rgb_list[::50//overfitting_bs+1]
+ self.depth_list = self.depth_list[::50//overfitting_bs+1]
+
+ # sequentially sampling pose
+ # self.pose_list = self.pose_list[25:25+overfitting_bs]
+ # self.rgb_list = self.rgb_list[25:25+overfitting_bs]
+ # self.depth_list = self.depth_list[25:25+overfitting_bs]
+
+ # duplicate the same pose
+ # self.pose_list = [self.pose_list[25]] * overfitting_bs
+ # self.rgb_list = [self.rgb_list[25]] * overfitting_bs
+ # self.depth_list = [self.depth_list[25]] * overfitting_bs
+ # self.pose_list = [self.pose_list[28]] * overfitting_bs
+ # self.rgb_list = [self.rgb_list[28]] * overfitting_bs
+ # self.depth_list = [self.depth_list[28]] * overfitting_bs
+
+ self.single_pose_list = [
+ os.path.join(cur_pose_path, fname + '.txt')
+ for fname in cur_all_fname
+ ]
+
+ # st()
+
+ # if imgnet_normalize:
+ transformations = [
+ transforms.ToTensor(), # [0,1] range
+ ]
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+
+ # self.normalize_normalrange = transforms.Compose([
+ # transforms.ToTensor(),# [0,1] range
+ # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ # ])
+
+ fx = fy = 525
+ cx = cy = 256 # rendering default K
+ factor = self.reso / (cx * 2) # 128 / 512
+ self.fx = fx * factor
+ self.fy = fy * factor
+ self.cx = cx * factor
+ self.cy = cy * factor
+
+ # ! fix scale for triplane ray_sampler(), here we adopt [0,1] uv range, not [0, w] img space range.
+ self.cx /= self.reso # 0.5
+ self.cy /= self.reso # 0.5
+ self.fx /= self.reso
+ self.fy /= self.reso
+
+ intrinsics = np.array([[self.fx, 0, self.cx], [0, self.fy, self.cy],
+ [0, 0, 1]]).reshape(9)
+ # self.intrinsics = torch.from_numpy(intrinsics).float()
+ self.intrinsics = intrinsics
+
+ def __len__(self):
+ return len(self.rgb_list)
+
+ def get_c2w(self, pose_fname):
+ with open(pose_fname, 'r') as f:
+ cam2world = f.readline().strip()
+ cam2world = [float(t) for t in cam2world.split(' ')]
+ c2w = torch.tensor(cam2world, dtype=torch.float32).reshape(4, 4)
+ return c2w
+
+ def gen_rays(self, c2w):
+ # Generate rays
+ self.h = self.reso
+ self.w = self.reso
+ yy, xx = torch.meshgrid(
+ torch.arange(self.h, dtype=torch.float32) + 0.5,
+ torch.arange(self.w, dtype=torch.float32) + 0.5,
+ indexing='ij')
+ xx = (xx - self.cx) / self.fx
+ yy = (yy - self.cy) / self.fy
+ zz = torch.ones_like(xx)
+ dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
+ dirs /= torch.norm(dirs, dim=-1, keepdim=True)
+ dirs = dirs.reshape(1, -1, 3, 1)
+ del xx, yy, zz
+ dirs = (c2w[:, None, :3, :3] @ dirs)[..., 0]
+
+ origins = c2w[:, None, :3, 3].expand(-1, self.h * self.w,
+ -1).contiguous()
+ origins = origins.view(-1, 3)
+ dirs = dirs.view(-1, 3)
+
+ return origins, dirs
+
+ def read_depth(self, idx):
+ depth_path = self.depth_list[idx]
+ # image_path = os.path.join(depth_fname, self.image_names[index])
+ exr = OpenEXR.InputFile(depth_path)
+ header = exr.header()
+ size = (header['dataWindow'].max.x - header['dataWindow'].min.x + 1,
+ header['dataWindow'].max.y - header['dataWindow'].min.y + 1)
+ FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
+ depth_str = exr.channel('B', FLOAT)
+ depth = np.frombuffer(depth_str,
+ dtype=np.float32).reshape(size[1],
+ size[0]) # H W
+ depth = np.nan_to_num(depth, posinf=0, neginf=0)
+ depth = depth.reshape(size)
+
+ def resize_depth_mask(depth_to_resize, resolution):
+ depth_resized = cv2.resize(depth_to_resize,
+ (resolution, resolution),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ return depth_resized > 0 # type: ignore
+
+ fg_mask_reso = resize_depth_mask(depth, self.reso)
+ fg_mask_sr = resize_depth_mask(depth, 128)
+
+ # depth = cv2.resize(depth, (self.reso, self.reso),
+ # interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ # depth_mask = depth > 0
+ # depth = np.expand_dims(depth, axis=0).reshape(size)
+ # return torch.from_numpy(depth)
+ return torch.from_numpy(depth), torch.from_numpy(
+ fg_mask_reso), torch.from_numpy(fg_mask_sr)
+
+ def load_bbox(self, mask):
+ nonzero_value = torch.nonzero(mask)
+ height, width = nonzero_value.max(dim=0)[0]
+ top, left = nonzero_value.min(dim=0)[0]
+ bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
+ return bbox
+
+ def __getitem__(self, idx):
+ rgb_fname = self.rgb_list[idx]
+ pose_fname = self.pose_list[idx]
+
+ raw_img = imageio.imread(rgb_fname)
+
+ if self.preprocess is None:
+ img_to_encoder = cv2.resize(raw_img,
+ (self.reso_encoder, self.reso_encoder),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ img_to_encoder = img_to_encoder[
+ ..., :3] #[3, reso_encoder, reso_encoder]
+ img_to_encoder = self.normalize(img_to_encoder)
+ else:
+ img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip
+
+ img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+
+ # img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA)
+ # img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
+ # img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
+ img_sr = cv2.resize(
+ raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4
+ ) # just as refinement, since eg3d uses 64->128 final resolution
+
+ # img = torch.from_numpy(img)[..., :3].permute(
+ # 2, 0, 1) / 255.0 #[3, reso, reso]
+
+ img = torch.from_numpy(img)[..., :3].permute(
+ 2, 0, 1
+ ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ img_sr = torch.from_numpy(img_sr)[..., :3].permute(
+ 2, 0, 1
+ ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ # c2w = self.get_c2w(pose_fname).reshape(1, 4, 4) #[1, 4, 4]
+ # rays_o, rays_d = self.gen_rays(c2w)
+ # return img_to_encoder, img, rays_o, rays_d, c2w.reshape(-1)
+
+ c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16]
+ # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
+ c = torch.cat([c2w, torch.from_numpy(self.intrinsics)],
+ dim=0).reshape(25) # 25, no '1' dim needed.
+ ret_dict = {
+ # 'rgb_fname': rgb_fname,
+ 'img_to_encoder': img_to_encoder,
+ 'img': img,
+ 'c': c,
+ 'img_sr': img_sr,
+ # 'ins_name': self.data_ins_list[idx]
+ }
+ if self.load_depth:
+ depth, depth_mask, depth_mask_sr = self.read_depth(idx)
+ bbox = self.load_bbox(depth_mask)
+ ret_dict.update({
+ 'depth': depth,
+ 'depth_mask': depth_mask,
+ 'depth_mask_sr': depth_mask_sr,
+ 'bbox': bbox
+ })
+ # rays_o, rays_d = self.gen_rays(c2w)
+ # return img_to_encoder, img, c
+ return ret_dict
+
+
+class MultiViewDatasetforLMDB(MultiViewDataset):
+
+ def __init__(self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1):
+ super().__init__(file_path, reso, reso_encoder, preprocess, classes,
+ load_depth, test, scene_scale, overfitting,
+ imgnet_normalize, dataset_size, overfitting_bs)
+
+ def __len__(self):
+ return super().__len__()
+ # return 100 # for speed debug
+
+ def __getitem__(self, idx):
+ # ret_dict = super().__getitem__(idx)
+ rgb_fname = self.rgb_list[idx]
+ pose_fname = self.pose_list[idx]
+ raw_img = imageio.imread(rgb_fname)[..., :3]
+
+ c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16]
+ # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
+ c = torch.cat([c2w, torch.from_numpy(self.intrinsics)],
+ dim=0).reshape(25) # 25, no '1' dim needed.
+
+ depth, depth_mask, depth_mask_sr = self.read_depth(idx)
+ bbox = self.load_bbox(depth_mask)
+ ret_dict = {
+ 'raw_img': raw_img,
+ 'c': c,
+ 'depth': depth,
+ # 'depth_mask': depth_mask, # 64x64 here?
+ 'bbox': bbox
+ }
+ return ret_dict
+
+
+def load_data_dryrun(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True):
+ # st()
+ dataset = MultiViewDataset(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize)
+ print('dataset size: {}'.format(len(dataset)))
+ # st()
+ # train_sampler = DistributedSampler(dataset=dataset)
+ loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ # shuffle=shuffle,
+ drop_last=False,
+ )
+ # sampler=train_sampler)
+
+ return loader
+
+
+class NovelViewDataset(MultiViewDataset):
+ """novel view prediction version.
+ """
+
+ def __init__(self,
+ file_path,
+ reso,
+ reso_encoder,
+ preprocess=None,
+ classes=False,
+ load_depth=False,
+ test=False,
+ scene_scale=1,
+ overfitting=False,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ overfitting_bs=-1):
+ super().__init__(file_path, reso, reso_encoder, preprocess, classes,
+ load_depth, test, scene_scale, overfitting,
+ imgnet_normalize, dataset_size, overfitting_bs)
+
+ def __getitem__(self, idx):
+ input_view = super().__getitem__(
+ idx) # get previous input view results
+
+ # get novel view of the same instance
+ novel_view = super().__getitem__(
+ (idx // self.instance_data_length) * self.instance_data_length +
+ random.randint(0, self.instance_data_length - 1))
+
+ # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
+
+ input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
+ return input_view
+
+
+def load_data_for_lmdb(
+ file_path="",
+ reso=64,
+ reso_encoder=224,
+ batch_size=1,
+ # shuffle=True,
+ num_workers=6,
+ load_depth=False,
+ preprocess=None,
+ imgnet_normalize=True,
+ dataset_size=-1,
+ trainer_name='input_rec'):
+ # st()
+ # dataset_cls = {
+ # 'input_rec': MultiViewDataset,
+ # 'nv': NovelViewDataset,
+ # }[trainer_name]
+ # if 'nv' in trainer_name:
+ # dataset_cls = NovelViewDataset
+ # else:
+ # dataset_cls = MultiViewDataset
+ dataset_cls = MultiViewDatasetforLMDB
+
+ dataset = dataset_cls(file_path,
+ reso,
+ reso_encoder,
+ test=False,
+ preprocess=preprocess,
+ load_depth=load_depth,
+ imgnet_normalize=imgnet_normalize,
+ dataset_size=dataset_size)
+
+ logger.log('dataset_cls: {}, dataset size: {}'.format(
+ trainer_name, len(dataset)))
+ # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
+ loader = DataLoader(
+ dataset,
+ shuffle=False,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ prefetch_factor=2,
+ # prefetch_factor=3,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ # sampler=train_sampler)
+
+ # while True:
+ # yield from loader
+ return loader, dataset.dataset_name, len(dataset)
+
+
+class LMDBDataset(Dataset):
+
+ def __init__(self, lmdb_path):
+ self.env = lmdb.open(
+ lmdb_path,
+ readonly=True,
+ max_readers=32,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ self.num_samples = self.env.stat()['entries']
+ # self.start_idx = self.env.stat()['start_idx']
+ # self.end_idx = self.env.stat()['end_idx']
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ with self.env.begin(write=False) as txn:
+ key = str(idx).encode('utf-8')
+ value = txn.get(key)
+
+ sample = pickle.loads(value)
+ return sample
+
+
+def resize_depth_mask(depth_to_resize, resolution):
+ depth_resized = cv2.resize(depth_to_resize, (resolution, resolution),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ return depth_resized, depth_resized > 0 # type: ignore
+
+
+class LMDBDataset_MV(LMDBDataset):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ **kwargs):
+ super().__init__(lmdb_path)
+
+ self.reso_encoder = reso_encoder
+ self.reso = reso
+
+ transformations = [
+ transforms.ToTensor(), # [0,1] range
+ ]
+ if imgnet_normalize:
+ transformations.append(
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225)) # type: ignore
+ )
+ else:
+ transformations.append(
+ transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))) # type: ignore
+
+ self.normalize = transforms.Compose(transformations)
+
+ def _post_process_sample(self, raw_img, depth):
+
+ # if raw_img.shape[-1] == 4: # ! set bg to white
+ # alpha_mask = raw_img[..., -1:] > 0
+ # raw_img = alpha_mask * raw_img[..., :3] + (1-alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
+ # raw_img = raw_img.astype(np.uint8)
+
+ # img_to_encoder = cv2.resize(sample.pop('raw_img'),
+ img_to_encoder = cv2.resize(raw_img,
+ (self.reso_encoder, self.reso_encoder),
+ interpolation=cv2.INTER_LANCZOS4)
+ # interpolation=cv2.INTER_AREA)
+ img_to_encoder = img_to_encoder[..., :
+ 3] #[3, reso_encoder, reso_encoder]
+ img_to_encoder = self.normalize(img_to_encoder)
+
+ img = cv2.resize(raw_img, (self.reso, self.reso),
+ interpolation=cv2.INTER_LANCZOS4)
+
+ if img.shape[-1] == 4:
+ alpha_mask = img[..., -1:] > 0
+ img = alpha_mask * img[..., :3] + (1-alpha_mask) * np.ones_like(img[..., :3]) * 255
+
+ img = torch.from_numpy(img)[..., :3].permute(
+ 2, 0, 1
+ ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ img_sr = torch.from_numpy(raw_img)[..., :3].permute(
+ 2, 0, 1
+ ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
+
+ # depth
+ # fg_mask_reso = resize_depth_mask(sample['depth'], self.reso)
+ depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso)
+
+ return {
+ # **sample,
+ 'img_to_encoder': img_to_encoder,
+ 'img': img,
+ 'depth_mask': fg_mask_reso,
+ 'img_sr': img_sr,
+ 'depth': depth_reso,
+ # ! no need to load img_sr for now
+ }
+
+ def __getitem__(self, idx):
+ sample = super().__getitem__(idx)
+ # do transformations online
+
+ return self._post_process_sample(sample['raw_img'], sample['depth'])
+ # return sample
+
+def load_bytes(inp_bytes, dtype, shape):
+ return np.frombuffer(inp_bytes, dtype=dtype).reshape(shape).copy()
+
+# Function to decompress an image using gzip and open with imageio
+def decompress_and_open_image_gzip(compressed_data, is_img=False):
+ # Decompress the image data using gzip
+ decompressed_data = gzip.decompress(compressed_data)
+
+ # Read the decompressed image using imageio
+ if is_img:
+ image = imageio.v3.imread(io.BytesIO(decompressed_data)).copy()
+ return image
+ return decompressed_data
+
+
+# Function to decompress an array using gzip
+def decompress_array(compressed_data, shape, dtype):
+ # Decompress the array data using gzip
+ decompressed_data = gzip.decompress(compressed_data)
+
+ # Convert the decompressed data to a NumPy array
+ # arr = np.frombuffer(decompressed_data, dtype=dtype).reshape(shape)
+
+ return load_bytes(decompressed_data, dtype, shape)
+
+
+class LMDBDataset_MV_Compressed(LMDBDataset_MV):
+
+ def __init__(self,
+ lmdb_path,
+ reso,
+ reso_encoder,
+ imgnet_normalize=True,
+ **kwargs):
+ super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
+ **kwargs)
+ with self.env.begin(write=False) as txn:
+ self.length = int(
+ txn.get('length'.encode('utf-8')).decode('utf-8')) - 40
+
+ self.load_image_fn = partial(decompress_and_open_image_gzip,
+ is_img=True)
+
+ def __len__(self):
+ return self.length
+
+ def _load_lmdb_data(self, idx):
+
+ with self.env.begin(write=False) as txn:
+ raw_img_key = f'{idx}-raw_img'.encode('utf-8')
+ raw_img = self.load_image_fn(txn.get(raw_img_key))
+
+ depth_key = f'{idx}-depth'.encode('utf-8')
+ depth = decompress_array(txn.get(depth_key), (512,512), np.float32)
+
+ c_key = f'{idx}-c'.encode('utf-8')
+ c = decompress_array(txn.get(c_key), (25, ), np.float32)
+
+ bbox_key = f'{idx}-bbox'.encode('utf-8')
+ bbox = decompress_array(txn.get(bbox_key), (4, ), np.float32)
+
+ return raw_img, depth, c, bbox
+
+ def __getitem__(self, idx):
+ # sample = super(LMDBDataset).__getitem__(idx)
+
+ # do gzip uncompress online
+ raw_img, depth, c, bbox = self._load_lmdb_data(idx)
+
+ return {
+ **self._post_process_sample(raw_img, depth), 'c': c,
+ 'bbox': bbox*(self.reso/64.0),
+ # 'depth': depth,
+ }
+
+
+class LMDBDataset_NV_Compressed(LMDBDataset_MV_Compressed):
+ def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, **kwargs):
+ super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, **kwargs)
+ self.instance_data_length = 50 #
+
+ def __getitem__(self, idx):
+ input_view = super().__getitem__(
+ idx) # get previous input view results
+
+ # get novel view of the same instance
+ try:
+ novel_view = super().__getitem__(
+ (idx // self.instance_data_length) * self.instance_data_length +
+ random.randint(0, self.instance_data_length - 1))
+ except Exception as e:
+ raise NotImplementedError(idx)
+
+ assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
+
+ input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
+ return input_view
\ No newline at end of file
diff --git a/dit/__init__.py b/dit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a0c6f23f5e81eff58f021f2cf16204772d4e54
--- /dev/null
+++ b/dit/__init__.py
@@ -0,0 +1 @@
+from .dit_3d import DiT_models
\ No newline at end of file
diff --git a/dit/__pycache__/__init__.cpython-39.pyc b/dit/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f24c3d159b798b5cfe6217acb856ec456f8d62c6
Binary files /dev/null and b/dit/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dit/__pycache__/dit_3d.cpython-39.pyc b/dit/__pycache__/dit_3d.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abc32a40477e622801f50f097f9dd744d15b12d1
Binary files /dev/null and b/dit/__pycache__/dit_3d.cpython-39.pyc differ
diff --git a/dit/__pycache__/dit_decoder.cpython-39.pyc b/dit/__pycache__/dit_decoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57eaa3dc471ef3eb7f5a15a1e2ac921f471a2ef2
Binary files /dev/null and b/dit/__pycache__/dit_decoder.cpython-39.pyc differ
diff --git a/dit/__pycache__/dit_models.cpython-39.pyc b/dit/__pycache__/dit_models.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a88ca87524e4a947eec6cc8ad78a6a4cc4277a0
Binary files /dev/null and b/dit/__pycache__/dit_models.cpython-39.pyc differ
diff --git a/dit/__pycache__/dit_models_xformers.cpython-39.pyc b/dit/__pycache__/dit_models_xformers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67c120fd4bc3042a33ec5ca0c34119522623d5f2
Binary files /dev/null and b/dit/__pycache__/dit_models_xformers.cpython-39.pyc differ
diff --git a/dit/dit_3d.py b/dit/dit_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..324099b19f4f83dcfadafd06b5faf20f716252cf
--- /dev/null
+++ b/dit/dit_3d.py
@@ -0,0 +1,212 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+
+from pdb import set_trace as st
+
+from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed
+
+
+class DiT_Triplane_V1(DiT):
+ """
+ 1. merge the 3*H*W as L, and 8 as C only
+ 2. pachify, flat into 224*(224*3) with 8 channels for pachify
+ 3. unpachify accordingly
+ """
+
+ def __init__(self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=False):
+
+ input_size = (input_size, input_size*3)
+ super().__init__(input_size, patch_size, in_channels//3, hidden_size, # type: ignore
+ depth, num_heads, mlp_ratio, class_dropout_prob,
+ num_classes, learn_sigma)
+
+ def initialize_weights(self):
+ """all the same except the PE part
+ """
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1], self.x_embedder.grid_size)
+ # st()
+ self.pos_embed.data.copy_(
+ torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # ! untouched below
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ if self.y_embedder is not None:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ # TODO
+ """
+ x: (N, L, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0] # type: ignore
+ h = w = int((x.shape[1]//3)**0.5)
+ assert h * w * 3 == x.shape[1] # merge triplane 3 dims with hw
+
+ x = x.reshape(shape=(x.shape[0], h, w, 3, p, p, c))
+ x = torch.einsum('nhwzpqc->nczhpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c*3, h * p, h * p)) # type: ignore
+ return imgs # B 8*3 H W
+
+ def forward(self, x, t, y=None):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+
+ # ! merge tri-channel into w chanenl for 3D-aware TX
+ x = x.reshape(x.shape[0], -1, 3, x.shape[2], x.shape[3]) # B 8 3 H W
+ x = x.permute(0,1,3,4,2).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1]*3) # B 8 H W83
+
+ x = self.x_embedder(
+ x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(t) # (N, D)
+
+ if self.y_embedder is not None:
+ assert y is not None
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ else:
+ c = t
+
+ for block in self.blocks:
+ x = block(x, c) # (N, T, D)
+
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ return x
+
+
+
+
+class DiT_Triplane_V1_learnedPE(DiT_Triplane_V1):
+ """
+ 1. learned PE, default cos/sin wave
+ """
+
+ def __init__(self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True):
+ super().__init__(input_size, patch_size, in_channels, hidden_size,
+ depth, num_heads, mlp_ratio, class_dropout_prob,
+ num_classes, learn_sigma)
+
+
+class DiT_Triplane_V1_fixed3DPE(DiT_Triplane_V1):
+ """
+ 1. 3D aware PE, fixed
+ """
+
+ def __init__(self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True):
+ super().__init__(input_size, patch_size, in_channels, hidden_size,
+ depth, num_heads, mlp_ratio, class_dropout_prob,
+ num_classes, learn_sigma)
+
+
+class DiT_Triplane_V1_learned3DPE(DiT_Triplane_V1):
+ """
+ 1. init with 3D aware PE, learnable
+ """
+
+ def __init__(self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True):
+ super().__init__(input_size, patch_size, in_channels, hidden_size,
+ depth, num_heads, mlp_ratio, class_dropout_prob,
+ num_classes, learn_sigma)
+
+def V1_Triplane_DiT_S_2(**kwargs):
+ return DiT_Triplane_V1(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+def V1_Triplane_DiT_S_4(**kwargs):
+ return DiT_Triplane_V1(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+def V1_Triplane_DiT_S_8(**kwargs):
+ return DiT_Triplane_V1(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+def V1_Triplane_DiT_B_8(**kwargs):
+ return DiT_Triplane_V1(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+def V1_Triplane_DiT_B_16(**kwargs): # ours cfg
+ return DiT_Triplane_V1(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs)
+
+DiT_models.update({
+ 'v1-T-DiT-S/2': V1_Triplane_DiT_S_2,
+ 'v1-T-DiT-S/4': V1_Triplane_DiT_S_4,
+ 'v1-T-DiT-S/8': V1_Triplane_DiT_S_8,
+ 'v1-T-DiT-B/8': V1_Triplane_DiT_B_8,
+ 'v1-T-DiT-B/16': V1_Triplane_DiT_B_16,
+})
\ No newline at end of file
diff --git a/dit/dit_decoder.py b/dit/dit_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2c350ff3e482ddc27b6d9c44d883bbb2cc3dc3
--- /dev/null
+++ b/dit/dit_decoder.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+
+from einops import rearrange
+from pdb import set_trace as st
+
+# from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer
+
+from .dit_models_xformers import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer
+# from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer
+
+
+def modulate2(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class DiTBlock2(DiTBlock):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs):
+ super().__init__(hidden_size, num_heads, mlp_ratio, **block_kwargs)
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ c).chunk(6, dim=-1)
+ # st()
+ x = x + gate_msa * self.attn(
+ modulate2(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp * self.mlp(
+ modulate2(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer2(FinalLayer):
+ """
+ The final layer of DiT, basically the decoder_pred in MAE with adaLN.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__(hidden_size, patch_size, out_channels)
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate2(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DiT2(DiT):
+ # a conditional ViT
+ def __init__(self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ mixing_logit_init=-3,
+ mixed_prediction=True,
+ context_dim=False,
+ roll_out=False,
+ plane_n=3,
+ return_all_layers=False,
+ vit_blk=...):
+ super().__init__(input_size,
+ patch_size,
+ in_channels,
+ hidden_size,
+ depth,
+ num_heads,
+ mlp_ratio,
+ class_dropout_prob,
+ num_classes,
+ learn_sigma,
+ mixing_logit_init,
+ mixed_prediction,
+ context_dim,
+ roll_out,
+ vit_blk=DiTBlock2,
+ final_layer_blk=FinalLayer2)
+
+ # no t and x embedder
+ del self.x_embedder
+ del self.t_embedder
+ del self.final_layer
+ torch.cuda.empty_cache()
+ self.clip_text_proj = None
+ self.plane_n = plane_n
+ self.return_all_layers = return_all_layers
+
+ def forward(self, c, *args, **kwargs):
+ # return super().forward(x, timesteps, context, y, get_attr, **kwargs)
+ """
+ Forward pass of DiT.
+ c: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ """
+ x = self.pos_embed.repeat(
+ c.shape[0], 1, 1) # (N, T, D), where T = H * W / patch_size ** 2
+
+ if self.return_all_layers:
+ all_layers = []
+
+ # if context is not None:
+ # c = context # B 3HW C
+
+ for blk_idx, block in enumerate(self.blocks):
+ if self.roll_out:
+ if blk_idx % 2 == 0: # with-in plane self attention
+ x = rearrange(x, 'b (n l) c -> (b n) l c ', n=self.plane_n)
+ x = block(x,
+ rearrange(c,
+ 'b (n l) c -> (b n) l c ',
+ n=self.plane_n)) # (N, T, D)
+ # st()
+ if self.return_all_layers:
+ all_layers.append(x)
+ else: # global attention
+ x = rearrange(x, '(b n) l c -> b (n l) c ', n=self.plane_n)
+ x = block(x, c) # (N, T, D)
+ # st()
+ if self.return_all_layers:
+ # all merged into B dim
+ all_layers.append(
+ rearrange(x,
+ 'b (n l) c -> (b n) l c',
+ n=self.plane_n))
+ else:
+ x = block(x, c) # (N, T, D)
+
+ # x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+
+ # if self.roll_out: # move n from L to B axis
+ # x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
+
+ # x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ # if self.roll_out: # move n from L to B axis
+ # x = rearrange(x, '(b n) c h w -> b (n c) h w', n=3)
+
+ if self.return_all_layers:
+ return all_layers
+ else:
+ return x
+
+
+# class DiT2_DPT(DiT2):
+# def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, mixing_logit_init=-3, mixed_prediction=True, context_dim=False, roll_out=False, plane_n=3, vit_blk=...):
+# super().__init__(input_size, patch_size, in_channels, hidden_size, depth, num_heads, mlp_ratio, class_dropout_prob, num_classes, learn_sigma, mixing_logit_init, mixed_prediction, context_dim, roll_out, plane_n, vit_blk)
+# self.return_all_layers = True
+
+#################################################################################
+# DiT2 Configs #
+#################################################################################
+
+
+def DiT2_XL_2(**kwargs):
+ return DiT2(depth=28,
+ hidden_size=1152,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_XL_2_half(**kwargs):
+ return DiT2(depth=28 // 2,
+ hidden_size=1152,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_XL_4(**kwargs):
+ return DiT2(depth=28,
+ hidden_size=1152,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_XL_8(**kwargs):
+ return DiT2(depth=28,
+ hidden_size=1152,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_L_2(**kwargs):
+ return DiT2(depth=24,
+ hidden_size=1024,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_L_2_half(**kwargs):
+ return DiT2(depth=24 // 2,
+ hidden_size=1024,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_L_4(**kwargs):
+ return DiT2(depth=24,
+ hidden_size=1024,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_L_8(**kwargs):
+ return DiT2(depth=24,
+ hidden_size=1024,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT2_B_2(**kwargs):
+ return DiT2(depth=12,
+ hidden_size=768,
+ patch_size=2,
+ num_heads=12,
+ **kwargs)
+
+
+def DiT2_B_4(**kwargs):
+ return DiT2(depth=12,
+ hidden_size=768,
+ patch_size=4,
+ num_heads=12,
+ **kwargs)
+
+
+def DiT2_B_8(**kwargs):
+ return DiT2(depth=12,
+ hidden_size=768,
+ patch_size=8,
+ num_heads=12,
+ **kwargs)
+
+
+def DiT2_B_16(**kwargs): # ours cfg
+ return DiT2(depth=12,
+ hidden_size=768,
+ patch_size=16,
+ num_heads=12,
+ **kwargs)
+
+
+def DiT2_S_2(**kwargs):
+ return DiT2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+
+def DiT2_S_4(**kwargs):
+ return DiT2(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+
+def DiT2_S_8(**kwargs):
+ return DiT2(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+DiT2_models = {
+ 'DiT2-XL/2': DiT2_XL_2,
+ 'DiT2-XL/2/half': DiT2_XL_2_half,
+ 'DiT2-XL/4': DiT2_XL_4,
+ 'DiT2-XL/8': DiT2_XL_8,
+ 'DiT2-L/2': DiT2_L_2,
+ 'DiT2-L/2/half': DiT2_L_2_half,
+ 'DiT2-L/4': DiT2_L_4,
+ 'DiT2-L/8': DiT2_L_8,
+ 'DiT2-B/2': DiT2_B_2,
+ 'DiT2-B/4': DiT2_B_4,
+ 'DiT2-B/8': DiT2_B_8,
+ 'DiT2-B/16': DiT2_B_16,
+ 'DiT2-S/2': DiT2_S_2,
+ 'DiT2-S/4': DiT2_S_4,
+ 'DiT2-S/8': DiT2_S_8,
+}
diff --git a/dit/dit_models copy.py b/dit/dit_models copy.py
new file mode 100644
index 0000000000000000000000000000000000000000..125f0444656cfc21a8d61406eaca7d6b96847dce
--- /dev/null
+++ b/dit/dit_models copy.py
@@ -0,0 +1,697 @@
+# https://github.com/facebookresearch/DiT
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+# from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
+from timm.models.vision_transformer import PatchEmbed, Mlp
+from einops import rearrange
+from pdb import set_trace as st
+
+# support flash attention and xformer acceleration
+from vit.vision_transformer import MemEffAttention as Attention
+
+# from torch.nn import LayerNorm
+# from xformers import triton
+# import xformers.triton
+# from xformers.triton import FusedLayerNorm as LayerNorm
+# from xformers.components.activations import build_activation, Activation
+# from xformers.components.feedforward import fused_mlp
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) /
+ half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding,
+ hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0],
+ device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class ClipProjector(nn.Module):
+
+ def __init__(self, transformer_width, embed_dim, tx_width, *args,
+ **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ '''a CLIP text encoder projector, adapted from CLIP.encode_text
+ '''
+
+ self.text_projection = nn.Parameter(
+ torch.empty(transformer_width, embed_dim))
+ nn.init.normal_(self.text_projection, std=tx_width**-0.5)
+
+ def forward(self, clip_text_x):
+ return clip_text_x @ self.text_projection
+
+
+#################################################################################
+# Core DiT Model #
+#################################################################################
+
+# class DiTBlock(nn.Module):
+# """
+# A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+# """
+
+# def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+# super().__init__()
+# nn.LayerNorm
+# self.norm1 = LayerNorm(
+# hidden_size,
+# affine=False,
+# # elementwise_affine=False,
+# eps=1e-6)
+# self.attn = Attention(hidden_size,
+# num_heads=num_heads,
+# qkv_bias=True,
+# **block_kwargs)
+# self.norm2 = LayerNorm(
+# hidden_size,
+# # elementwise_affine=False,
+# affine=False,
+# eps=1e-6)
+
+# mlp_hidden_dim = int(hidden_size * mlp_ratio)
+# approx_gelu = lambda: nn.GELU(approximate="tanh")
+
+# self.mlp = Mlp(in_features=hidden_size,
+# hidden_features=mlp_hidden_dim,
+# act_layer=approx_gelu,
+# drop=0)
+
+# # self.mlp = fused_mlp.FusedMLP(
+# # dim_model=hidden_size,
+# # dropout=0,
+# # activation=Activation.GeLU,
+# # hidden_layer_multiplier=mlp_ratio,
+# # )
+
+# self.adaLN_modulation = nn.Sequential(
+# nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+# def forward(self, x, c):
+# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+# c).chunk(6, dim=1)
+# x = x + gate_msa.unsqueeze(1) * self.attn(
+# modulate(self.norm1(x), shift_msa, scale_msa))
+# x = x + gate_mlp.unsqueeze(1) * self.mlp(
+# modulate(self.norm2(x), shift_mlp, scale_mlp))
+# return x
+
+
+class DiTBlock(nn.Module):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size,
+ elementwise_affine=False,
+ eps=1e-6)
+ self.attn = Attention(hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size,
+ elementwise_affine=False,
+ eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(
+ modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
+ modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class DiTBlockRollOut(DiTBlock):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs):
+ super().__init__(hidden_size * 3, num_heads, mlp_ratio, **block_kwargs)
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(
+ modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
+ modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT, basically the decoder_pred in MAE with adaLN.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(
+ hidden_size,
+ # self.norm_final = LayerNorm(
+ hidden_size,
+ elementwise_affine=False,)
+ # affine=False,
+ # eps=1e-6)
+ self.linear = nn.Linear(hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ mixing_logit_init=-3,
+ mixed_prediction=True,
+ context_dim=False,
+ roll_out=False,
+ vit_blk=DiTBlock,
+ final_layer_blk=FinalLayer,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.embed_dim = hidden_size
+
+ # st()
+ self.x_embedder = PatchEmbed(input_size,
+ patch_size,
+ in_channels,
+ hidden_size,
+ bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ if num_classes > 0:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size,
+ class_dropout_prob)
+ else:
+ self.y_embedder = None
+
+ if context_dim is not None:
+ self.clip_text_proj = ClipProjector(context_dim,
+ hidden_size,
+ tx_width=depth)
+ else:
+ self.clip_text_proj = None
+
+ self.roll_out = roll_out
+
+ num_patches = self.x_embedder.num_patches # 14*14*3
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size),
+ requires_grad=False)
+
+ # if not self.roll_out:
+ self.blocks = nn.ModuleList([
+ vit_blk(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth)
+ ])
+ # else:
+ # self.blocks = nn.ModuleList([
+ # DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) if idx % 2 == 0 else
+ # DiTBlockRollOut(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ # for idx in range(depth)
+ # ])
+
+ self.final_layer = final_layer_blk(hidden_size, patch_size,
+ self.out_channels)
+ self.initialize_weights()
+
+ self.mixed_prediction = mixed_prediction # This enables mixed prediction
+ if self.mixed_prediction:
+ if self.roll_out:
+ logit_ch = in_channels * 3
+ else:
+ logit_ch = in_channels
+ init = mixing_logit_init * torch.ones(
+ size=[1, logit_ch, 1, 1]) # hard coded for now
+ self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
+ # st()
+ self.pos_embed.data.copy_(
+ torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ if self.y_embedder is not None:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ # p = self.x_embedder.patch_size[0]
+ p = self.patch_size
+ h = w = int(x.shape[1]**0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ # def forward(self, x, t, y=None, get_attr=''):
+ def forward(self,
+ x,
+ timesteps=None,
+ context=None,
+ y=None,
+ get_attr='',
+ **kwargs):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ # t = timesteps
+
+ if get_attr != '': # not breaking the forward hooks
+ return getattr(self, get_attr)
+
+ t = self.t_embedder(timesteps) # (N, D)
+
+ if self.roll_out: # !
+ x = rearrange(x, 'b (n c) h w->(b n) c h w', n=3)
+
+ x = self.x_embedder(
+ x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+
+ if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens.
+ x = rearrange(x, '(b n) l c ->b (n l) c', n=3)
+
+ if self.y_embedder is not None:
+ assert y is not None
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ elif context is not None:
+ assert context.ndim == 2
+ context = self.clip_text_proj(context)
+
+ if context.shape[0] < t.shape[
+ 0]: # same caption context for different view input of the same ID
+ context = torch.repeat_interleave(context,
+ t.shape[0] //
+ context.shape[0],
+ dim=0)
+
+ # if context.ndim == 3: # compat version from SD
+ # context = context[:, 0, :]
+ c = t + context
+ else:
+ c = t # BS 1024
+
+ for blk_idx, block in enumerate(self.blocks):
+ # if self.roll_out:
+ # if blk_idx % 2 == 0: # with-in plane self attention
+ # x = rearrange(x, 'b (n l) c -> b l (n c) ', n=3)
+ # x = block(x, torch.repeat_interleave(c, 3, 0)) # (N, T, D)
+ # else: # global attention
+ # # x = rearrange(x, '(b n) l c -> b (n l) c ', n=3)
+ # x = rearrange(x, 'b l (n c) -> b (n l) c ', n=3)
+ # x = block(x, c) # (N, T, D)
+ # else:
+ x = block(x, c) # (N, T, D)
+
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+
+ if self.roll_out: # move n from L to B axis
+ x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
+
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ if self.roll_out: # move n from L to B axis
+ x = rearrange(x, '(b n) c h w -> b (n c) h w', n=3)
+ # x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)
+
+ return x
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[:len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ # half = x[:len(x) // 2]
+ # combined = torch.cat([half, half], dim=0)
+ combined = x
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
+ # cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ # half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ # eps = torch.cat([half_eps, half_eps], dim=0)
+ # return torch.cat([eps, rest], dim=1)
+ # st()
+ return model_out
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+
+def get_2d_sincos_pos_embed(embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
+ grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
+ grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# DiT Configs #
+#################################################################################
+
+
+def DiT_XL_2(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_XL_4(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_XL_8(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_2(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_4(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_8(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_B_2(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+
+def DiT_B_4(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+
+def DiT_B_8(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+
+def DiT_B_16(**kwargs): # ours cfg
+ return DiT(depth=12,
+ hidden_size=768,
+ patch_size=16,
+ num_heads=12,
+ **kwargs)
+
+
+def DiT_S_2(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+
+def DiT_S_4(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+
+def DiT_S_8(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+DiT_models = {
+ 'DiT-XL/2': DiT_XL_2,
+ 'DiT-XL/4': DiT_XL_4,
+ 'DiT-XL/8': DiT_XL_8,
+ 'DiT-L/2': DiT_L_2,
+ 'DiT-L/4': DiT_L_4,
+ 'DiT-L/8': DiT_L_8,
+ 'DiT-B/2': DiT_B_2,
+ 'DiT-B/4': DiT_B_4,
+ 'DiT-B/8': DiT_B_8,
+ 'DiT-B/16': DiT_B_16,
+ 'DiT-S/2': DiT_S_2,
+ 'DiT-S/4': DiT_S_4,
+ 'DiT-S/8': DiT_S_8,
+}
diff --git a/dit/dit_models.py b/dit/dit_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..e39de76536e9204a221decae2aacf1c340942457
--- /dev/null
+++ b/dit/dit_models.py
@@ -0,0 +1,508 @@
+# https://github.com/facebookresearch/DiT
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
+
+from pdb import set_trace as st
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) /
+ half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding,
+ hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0],
+ device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+#################################################################################
+# Core DiT Model #
+#################################################################################
+
+
+class DiTBlock(nn.Module):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size,
+ elementwise_affine=False,
+ eps=1e-6)
+ self.attn = Attention(hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ **block_kwargs)
+ self.norm2 = nn.LayerNorm(hidden_size,
+ elementwise_affine=False,
+ eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(
+ modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
+ modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT, basically the decoder_pred in MAE with adaLN.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size,
+ elementwise_affine=False,
+ eps=1e-6)
+ self.linear = nn.Linear(hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+
+ self.x_embedder = PatchEmbed(input_size,
+ patch_size,
+ in_channels,
+ hidden_size,
+ bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ if num_classes > 0:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size,
+ class_dropout_prob)
+ else:
+ self.y_embedder = None
+ num_patches = self.x_embedder.num_patches # 14*14*3
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size),
+ requires_grad=False)
+
+ self.blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth)
+ ])
+ self.final_layer = FinalLayer(hidden_size, patch_size,
+ self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
+ # st()
+ self.pos_embed.data.copy_(
+ torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ if self.y_embedder is not None:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ p = self.x_embedder.patch_size[0]
+ h = w = int(x.shape[1]**0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ def forward(self, x, t, y=None):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+
+ x = self.x_embedder(
+ x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ t = self.t_embedder(t) # (N, D)
+
+ if self.y_embedder is not None:
+ assert y is not None
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ else:
+ c = t
+
+ for block in self.blocks:
+ x = block(x, c) # (N, T, D)
+
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ return x
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[:len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ # half = x[:len(x) // 2]
+ # combined = torch.cat([half, half], dim=0)
+ combined = x
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
+ # cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ # half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ # eps = torch.cat([half_eps, half_eps], dim=0)
+ # return torch.cat([eps, rest], dim=1)
+ # st()
+ return model_out
+
+
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+
+def get_2d_sincos_pos_embed(embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
+ grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
+ grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# DiT Configs #
+#################################################################################
+
+
+def DiT_XL_2(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_XL_4(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_XL_8(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_2(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_4(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_8(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_B_2(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+
+def DiT_B_4(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+
+def DiT_B_8(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+def DiT_B_16(**kwargs): # ours cfg
+ return DiT(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs)
+
+def DiT_S_2(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+
+def DiT_S_4(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+
+def DiT_S_8(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+DiT_models = {
+ 'DiT-XL/2': DiT_XL_2,
+ 'DiT-XL/4': DiT_XL_4,
+ 'DiT-XL/8': DiT_XL_8,
+ 'DiT-L/2': DiT_L_2,
+ 'DiT-L/4': DiT_L_4,
+ 'DiT-L/8': DiT_L_8,
+ 'DiT-B/2': DiT_B_2,
+ 'DiT-B/4': DiT_B_4,
+ 'DiT-B/8': DiT_B_8,
+ 'DiT-B/16': DiT_B_16,
+ 'DiT-S/2': DiT_S_2,
+ 'DiT-S/4': DiT_S_4,
+ 'DiT-S/8': DiT_S_8,
+}
diff --git a/dit/dit_models_xformers.py b/dit/dit_models_xformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f51ed003c442d3f68b821c4b323aca34fcd11e26
--- /dev/null
+++ b/dit/dit_models_xformers.py
@@ -0,0 +1,664 @@
+# https://github.com/facebookresearch/DiT
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# GLIDE: https://github.com/openai/glide-text2im
+# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+# from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
+from timm.models.vision_transformer import PatchEmbed, Mlp
+from einops import rearrange
+from pdb import set_trace as st
+
+# support flash attention and xformer acceleration
+from vit.vision_transformer import MemEffAttention as Attention
+
+# from torch.nn import LayerNorm
+# from xformers import triton
+# import xformers.triton
+
+if torch.cuda.is_available():
+ from xformers.triton import FusedLayerNorm as LayerNorm
+ from xformers.components.activations import build_activation, Activation
+ from xformers.components.feedforward import fused_mlp
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) /
+ half).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class LabelEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding,
+ hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0],
+ device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class ClipProjector(nn.Module):
+
+ def __init__(self, transformer_width, embed_dim, tx_width, *args,
+ **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ '''a CLIP text encoder projector, adapted from CLIP.encode_text
+ '''
+
+ self.text_projection = nn.Parameter(
+ torch.empty(transformer_width, embed_dim))
+ nn.init.normal_(self.text_projection, std=tx_width**-0.5)
+
+ def forward(self, clip_text_x):
+ return clip_text_x @ self.text_projection
+
+
+#################################################################################
+# Core DiT Model #
+#################################################################################
+
+
+class DiTBlock(nn.Module):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ nn.LayerNorm
+ self.norm1 = LayerNorm(
+ hidden_size,
+ affine=False,
+ # elementwise_affine=False,
+ eps=1e-6)
+ self.attn = Attention(hidden_size,
+ num_heads=num_heads,
+ qkv_bias=True,
+ **block_kwargs)
+ self.norm2 = LayerNorm(
+ hidden_size,
+ # elementwise_affine=False,
+ affine=False,
+ eps=1e-6)
+ # mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # approx_gelu = lambda: nn.GELU(approximate="tanh")
+
+ # self.mlp = Mlp(in_features=hidden_size,
+ # hidden_features=mlp_hidden_dim,
+ # act_layer=approx_gelu,
+ # drop=0)
+
+ self.mlp = fused_mlp.FusedMLP(
+ dim_model=hidden_size,
+ dropout=0,
+ activation=Activation.GeLU,
+ hidden_layer_multiplier=int(mlp_ratio),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(
+ modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
+ modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class DiTBlockRollOut(DiTBlock):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs):
+ super().__init__(hidden_size * 3, num_heads, mlp_ratio, **block_kwargs)
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ c).chunk(6, dim=1)
+ x = x + gate_msa.unsqueeze(1) * self.attn(
+ modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
+ modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT, basically the decoder_pred in MAE with adaLN.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ # self.norm_final = nn.LayerNorm(hidden_size,
+ self.norm_final = LayerNorm(
+ hidden_size,
+ # elementwise_affine=False,
+ affine=False,
+ eps=1e-6)
+ self.linear = nn.Linear(hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class DiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ input_size=32,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ mixing_logit_init=-3,
+ mixed_prediction=True,
+ context_dim=False,
+ roll_out=False,
+ vit_blk=DiTBlock,
+ final_layer_blk=FinalLayer,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+ self.embed_dim = hidden_size
+
+ self.x_embedder = PatchEmbed(input_size,
+ patch_size,
+ in_channels,
+ hidden_size,
+ bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ if num_classes > 0:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size,
+ class_dropout_prob)
+ else:
+ self.y_embedder = None
+
+ if context_dim is not None:
+ self.clip_text_proj = ClipProjector(context_dim,
+ hidden_size,
+ tx_width=depth)
+ else:
+ self.clip_text_proj = None
+
+ self.roll_out = roll_out
+
+ num_patches = self.x_embedder.num_patches # 14*14*3
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size),
+ requires_grad=False)
+
+ # if not self.roll_out:
+ self.blocks = nn.ModuleList([
+ vit_blk(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth)
+ ])
+ # else:
+ # self.blocks = nn.ModuleList([
+ # DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) if idx % 2 == 0 else
+ # DiTBlockRollOut(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ # for idx in range(depth)
+ # ])
+
+ self.final_layer = final_layer_blk(hidden_size, patch_size,
+ self.out_channels)
+ self.initialize_weights()
+
+ self.mixed_prediction = mixed_prediction # This enables mixed prediction
+ if self.mixed_prediction:
+ if self.roll_out:
+ logit_ch = in_channels * 3
+ else:
+ logit_ch = in_channels
+ init = mixing_logit_init * torch.ones(
+ size=[1, logit_ch, 1, 1]) # hard coded for now
+ self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
+
+ # def len(self):
+ # return len(self.blocks)
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5))
+ # st()
+ self.pos_embed.data.copy_(
+ torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ if self.y_embedder is not None:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+ # p = self.x_embedder.patch_size[0]
+ p = self.patch_size
+ h = w = int(x.shape[1]**0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
+ return imgs
+
+ # def forward(self, x, t, y=None, get_attr=''):
+ def forward(self,
+ x,
+ timesteps=None,
+ context=None,
+ y=None,
+ get_attr='',
+ **kwargs):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+ # t = timesteps
+
+ if get_attr != '': # not breaking the forward hooks
+ return getattr(self, get_attr)
+
+ t = self.t_embedder(timesteps) # (N, D)
+
+ if self.roll_out: # !
+ x = rearrange(x, 'b (n c) h w->(b n) c h w', n=3)
+
+ x = self.x_embedder(
+ x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+
+ if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens.
+ x = rearrange(x, '(b n) l c ->b (n l) c', n=3)
+
+ if self.y_embedder is not None:
+ assert y is not None
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ elif context is not None:
+ assert context.ndim == 2
+ context = self.clip_text_proj(context)
+
+ if context.shape[0] < t.shape[
+ 0]: # same caption context for different view input of the same ID
+ context = torch.repeat_interleave(context,
+ t.shape[0] //
+ context.shape[0],
+ dim=0)
+
+ # if context.ndim == 3: # compat version from SD
+ # context = context[:, 0, :]
+ c = t + context
+ else:
+ c = t # BS 1024
+
+ for blk_idx, block in enumerate(self.blocks):
+ # if self.roll_out:
+ # if blk_idx % 2 == 0: # with-in plane self attention
+ # x = rearrange(x, 'b (n l) c -> b l (n c) ', n=3)
+ # x = block(x, torch.repeat_interleave(c, 3, 0)) # (N, T, D)
+ # else: # global attention
+ # # x = rearrange(x, '(b n) l c -> b (n l) c ', n=3)
+ # x = rearrange(x, 'b l (n c) -> b (n l) c ', n=3)
+ # x = block(x, c) # (N, T, D)
+ # else:
+ x = block(x, c) # (N, T, D)
+
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+
+ if self.roll_out: # move n from L to B axis
+ x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
+
+ x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ if self.roll_out: # move n from L to B axis
+ x = rearrange(x, '(b n) c h w -> b (n c) h w', n=3)
+ # x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)
+
+ return x
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[:len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ # half = x[:len(x) // 2]
+ # combined = torch.cat([half, half], dim=0)
+ combined = x
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
+ # cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ # half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ # eps = torch.cat([half_eps, half_eps], dim=0)
+ # return torch.cat([eps, rest], dim=1)
+ # st()
+ return model_out
+
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+
+def get_2d_sincos_pos_embed(embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
+ grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
+ grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+#################################################################################
+# DiT Configs #
+#################################################################################
+
+
+def DiT_XL_2(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_XL_4(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_XL_8(**kwargs):
+ return DiT(depth=28,
+ hidden_size=1152,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_2(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=2,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_4(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=4,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_L_8(**kwargs):
+ return DiT(depth=24,
+ hidden_size=1024,
+ patch_size=8,
+ num_heads=16,
+ **kwargs)
+
+
+def DiT_B_2(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+
+def DiT_B_4(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+
+def DiT_B_8(**kwargs):
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+
+def DiT_B_16(**kwargs): # ours cfg
+ return DiT(depth=12,
+ hidden_size=768,
+ patch_size=16,
+ num_heads=12,
+ **kwargs)
+
+
+def DiT_S_2(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+
+def DiT_S_4(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+
+def DiT_S_8(**kwargs):
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+DiT_models = {
+ 'DiT-XL/2': DiT_XL_2,
+ 'DiT-XL/4': DiT_XL_4,
+ 'DiT-XL/8': DiT_XL_8,
+ 'DiT-L/2': DiT_L_2,
+ 'DiT-L/4': DiT_L_4,
+ 'DiT-L/8': DiT_L_8,
+ 'DiT-B/2': DiT_B_2,
+ 'DiT-B/4': DiT_B_4,
+ 'DiT-B/8': DiT_B_8,
+ 'DiT-B/16': DiT_B_16,
+ 'DiT-S/2': DiT_S_2,
+ 'DiT-S/4': DiT_S_4,
+ 'DiT-S/8': DiT_S_8,
+}
diff --git a/dit/dit_wo_embedder.py b/dit/dit_wo_embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a40a05b2b071d9b564551d46b78bafe43156121
--- /dev/null
+++ b/dit/dit_wo_embedder.py
@@ -0,0 +1,439 @@
+import torch
+import torch.nn as nn
+# import numpy as np
+# import math
+# from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
+
+from .dit_models import TimestepEmbedder, LabelEmbedder, DiTBlock, get_2d_sincos_pos_embed
+
+
+class DiTwoEmbedder(nn.Module):
+ """
+ Diffusion model with a Transformer backbone, performing directly on the ViT token latents rather than spatial latents.
+ """
+
+ def __init__(
+ self,
+ input_size=224, # raw img input size
+ # patch_size=14, # dino version
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = 14 # dino-v2 patch sized fixed in this project
+ self.num_heads = num_heads
+
+ # self.x_embedder = PatchEmbed(input_size,
+ # patch_size,
+ # in_channels,
+ # hidden_size,
+ # bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ if num_classes > 0:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size,
+ class_dropout_prob)
+ else:
+ self.y_embedder = None
+
+ # num_patches = self.x_embedder.num_patches # 14*14*3
+ self.num_patches = (input_size // self.patch_size)**2
+
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches,
+ hidden_size),
+ requires_grad=False)
+
+ self.blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth)
+ ])
+ # self.final_layer = FinalLayer(hidden_size, patch_size,
+ # self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
+ int(self.num_patches**0.5))
+ # st()
+ self.pos_embed.data.copy_(
+ torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ # w = self.x_embedder.proj.weight.data
+ # nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ # nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ if self.y_embedder is not None:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ # nn.init.constant_(self.final_layer.linear.weight, 0)
+ # nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y=None):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+
+ # ! no embedder operation
+ # x = self.x_embedder(
+ # x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ x = x + self.pos_embed
+
+ t = self.t_embedder(t) # (N, D)
+
+ if self.y_embedder is not None:
+ assert y is not None
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ else:
+ c = t
+
+ for block in self.blocks:
+ x = block(x, c) # (N, T, D)
+
+ # x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ # x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ return x
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[:len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+
+ combined = x
+ model_out = self.forward(combined, t, y)
+
+ return model_out
+
+
+class DiTwoEmbedderLongSkipConnection(nn.Module):
+
+ def __init__(
+ self,
+ input_size=224, # raw img input size
+ patch_size=14, # dino version
+ in_channels=4,
+ hidden_size=1152,
+ depth=28,
+ num_heads=16,
+ mlp_ratio=4.0,
+ class_dropout_prob=0.1,
+ num_classes=1000,
+ learn_sigma=True,
+ ):
+ """DiT with long skip-connections from U-ViT, CVPR 23'
+ """
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
+ self.patch_size = patch_size
+ self.num_heads = num_heads
+
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ if num_classes > 0:
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size,
+ class_dropout_prob)
+ else:
+ self.y_embedder = None
+
+ # num_patches = self.x_embedder.num_patches # 14*14*3
+ self.num_patches = (input_size // patch_size)**2
+
+ # Will use fixed sin-cos embedding:
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches,
+ hidden_size),
+ requires_grad=False)
+
+ self.blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth)
+ ])
+
+ # ! add long-skip-connections from U-ViT
+ self.in_blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth // 2)
+ ])
+
+ self.mid_block = DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+
+ self.out_blocks = nn.ModuleList([
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(depth // 2)
+ ])
+
+ # ! needed or to be replaced?
+ # self.final_layer = FinalLayer(hidden_size, patch_size,
+ # self.out_channels)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
+ int(self.num_patches**0.5))
+ # st()
+ self.pos_embed.data.copy_(
+ torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ # w = self.x_embedder.proj.weight.data
+ # nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ # nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ if self.y_embedder is not None:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers in DiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ # nn.init.constant_(self.final_layer.linear.weight, 0)
+ # nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y=None):
+ """
+ Forward pass of DiT.
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of class labels
+ """
+
+ # ! no embedder operation
+ # x = self.x_embedder(
+ # x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
+ x = x + self.pos_embed
+
+ t = self.t_embedder(t) # (N, D)
+
+ if self.y_embedder is not None:
+ assert y is not None
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ else:
+ c = t
+
+ # ! add long-skip-connections here
+
+ # for block in self.blocks:
+ # x = block(x, c) # (N, T, D)
+
+ skips = []
+ for blk in self.in_blocks:
+ x = blk(x)
+ skips.append(x)
+
+ x = self.mid_block(x)
+
+ for blk in self.out_blocks:
+ x = blk(x, skips.pop())
+
+ # ! the order of unpatchify and final_linear swaps in the baseline implementation
+ # x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ # x = self.unpatchify(x) # (N, out_channels, H, W)
+
+ return x
+
+ def forward_with_cfg(self, x, t, y, cfg_scale):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[:len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, y)
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ def forward_with_cfg_unconditional(self, x, t, y=None, cfg_scale=None):
+ """
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+
+ combined = x
+ model_out = self.forward(combined, t, y)
+
+ return model_out
+
+
+#################################################################################
+# DiT Configs #
+#################################################################################
+
+# def DiT_XL_2(**kwargs):
+# return DiT(depth=28,
+# hidden_size=1152,
+# patch_size=2,
+# num_heads=16,
+# **kwargs)
+
+# def DiT_XL_4(**kwargs):
+# return DiT(depth=28,
+# hidden_size=1152,
+# patch_size=4,
+# num_heads=16,
+# **kwargs)
+
+# def DiT_XL_8(**kwargs):
+# return DiT(depth=28,
+# hidden_size=1152,
+# patch_size=8,
+# num_heads=16,
+# **kwargs)
+
+# def DiT_L_2(**kwargs):
+# return DiT(depth=24,
+# hidden_size=1024,
+# patch_size=2,
+# num_heads=16,
+# **kwargs)
+
+# def DiT_L_4(**kwargs):
+# return DiT(depth=24,
+# hidden_size=1024,
+# patch_size=4,
+# num_heads=16,
+# **kwargs)
+
+# def DiT_L_8(**kwargs):
+# return DiT(depth=24,
+# hidden_size=1024,
+# patch_size=8,
+# num_heads=16,
+# **kwargs)
+
+# def DiT_B_2(**kwargs):
+# return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
+
+# def DiT_B_4(**kwargs):
+# return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
+
+# def DiT_B_8(**kwargs):
+# return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
+
+# def DiT_B_16(**kwargs): # ours cfg
+# return DiT(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs)
+
+# def DiT_S_2(**kwargs):
+# return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
+
+# def DiT_S_4(**kwargs):
+# return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
+
+# def DiT_S_8(**kwargs):
+# return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
+
+
+def DiT_woembed_S(**kwargs):
+ return DiTwoEmbedder(depth=12, hidden_size=384, num_heads=6, **kwargs)
+
+
+def DiT_woembed_B(**kwargs):
+ return DiTwoEmbedder(depth=12, hidden_size=768, num_heads=12, **kwargs)
+
+
+def DiT_woembed_L(**kwargs):
+ return DiTwoEmbedder(
+ depth=24,
+ hidden_size=1024,
+ num_heads=16,
+ **kwargs)
+
+
+DiT_woembed_models = {
+ # 'DiT-XL/2': DiT_XL_2,
+ # 'DiT-XL/4': DiT_XL_4,
+ # 'DiT-XL/8': DiT_XL_8,
+ # 'DiT-L/2': DiT_L_2,
+ # 'DiT-L/4': DiT_L_4,
+ # 'DiT-L/8': DiT_L_8,
+ # 'DiT-B/2': DiT_B_2,
+ # 'DiT-B/4': DiT_B_4,
+ # 'DiT-B/8': DiT_B_8,
+ # 'DiT-B/16': DiT_B_16,
+ # 'DiT-S/2': DiT_S_2,
+ # 'DiT-S/4': DiT_S_4,
+ # 'DiT-S/8': DiT_S_8,
+ 'DiT-wo-S': DiT_woembed_S,
+ 'DiT-wo-B': DiT_woembed_B,
+ 'DiT-wo-L': DiT_woembed_L,
+}
diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd91ed142e955581e83948455fb71cd837215f61
--- /dev/null
+++ b/dnnlib/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/dnnlib/__pycache__/__init__.cpython-39.pyc b/dnnlib/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab980ad2af130f46020346228accd9b9ed84f0a3
Binary files /dev/null and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dnnlib/__pycache__/util.cpython-39.pyc b/dnnlib/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ca7b6081015d2944ccd9579f78b89061535801b
Binary files /dev/null and b/dnnlib/__pycache__/util.cpython-39.pyc differ
diff --git a/dnnlib/util.py b/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2430b83469eca9217de68e9ab3fa7cdec14bc5b4
--- /dev/null
+++ b/dnnlib/util.py
@@ -0,0 +1,590 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+import torch
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max=1.0):
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
+ return d_weight
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self,
+ file_name: str = None,
+ file_mode: str = "w",
+ should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(
+ text
+ ) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib',
+ *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60,
+ s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60),
+ (s // (60 * 60)) % 24,
+ (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60),
+ (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:]))
+ for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(
+ module_name) # may raise ImportError
+ get_obj_from_module(module,
+ local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" +
+ module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(
+ module_name) # may raise ImportError
+ get_obj_from_module(module,
+ local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[
+ obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(
+ os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+
+def list_dir_recursively_with_ignore(
+ dir_path: str,
+ ignores: List[str] = None,
+ add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [
+ os.path.join(base_name, p) for p in relative_paths
+ ]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str,
+ cache_dir: str = None,
+ num_attempts: int = 10,
+ verbose: bool = True,
+ return_filename: bool = False,
+ cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get(
+ "Set-Cookie", ""):
+ links = [
+ html.unescape(link)
+ for link in content_str.split('"')
+ if "export=download" in link
+ ]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError(
+ "Google Drive download quota exceeded -- please try again later"
+ )
+
+ match = re.search(
+ r'filename="([^"]*)"',
+ res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(
+ cache_dir,
+ "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
+
+class InfiniteSampler(torch.utils.data.Sampler):
+
+ def __init__(self,
+ dataset,
+ rank=0,
+ num_replicas=1,
+ shuffle=True,
+ seed=0,
+ window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+def requires_grad(model, flag=True):
+ for p in model.parameters():
+ p.requires_grad = flag
diff --git a/environment_ln3diff.yml b/environment_ln3diff.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0e270edfad8c88e4ac57744721f1bef70036f599
--- /dev/null
+++ b/environment_ln3diff.yml
@@ -0,0 +1,320 @@
+name: ln3diff
+channels:
+ - xformers
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=2_kmp_llvm
+ - blas=1.0=mkl
+ - brotlipy=0.7.0=py39h27cfd23_1003
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2023.01.10=h06a4308_0
+ - certifi=2022.12.7=py39h06a4308_0
+ - cffi=1.15.1=py39h5eee18b_3
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - cryptography=39.0.1=py39h9ce1e76_0
+ - cuda-cudart=11.7.99=0
+ - cuda-cupti=11.7.101=0
+ - cuda-libraries=11.7.1=0
+ - cuda-nvrtc=11.7.99=0
+ - cuda-nvtx=11.7.91=0
+ - cuda-runtime=11.7.1=0
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.9.0=py39h06a4308_0
+ - flit-core=3.8.0=py39h06a4308_0
+ - freetype=2.12.1=h4a9f257_0
+ - giflib=5.2.1=h5eee18b_3
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py39heeb90bb_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.4=py39h06a4308_0
+ - ilmbase=2.5.5=h780b84a_0
+ - imath=3.1.7=h3eb15da_1
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - jinja2=3.1.2=py39h06a4308_0
+ - jpeg=9e=h5eee18b_1
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - lerc=3.0=h295c915_0
+ - libcublas=11.10.3.66=0
+ - libcufft=10.7.2.124=h4fbf590_0
+ - libcufile=1.6.0.25=0
+ - libcurand=10.3.2.56=0
+ - libcusolver=11.4.0.1=0
+ - libcusparse=11.7.4.91=0
+ - libdeflate=1.17=h5eee18b_0
+ - libffi=3.4.2=h6a678d5_6
+ - libgcc-ng=12.2.0=h65d4601_19
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.2=h7f8727e_0
+ - libnpp=11.7.4.75=0
+ - libnvjpeg=11.8.0.2=0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=12.2.0=h46fd767_19
+ - libtasn1=4.16.0=h27cfd23_0
+ - libtiff=4.5.0=h6a678d5_2
+ - libunistring=0.9.10=h27cfd23_0
+ - libwebp=1.2.4=h11a3e52_1
+ - libwebp-base=1.2.4=h5eee18b_1
+ - libzlib=1.2.13=h166bdaf_4
+ - llvm-openmp=16.0.1=h417c0b6_0
+ - lz4-c=1.9.4=h6a678d5_0
+ - markupsafe=2.1.1=py39h7f8727e_0
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py39h7f8727e_0
+ - mkl_fft=1.3.1=py39hd3c417c_0
+ - mkl_random=1.2.2=py39h51133e4_0
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.2.1=py39h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=2.8.4=py39h06a4308_1
+ - numpy=1.23.5=py39h14f4228_0
+ - numpy-base=1.23.5=py39h31eccc5_0
+ - openexr=3.1.7=h5c7bc04_0
+ - openexr-python=1.3.9=py39hbd9bb45_1
+ - openh264=2.1.1=h4ff587b_0
+ - openssl=1.1.1t=h7f8727e_0
+ - pillow=9.4.0=py39h6a678d5_0
+ - pip=23.0.1=py39h06a4308_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyopenssl=23.0.0=py39h06a4308_0
+ - pysocks=1.7.1=py39h06a4308_0
+ - python=3.9.16=h7a1cb2a_2
+ - python_abi=3.9=2_cp39
+ - pytorch=2.0.0=py3.9_cuda11.7_cudnn8.5.0_0
+ - pytorch-cuda=11.7=h778d358_3
+ - pytorch-mutex=1.0=cuda
+ - readline=8.2=h5eee18b_0
+ - requests=2.28.1=py39h06a4308_1
+ - setuptools=65.6.3=py39h06a4308_0
+ - six=1.16.0=pyhd3eb1b0_1
+ - sqlite=3.41.1=h5eee18b_0
+ - sympy=1.11.1=py39h06a4308_0
+ - tk=8.6.12=h1ccaba5_0
+ - torchaudio=2.0.0=py39_cu117
+ - torchtriton=2.0.0=py39
+ - torchvision=0.15.0=py39_cu117
+ - tzdata=2022g=h04d1e81_0
+ - urllib3=1.26.14=py39h06a4308_0
+ - wheel=0.38.4=py39h06a4308_0
+ - xformers=0.0.18=py39_cu11.8.0_pyt2.0.0
+ - xz=5.2.10=h5eee18b_1
+ - zlib=1.2.13=h166bdaf_4
+ - zstd=1.5.2=ha4553b6_0
+ - pip:
+ - absl-py==1.4.0
+ - accelerate==0.26.1
+ - aiohttp==3.9.3
+ - aiosignal==1.3.1
+ - antialiased-cnns==0.3
+ - antlr4-python3-runtime==4.9.3
+ - anyio==3.7.1
+ - appdirs==1.4.4
+ - argon2-cffi==23.1.0
+ - argon2-cffi-bindings==21.2.0
+ - arrow==1.2.3
+ - asttokens==2.2.1
+ - async-lru==2.0.4
+ - async-timeout==4.0.3
+ - attrs==23.1.0
+ - av==10.0.0
+ - babel==2.12.1
+ - backcall==0.2.0
+ - beartype==0.13.1
+ - beautifulsoup4==4.12.2
+ - bleach==6.0.0
+ - blobfile==2.0.1
+ - braceexpand==0.1.7
+ - cachetools==5.3.0
+ - click==8.1.3
+ - clip==1.0
+ - comm==0.1.4
+ - contourpy==1.1.0
+ - cycler==0.11.0
+ - debugpy==1.6.7.post1
+ - decorator==5.1.1
+ - defusedxml==0.7.1
+ - diff-gaussian-rasterization==0.0.0
+ - diffusers==0.15.0
+ - docker-pycreds==0.4.0
+ - einops==0.6.0
+ - exceptiongroup==1.1.3
+ - executing==2.0.1
+ - fastjsonschema==2.18.0
+ - filterpy==1.4.5
+ - fonttools==4.42.1
+ - fqdn==1.5.1
+ - frozenlist==1.4.1
+ - fsspec==2023.9.2
+ - ftfy==6.1.1
+ - gdown==4.4.0
+ - gitdb==4.0.10
+ - gitpython==3.1.31
+ - google-auth==2.17.3
+ - google-auth-oauthlib==1.0.0
+ - grpcio==1.53.0
+ - h11==0.14.0
+ - httpcore==1.0.4
+ - httpx==0.27.0
+ - huggingface-hub==0.20.3
+ - imageio==2.27.0
+ - imageio-ffmpeg==0.4.8
+ - importlib-metadata==6.4.1
+ - importlib-resources==6.0.1
+ - ipdb==0.13.13
+ - ipykernel==6.25.1
+ - ipython==8.13.2
+ - ipywidgets==8.1.2
+ - isoduration==20.11.0
+ - jax==0.4.14
+ - jaxlib==0.4.14+cuda11.cudnn86
+ - jedi==0.18.2
+ - joblib==1.2.0
+ - json5==0.9.14
+ - jsonpointer==2.4
+ - jsonschema==4.19.0
+ - jsonschema-specifications==2023.7.1
+ - jupyter==1.0.0
+ - jupyter-client==8.3.0
+ - jupyter-console==6.6.3
+ - jupyter-core==5.3.1
+ - jupyter-events==0.7.0
+ - jupyter-lsp==2.2.0
+ - jupyter-server==2.7.2
+ - jupyter-server-terminals==0.4.4
+ - jupyterlab==4.1.2
+ - jupyterlab-pygments==0.2.2
+ - jupyterlab-server==2.24.0
+ - jupyterlab-widgets==3.0.10
+ - kiui==0.2.7
+ - kiwisolver==1.4.4
+ - kornia==0.6.11
+ - lazy-loader==0.2
+ - lightning-utilities==0.11.2
+ - llvmlite==0.40.1
+ - lmdb==1.4.1
+ - lxml==4.9.2
+ - lz4==4.3.3
+ - markdown==3.4.3
+ - matplotlib==3.7.2
+ - matplotlib-inline==0.1.6
+ - mistune==3.0.1
+ - ml-dtypes==0.2.0
+ - mrcfile==1.5.0
+ - multidict==6.0.5
+ - nbclient==0.8.0
+ - nbconvert==7.7.4
+ - nbformat==5.9.2
+ - nest-asyncio==1.5.7
+ - ninja==1.11.1
+ - notebook==7.1.1
+ - notebook-shim==0.2.3
+ - numba==0.57.1
+ - nvidia-cublas-cu11==11.11.3.6
+ - nvidia-cuda-cupti-cu11==11.8.87
+ - nvidia-cuda-nvcc-cu11==11.8.89
+ - nvidia-cuda-nvrtc-cu11==11.8.89
+ - nvidia-cuda-runtime-cu11==11.8.89
+ - nvidia-cudnn-cu11==8.9.4.25
+ - nvidia-cufft-cu11==10.9.0.58
+ - nvidia-cusolver-cu11==11.4.1.48
+ - nvidia-cusparse-cu11==11.7.5.86
+ - oauthlib==3.2.2
+ - objprint==0.2.3
+ - omegaconf==2.3.0
+ - open-clip-torch==2.24.0
+ - opencv-python==4.7.0.72
+ - opt-einsum==3.3.0
+ - overrides==7.4.0
+ - packaging==23.1
+ - pandocfilters==1.5.0
+ - parso==0.8.3
+ - pathtools==0.1.2
+ - pexpect==4.8.0
+ - pickleshare==0.7.5
+ - platformdirs==3.10.0
+ - plyfile==1.0.3
+ - point-cloud-utils==0.30.4
+ - prometheus-client==0.17.1
+ - prompt-toolkit==3.0.38
+ - protobuf==4.22.3
+ - psutil==5.9.4
+ - ptflops==0.7
+ - ptyprocess==0.7.0
+ - pure-eval==0.2.2
+ - pyasn1==0.4.8
+ - pyasn1-modules==0.2.8
+ - pycryptodomex==3.17
+ - pygments==2.15.1
+ - pymcubes==0.1.4
+ - pyparsing==3.0.9
+ - python-dateutil==2.8.2
+ - python-json-logger==2.0.7
+ - pytorch-fid==0.3.0
+ - pytorch-lightning==2.2.1
+ - pywavelets==1.4.1
+ - pyyaml==6.0
+ - pyzmq==25.1.1
+ - qtconsole==5.5.1
+ - qtpy==2.4.1
+ - referencing==0.30.2
+ - regex==2023.3.23
+ - requests-oauthlib==1.3.1
+ - rfc3339-validator==0.1.4
+ - rfc3986-validator==0.1.1
+ - rpds-py==0.9.2
+ - rsa==4.9
+ - safetensors==0.4.0
+ - scikit-image==0.20.0
+ - scikit-learn==1.2.2
+ - scikit-video==1.1.11
+ - scipy==1.9.1
+ - send2trash==1.8.2
+ - sentencepiece==0.2.0
+ - sentry-sdk==1.24.0
+ - setproctitle==1.3.2
+ - smmap==5.0.0
+ - sniffio==1.3.0
+ - soupsieve==2.4.1
+ - stack-data==0.6.2
+ - tensorboard==2.12.2
+ - tensorboard-data-server==0.7.0
+ - tensorboard-plugin-wit==1.8.1
+ - terminado==0.17.1
+ - thop==0.1.1-2209072238
+ - threadpoolctl==3.1.0
+ - tifffile==2023.4.12
+ - timm==0.6.13
+ - tinycss2==1.2.1
+ - tokenizers==0.15.1
+ - tomli==2.0.1
+ - torchdiffeq==0.2.3
+ - torchmetrics==1.3.2
+ - torchtyping==0.1.4
+ - tornado==6.3.3
+ - tqdm==4.65.0
+ - traitlets==5.9.0
+ - transformers==4.37.1
+ - trimesh==3.21.7
+ - typeguard==4.2.1
+ - typing-extensions==4.10.0
+ - uri-template==1.3.0
+ - varname==0.13.0
+ - vision-aided-loss==0.1.0
+ - wandb==0.15.3
+ - wcwidth==0.2.6
+ - webcolors==1.13
+ - webdataset==0.2.86
+ - webencodings==0.5.1
+ - websocket-client==1.6.2
+ - werkzeug==2.2.3
+ - widgetsnbextension==4.0.10
+ - yapf==0.33.0
+ - yarl==1.9.4
+ - zipp==3.15.0
diff --git a/evaluations/README.md b/evaluations/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ad0ab6c0b3982ad60950df7ffa9af5662d31b2b
--- /dev/null
+++ b/evaluations/README.md
@@ -0,0 +1,72 @@
+# Evaluations
+
+To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
+
+# Download batches
+
+We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
+
+Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
+
+Here are links to download all of the sample and reference batches:
+
+ * LSUN
+ * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
+ * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
+ * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
+ * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
+ * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
+ * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
+ * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
+
+ * ImageNet
+ * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
+ * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
+ * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
+ * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
+ * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
+ * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
+ * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
+ * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
+ * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
+ * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
+ * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
+ * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
+ * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
+
+# Run evaluations
+
+First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
+
+Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
+
+The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
+
+```
+$ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
+...
+computing reference batch activations...
+computing/reading reference batch statistics...
+computing sample batch activations...
+computing/reading sample batch statistics...
+Computing evaluations...
+Inception Score: 215.8370361328125
+FID: 3.9425574129223264
+sFID: 6.140433703346162
+Precision: 0.8265
+Recall: 0.5309
+```
diff --git a/evaluations/evaluator.py b/evaluations/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9590855d564dc94b9b779027a7eae3e3659dd215
--- /dev/null
+++ b/evaluations/evaluator.py
@@ -0,0 +1,653 @@
+import argparse
+import io
+import os
+import random
+import warnings
+import zipfile
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from functools import partial
+from multiprocessing import cpu_count
+from multiprocessing.pool import ThreadPool
+from typing import Iterable, Optional, Tuple
+
+import numpy as np
+import requests
+import tensorflow.compat.v1 as tf
+from scipy import linalg
+from tqdm.auto import tqdm
+
+INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
+INCEPTION_V3_PATH = "classify_image_graph_def.pb"
+
+FID_POOL_NAME = "pool_3:0"
+FID_SPATIAL_NAME = "mixed_6/conv:0"
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("ref_batch", help="path to reference batch npz file")
+ parser.add_argument("sample_batch", help="path to sample batch npz file")
+ args = parser.parse_args()
+
+ config = tf.ConfigProto(
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
+ )
+ config.gpu_options.allow_growth = True
+ evaluator = Evaluator(tf.Session(config=config))
+
+ print("warming up TensorFlow...")
+ # This will cause TF to print a bunch of verbose stuff now rather
+ # than after the next print(), to help prevent confusion.
+ evaluator.warmup()
+
+ print("computing reference batch activations...")
+ ref_acts = evaluator.read_activations(args.ref_batch)
+ print("computing/reading reference batch statistics...")
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
+
+ print("computing sample batch activations...")
+ sample_acts = evaluator.read_activations(args.sample_batch)
+ print("computing/reading sample batch statistics...")
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
+
+ print("Computing evaluations...")
+ print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
+ print("FID:", sample_stats.frechet_distance(ref_stats))
+ print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
+ print("Precision:", prec)
+ print("Recall:", recall)
+
+
+class InvalidFIDException(Exception):
+ pass
+
+
+class FIDStatistics:
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
+ self.mu = mu
+ self.sigma = sigma
+
+ def frechet_distance(self, other, eps=1e-6):
+ """
+ Compute the Frechet distance between two sets of statistics.
+ """
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
+ mu1, sigma1 = self.mu, self.sigma
+ mu2, sigma2 = other.mu, other.sigma
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert (
+ mu1.shape == mu2.shape
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
+ assert (
+ sigma1.shape == sigma2.shape
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
+
+ diff = mu1 - mu2
+
+ # product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = (
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
+ % eps
+ )
+ warnings.warn(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError("Imaginary component {}".format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
+
+
+class Evaluator:
+ def __init__(
+ self,
+ session,
+ batch_size=64,
+ softmax_batch_size=512,
+ ):
+ self.sess = session
+ self.batch_size = batch_size
+ self.softmax_batch_size = softmax_batch_size
+ self.manifold_estimator = ManifoldEstimator(session)
+ with self.sess.graph.as_default():
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
+ self.softmax = _create_softmax_graph(self.softmax_input)
+
+ def warmup(self):
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
+
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
+ with open_npz_array(npz_path, "arr_0") as reader:
+ return self.compute_activations(reader.read_batches(self.batch_size))
+
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Compute image features for downstream evals.
+
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
+ dimension. The tuple is (pool_3, spatial).
+ """
+ preds = []
+ spatial_preds = []
+ for batch in tqdm(batches):
+ batch = batch.astype(np.float32)
+ pred, spatial_pred = self.sess.run(
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
+ )
+ preds.append(pred.reshape([pred.shape[0], -1]))
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
+ return (
+ np.concatenate(preds, axis=0),
+ np.concatenate(spatial_preds, axis=0),
+ )
+
+ def read_statistics(
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
+ obj = np.load(npz_path)
+ if "mu" in list(obj.keys()):
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
+ obj["mu_s"], obj["sigma_s"]
+ )
+ return tuple(self.compute_statistics(x) for x in activations)
+
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
+ mu = np.mean(activations, axis=0)
+ sigma = np.cov(activations, rowvar=False)
+ return FIDStatistics(mu, sigma)
+
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
+ softmax_out = []
+ for i in range(0, len(activations), self.softmax_batch_size):
+ acts = activations[i : i + self.softmax_batch_size]
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
+ preds = np.concatenate(softmax_out, axis=0)
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
+ scores = []
+ for i in range(0, len(preds), split_size):
+ part = preds[i : i + split_size]
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
+ kl = np.mean(np.sum(kl, 1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores))
+
+ def compute_prec_recall(
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
+ ) -> Tuple[float, float]:
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
+ pr = self.manifold_estimator.evaluate_pr(
+ activations_ref, radii_1, activations_sample, radii_2
+ )
+ return (float(pr[0][0]), float(pr[1][0]))
+
+
+class ManifoldEstimator:
+ """
+ A helper for comparing manifolds of feature vectors.
+
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
+ """
+
+ def __init__(
+ self,
+ session,
+ row_batch_size=10000,
+ col_batch_size=10000,
+ nhood_sizes=(3,),
+ clamp_to_percentile=None,
+ eps=1e-5,
+ ):
+ """
+ Estimate the manifold of given feature vectors.
+
+ :param session: the TensorFlow session.
+ :param row_batch_size: row batch size to compute pairwise distances
+ (parameter to trade-off between memory usage and performance).
+ :param col_batch_size: column batch size to compute pairwise distances.
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
+ the given percentile.
+ :param eps: small number for numerical stability.
+ """
+ self.distance_block = DistanceBlock(session)
+ self.row_batch_size = row_batch_size
+ self.col_batch_size = col_batch_size
+ self.nhood_sizes = nhood_sizes
+ self.num_nhoods = len(nhood_sizes)
+ self.clamp_to_percentile = clamp_to_percentile
+ self.eps = eps
+
+ def warmup(self):
+ feats, radii = (
+ np.zeros([1, 2048], dtype=np.float32),
+ np.zeros([1, 1], dtype=np.float32),
+ )
+ self.evaluate_pr(feats, radii, feats, radii)
+
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
+ num_images = len(features)
+
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
+
+ for begin1 in range(0, num_images, self.row_batch_size):
+ end1 = min(begin1 + self.row_batch_size, num_images)
+ row_batch = features[begin1:end1]
+
+ for begin2 in range(0, num_images, self.col_batch_size):
+ end2 = min(begin2 + self.col_batch_size, num_images)
+ col_batch = features[begin2:end2]
+
+ # Compute distances between batches.
+ distance_batch[
+ 0 : end1 - begin1, begin2:end2
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
+
+ # Find the k-nearest neighbor from the current batch.
+ radii[begin1:end1, :] = np.concatenate(
+ [
+ x[:, self.nhood_sizes]
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
+ ],
+ axis=0,
+ )
+
+ if self.clamp_to_percentile is not None:
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
+ radii[radii > max_distances] = 0
+ return radii
+
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
+ """
+ Evaluate if new feature vectors are at the manifold.
+ """
+ num_eval_images = eval_features.shape[0]
+ num_ref_images = radii.shape[0]
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
+
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
+ feature_batch = eval_features[begin1:end1]
+
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
+ ref_batch = features[begin2:end2]
+
+ distance_batch[
+ 0 : end1 - begin1, begin2:end2
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
+
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
+ # If a feature vector is inside a hypersphere of some reference sample, then
+ # the new sample lies at the estimated manifold.
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
+
+ max_realism_score[begin1:end1] = np.max(
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
+ )
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
+
+ return {
+ "fraction": float(np.mean(batch_predictions)),
+ "batch_predictions": batch_predictions,
+ "max_realisim_score": max_realism_score,
+ "nearest_indices": nearest_indices,
+ }
+
+ def evaluate_pr(
+ self,
+ features_1: np.ndarray,
+ radii_1: np.ndarray,
+ features_2: np.ndarray,
+ radii_2: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Evaluate precision and recall efficiently.
+
+ :param features_1: [N1 x D] feature vectors for reference batch.
+ :param radii_1: [N1 x K1] radii for reference vectors.
+ :param features_2: [N2 x D] feature vectors for the other batch.
+ :param radii_2: [N x K2] radii for other vectors.
+ :return: a tuple of arrays for (precision, recall):
+ - precision: an np.ndarray of length K1
+ - recall: an np.ndarray of length K2
+ """
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
+ end_1 = begin_1 + self.row_batch_size
+ batch_1 = features_1[begin_1:end_1]
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
+ end_2 = begin_2 + self.col_batch_size
+ batch_2 = features_2[begin_2:end_2]
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
+ )
+ features_1_status[begin_1:end_1] |= batch_1_in
+ features_2_status[begin_2:end_2] |= batch_2_in
+ return (
+ np.mean(features_2_status.astype(np.float64), axis=0),
+ np.mean(features_1_status.astype(np.float64), axis=0),
+ )
+
+
+class DistanceBlock:
+ """
+ Calculate pairwise distances between vectors.
+
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
+ """
+
+ def __init__(self, session):
+ self.session = session
+
+ # Initialize TF graph to calculate pairwise distances.
+ with session.graph.as_default():
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
+ distance_block_16 = _batch_pairwise_distances(
+ tf.cast(self._features_batch1, tf.float16),
+ tf.cast(self._features_batch2, tf.float16),
+ )
+ self.distance_block = tf.cond(
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
+ lambda: tf.cast(distance_block_16, tf.float32),
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
+ )
+
+ # Extra logic for less thans.
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
+
+ def pairwise_distances(self, U, V):
+ """
+ Evaluate pairwise distances between two batches of feature vectors.
+ """
+ return self.session.run(
+ self.distance_block,
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
+ )
+
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
+ return self.session.run(
+ [self._batch_1_in, self._batch_2_in],
+ feed_dict={
+ self._features_batch1: batch_1,
+ self._features_batch2: batch_2,
+ self._radii1: radii_1,
+ self._radii2: radii_2,
+ },
+ )
+
+
+def _batch_pairwise_distances(U, V):
+ """
+ Compute pairwise distances between two batches of feature vectors.
+ """
+ with tf.variable_scope("pairwise_dist_block"):
+ # Squared norms of each row in U and V.
+ norm_u = tf.reduce_sum(tf.square(U), 1)
+ norm_v = tf.reduce_sum(tf.square(V), 1)
+
+ # norm_u as a column and norm_v as a row vectors.
+ norm_u = tf.reshape(norm_u, [-1, 1])
+ norm_v = tf.reshape(norm_v, [1, -1])
+
+ # Pairwise squared Euclidean distances.
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
+
+ return D
+
+
+class NpzArrayReader(ABC):
+ @abstractmethod
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
+ pass
+
+ @abstractmethod
+ def remaining(self) -> int:
+ pass
+
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
+ def gen_fn():
+ while True:
+ batch = self.read_batch(batch_size)
+ if batch is None:
+ break
+ yield batch
+
+ rem = self.remaining()
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
+ return BatchIterator(gen_fn, num_batches)
+
+
+class BatchIterator:
+ def __init__(self, gen_fn, length):
+ self.gen_fn = gen_fn
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ return self.gen_fn()
+
+
+class StreamingNpzArrayReader(NpzArrayReader):
+ def __init__(self, arr_f, shape, dtype):
+ self.arr_f = arr_f
+ self.shape = shape
+ self.dtype = dtype
+ self.idx = 0
+
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
+ if self.idx >= self.shape[0]:
+ return None
+
+ bs = min(batch_size, self.shape[0] - self.idx)
+ self.idx += bs
+
+ if self.dtype.itemsize == 0:
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
+
+ read_count = bs * np.prod(self.shape[1:])
+ read_size = int(read_count * self.dtype.itemsize)
+ data = _read_bytes(self.arr_f, read_size, "array data")
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
+
+ def remaining(self) -> int:
+ return max(0, self.shape[0] - self.idx)
+
+
+class MemoryNpzArrayReader(NpzArrayReader):
+ def __init__(self, arr):
+ self.arr = arr
+ self.idx = 0
+
+ @classmethod
+ def load(cls, path: str, arr_name: str):
+ with open(path, "rb") as f:
+ arr = np.load(f)[arr_name]
+ return cls(arr)
+
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
+ if self.idx >= self.arr.shape[0]:
+ return None
+
+ res = self.arr[self.idx : self.idx + batch_size]
+ self.idx += batch_size
+ return res
+
+ def remaining(self) -> int:
+ return max(0, self.arr.shape[0] - self.idx)
+
+
+@contextmanager
+def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
+ with _open_npy_file(path, arr_name) as arr_f:
+ version = np.lib.format.read_magic(arr_f)
+ if version == (1, 0):
+ header = np.lib.format.read_array_header_1_0(arr_f)
+ elif version == (2, 0):
+ header = np.lib.format.read_array_header_2_0(arr_f)
+ else:
+ yield MemoryNpzArrayReader.load(path, arr_name)
+ return
+ shape, fortran, dtype = header
+ if fortran or dtype.hasobject:
+ yield MemoryNpzArrayReader.load(path, arr_name)
+ else:
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
+
+
+def _read_bytes(fp, size, error_template="ran out of data"):
+ """
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
+
+ Read from file-like object until size bytes are read.
+ Raises ValueError if not EOF is encountered before size bytes are read.
+ Non-blocking objects only supported if they derive from io objects.
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
+ requested.
+ """
+ data = bytes()
+ while True:
+ # io files (default in python3) return None or raise on
+ # would-block, python2 file will truncate, probably nothing can be
+ # done about that. note that regular files can't be non-blocking
+ try:
+ r = fp.read(size - len(data))
+ data += r
+ if len(r) == 0 or len(data) == size:
+ break
+ except io.BlockingIOError:
+ pass
+ if len(data) != size:
+ msg = "EOF: reading %s, expected %d bytes got %d"
+ raise ValueError(msg % (error_template, size, len(data)))
+ else:
+ return data
+
+
+@contextmanager
+def _open_npy_file(path: str, arr_name: str):
+ with open(path, "rb") as f:
+ with zipfile.ZipFile(f, "r") as zip_f:
+ if f"{arr_name}.npy" not in zip_f.namelist():
+ raise ValueError(f"missing {arr_name} in npz file")
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
+ yield arr_f
+
+
+def _download_inception_model():
+ if os.path.exists(INCEPTION_V3_PATH):
+ return
+ print("downloading InceptionV3 model...")
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
+ r.raise_for_status()
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
+ with open(tmp_path, "wb") as f:
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
+ f.write(chunk)
+ os.rename(tmp_path, INCEPTION_V3_PATH)
+
+
+def _create_feature_graph(input_batch):
+ _download_inception_model()
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
+ with open(INCEPTION_V3_PATH, "rb") as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ pool3, spatial = tf.import_graph_def(
+ graph_def,
+ input_map={f"ExpandDims:0": input_batch},
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
+ name=prefix,
+ )
+ _update_shapes(pool3)
+ spatial = spatial[..., :7]
+ return pool3, spatial
+
+
+def _create_softmax_graph(input_batch):
+ _download_inception_model()
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
+ with open(INCEPTION_V3_PATH, "rb") as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ (matmul,) = tf.import_graph_def(
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
+ )
+ w = matmul.inputs[1]
+ logits = tf.matmul(input_batch, w)
+ return tf.nn.softmax(logits)
+
+
+def _update_shapes(pool3):
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
+ ops = pool3.graph.get_operations()
+ for op in ops:
+ for o in op.outputs:
+ shape = o.get_shape()
+ if shape._dims is not None: # pylint: disable=protected-access
+ # shape = [s.value for s in shape] TF 1.x
+ shape = [s for s in shape] # TF 2.x
+ new_shape = []
+ for j, s in enumerate(shape):
+ if s == 1 and j == 0:
+ new_shape.append(None)
+ else:
+ new_shape.append(s)
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
+ return pool3
+
+
+def _numpy_partition(arr, kth, **kwargs):
+ num_workers = min(cpu_count(), len(arr))
+ chunk_size = len(arr) // num_workers
+ extra = len(arr) % num_workers
+
+ start_idx = 0
+ batches = []
+ for i in range(num_workers):
+ size = chunk_size + (1 if i < extra else 0)
+ batches.append(arr[start_idx : start_idx + size])
+ start_idx += size
+
+ with ThreadPool(num_workers) as pool:
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/evaluations/requirements.txt b/evaluations/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc6df305a4169b13bcfab5e238e4ff1c97b6baaa
--- /dev/null
+++ b/evaluations/requirements.txt
@@ -0,0 +1,4 @@
+tensorflow-gpu>=2.0
+scipy
+requests
+tqdm
\ No newline at end of file
diff --git a/guided_diffusion/__init__.py b/guided_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d018714d642ef967ed0152f418538c4019b2340
--- /dev/null
+++ b/guided_diffusion/__init__.py
@@ -0,0 +1,4 @@
+"""
+Codebase for "Improved Denoising Diffusion Probabilistic Models".
+Also merged continuous_diffusion.py from LSGM: https://github.com/NVlabs/LSGM
+"""
diff --git a/guided_diffusion/__pycache__/__init__.cpython-39.pyc b/guided_diffusion/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87e8e7975a10f53742f77afca6d8a1c2188df71b
Binary files /dev/null and b/guided_diffusion/__pycache__/__init__.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/continuous_diffusion.cpython-39.pyc b/guided_diffusion/__pycache__/continuous_diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..162ddc16cd3103a58e96c02fe7de8a0216849e81
Binary files /dev/null and b/guided_diffusion/__pycache__/continuous_diffusion.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/continuous_diffusion_utils.cpython-39.pyc b/guided_diffusion/__pycache__/continuous_diffusion_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..860517b1475b1ee3e4af6555fb25a846e683ae81
Binary files /dev/null and b/guided_diffusion/__pycache__/continuous_diffusion_utils.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/continuous_distributions.cpython-39.pyc b/guided_diffusion/__pycache__/continuous_distributions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01fbba49f25c70cdab37e993e8f22fbf12bba13a
Binary files /dev/null and b/guided_diffusion/__pycache__/continuous_distributions.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/dist_util.cpython-39.pyc b/guided_diffusion/__pycache__/dist_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b396955204f0f286dbc42678989d27df0264f11
Binary files /dev/null and b/guided_diffusion/__pycache__/dist_util.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc b/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..593ab6ba433f262776869578d2c9f321213fe262
Binary files /dev/null and b/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccac4aacbeebf69dc74df586a2b3422f3f460ee6
Binary files /dev/null and b/guided_diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/logger.cpython-39.pyc b/guided_diffusion/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52fbe63cac5820347e14d1d3bc3273fdbd26adf7
Binary files /dev/null and b/guided_diffusion/__pycache__/logger.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/losses.cpython-39.pyc b/guided_diffusion/__pycache__/losses.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dea9683dca3ac9b278c004a58865d3ad148d3b7e
Binary files /dev/null and b/guided_diffusion/__pycache__/losses.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/nn.cpython-39.pyc b/guided_diffusion/__pycache__/nn.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b64b6849025d3e240b8347c828114abeea692ee0
Binary files /dev/null and b/guided_diffusion/__pycache__/nn.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/resample.cpython-39.pyc b/guided_diffusion/__pycache__/resample.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1fc9dd65dd0681e1a0d1b1100edf81b8751523b
Binary files /dev/null and b/guided_diffusion/__pycache__/resample.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/respace.cpython-39.pyc b/guided_diffusion/__pycache__/respace.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc16190c16b5ce28568ad611e1a7ba3a78f89152
Binary files /dev/null and b/guided_diffusion/__pycache__/respace.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/script_util.cpython-39.pyc b/guided_diffusion/__pycache__/script_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15145cb49a7565ce7069e16e49d65316bdd06567
Binary files /dev/null and b/guided_diffusion/__pycache__/script_util.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/train_util.cpython-39.pyc b/guided_diffusion/__pycache__/train_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ab3f90f8e4916a49127c635c6d64c08b53818bc
Binary files /dev/null and b/guided_diffusion/__pycache__/train_util.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/unet.cpython-39.pyc b/guided_diffusion/__pycache__/unet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba589b51889340afd3f730cd4c6aa837a5f59c97
Binary files /dev/null and b/guided_diffusion/__pycache__/unet.cpython-39.pyc differ
diff --git a/guided_diffusion/__pycache__/unet_old.cpython-39.pyc b/guided_diffusion/__pycache__/unet_old.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f0c134b406732a7b89cd674c1030b348db7635a
Binary files /dev/null and b/guided_diffusion/__pycache__/unet_old.cpython-39.pyc differ
diff --git a/guided_diffusion/continuous_diffusion.py b/guided_diffusion/continuous_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..775733d8f9bb47c909368c17ace82d415c98ac39
--- /dev/null
+++ b/guided_diffusion/continuous_diffusion.py
@@ -0,0 +1,795 @@
+# ---------------------------------------------------------------
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# This work is licensed under the NVIDIA Source Code License
+# for LSGM. To view a copy of this license, see the LICENSE file.
+# ---------------------------------------------------------------
+
+from pdb import set_trace as st
+from abc import ABC, abstractmethod
+import numpy as np
+import torch
+import gc
+from .continuous_distributions import log_p_standard_normal, log_p_var_normal
+from .continuous_diffusion_utils import trace_df_dx_hutchinson, sample_gaussian_like, sample_rademacher_like, get_mixed_prediction
+from torchdiffeq import odeint
+from torch.cuda.amp import autocast
+from timeit import default_timer as timer
+
+from guided_diffusion import dist_util, logger
+
+
+def make_diffusion(args):
+ """ simple diffusion factory function to return diffusion instances. Only use this to create continuous diffusions """
+ if args.sde_sde_type == 'geometric_sde':
+ return DiffusionGeometric(args)
+ elif args.sde_sde_type == 'vpsde':
+ return DiffusionVPSDE(args)
+ elif args.sde_sde_type == 'sub_vpsde':
+ return DiffusionSubVPSDE(args)
+ elif args.sde_sde_type == 'vesde':
+ return DiffusionVESDE(args)
+ else:
+ raise ValueError("Unrecognized sde type: {}".format(args.sde_sde_type))
+
+
+class DiffusionBase(ABC):
+ """
+ Abstract base class for all diffusion implementations.
+ """
+
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ self.sigma2_0 = args.sde_sigma2_0
+ self.sde_type = args.sde_sde_type
+
+ @abstractmethod
+ def f(self, t):
+ """ returns the drift coefficient at time t: f(t) """
+ pass
+
+ @abstractmethod
+ def g2(self, t):
+ """ returns the squared diffusion coefficient at time t: g^2(t) """
+ pass
+
+ @abstractmethod
+ def var(self, t):
+ """ returns variance at time t, \sigma_t^2"""
+ pass
+
+ @abstractmethod
+ def e2int_f(self, t):
+ """ returns e^{\int_0^t f(s) ds} which corresponds to the coefficient of mean at time t. """
+ pass
+
+ @abstractmethod
+ def inv_var(self, var):
+ """ inverse of the variance function at input variance var. """
+ pass
+
+ @abstractmethod
+ def mixing_component(self, x_noisy, var_t, t, enabled):
+ """ returns mixing component which is the optimal denoising model assuming that q(z_0) is N(0, 1) """
+ pass
+
+ def sample_q(self, x_init, noise, var_t, m_t):
+ """ returns a sample from diffusion process at time t """
+ return m_t * x_init + torch.sqrt(var_t) * noise
+
+ def log_snr(self, m_t, var_t):
+ return torch.log((torch.square(m_t) / var_t))
+
+ def _predict_x0_from_eps(self, z, eps, logsnr):
+ """eps = (z - alpha * x0) / sigma
+ """
+ return torch.sqrt(1 + torch.exp(-logsnr)) * (
+ z - eps * torch.rsqrt(1 + torch.exp(logsnr)))
+
+ def _predict_eps_from_x0(self, z, x0, logsnr):
+ """x = (z - sigma * eps) / alpha
+ """
+ return torch.sqrt(1 + torch.exp(logsnr)) * (
+ z - x0 * torch.rsqrt(1 + torch.exp(-logsnr)))
+
+ def _predict_eps_from_z_and_v(self, v_t, var_t, z, m_t):
+ # TODO, use logsnr here?
+ return torch.sqrt(var_t) * z + m_t * v_t
+
+ def _predict_x0_from_z_and_v(self, v_t, var_t, z, m_t):
+ return torch.sqrt(var_t) * v_t + m_t * z
+
+ def cross_entropy_const(self, ode_eps):
+ """ returns cross entropy factor with variance according to ode integration cutoff ode_eps """
+ # _, c, h, w = x_init.shape
+ return 0.5 * (1.0 + torch.log(2.0 * np.pi * self.var(
+ t=torch.tensor(ode_eps, device=dist_util.dev()))))
+
+ def compute_ode_nll(self, dae, eps, ode_eps, ode_solver_tol,
+ enable_autocast, no_autograd, num_samples, report_std):
+ """ calculates NLL based on ODE framework, assuming integration cutoff ode_eps """
+ # ODE solver starts consuming the CPU memory without this on large models
+ # https://github.com/scipy/scipy/issues/10070
+ gc.collect()
+
+ dae.eval()
+
+ def ode_func(t, state):
+ """ the ode function (including log probability integration for NLL calculation) """
+ global nfe_counter
+ nfe_counter = nfe_counter + 1
+
+ x = state[0].detach()
+ x.requires_grad_(True)
+ noise = sample_gaussian_like(
+ x) # could also use rademacher noise (sample_rademacher_like)
+ with torch.set_grad_enabled(True):
+ with autocast(enabled=enable_autocast):
+ variance = self.var(t=t)
+ mixing_component = self.mixing_component(
+ x_noisy=x,
+ var_t=variance,
+ t=t,
+ enabled=dae.mixed_prediction)
+ pred_params = dae(x=x, t=t)
+ params = get_mixed_prediction(dae.mixed_prediction,
+ pred_params,
+ dae.mixing_logit,
+ mixing_component)
+ dx_dt = self.f(t=t) * x + 0.5 * self.g2(
+ t=t) * params / torch.sqrt(variance)
+
+ with autocast(enabled=False):
+ dlogp_x_dt = -trace_df_dx_hutchinson(
+ dx_dt, x, noise, no_autograd).view(x.shape[0], 1)
+
+ return (dx_dt, dlogp_x_dt)
+
+ # NFE counter
+ global nfe_counter
+
+ nll_all, nfe_all = [], []
+ for i in range(num_samples):
+ # integrated log probability
+ logp_diff_t0 = torch.zeros(eps.shape[0], 1, device=dist_util.dev())
+
+ nfe_counter = 0
+
+ # solve the ODE
+ x_t, logp_diff_t = odeint(
+ ode_func,
+ (eps, logp_diff_t0),
+ torch.tensor([ode_eps, 1.0], device=dist_util.dev()),
+ atol=ode_solver_tol,
+ rtol=ode_solver_tol,
+ method="scipy_solver",
+ options={"solver": 'RK45'},
+ )
+ # last output values
+ x_t0, logp_diff_t0 = x_t[-1], logp_diff_t[-1]
+
+ # prior
+ if self.sde_type == 'vesde':
+ logp_prior = torch.sum(log_p_var_normal(x_t0,
+ var=self.sigma2_max),
+ dim=[1, 2, 3])
+ else:
+ logp_prior = torch.sum(log_p_standard_normal(x_t0),
+ dim=[1, 2, 3])
+
+ log_likelihood = logp_prior - logp_diff_t0.view(-1)
+
+ nll_all.append(-log_likelihood)
+ nfe_all.append(nfe_counter)
+
+ nfe_mean = np.mean(nfe_all)
+ nll_all = torch.stack(nll_all, dim=1)
+ nll_mean = torch.mean(nll_all, dim=1)
+ if num_samples > 1 and report_std:
+ nll_stddev = torch.std(nll_all, dim=1)
+ nll_stddev_batch = torch.mean(nll_stddev)
+ nll_stderror_batch = nll_stddev_batch / np.sqrt(num_samples)
+ else:
+ nll_stddev_batch = None
+ nll_stderror_batch = None
+ return nll_mean, nfe_mean, nll_stddev_batch, nll_stderror_batch
+
+ def sample_model_ode(self,
+ dae,
+ num_samples,
+ shape,
+ ode_eps,
+ ode_solver_tol,
+ enable_autocast,
+ temp,
+ noise=None):
+ """ generates samples using the ODE framework, assuming integration cutoff ode_eps """
+ # ODE solver starts consuming the CPU memory without this on large models
+ # https://github.com/scipy/scipy/issues/10070
+ gc.collect()
+
+ dae.eval()
+
+ def ode_func(t, x):
+ """ the ode function (sampling only, no NLL stuff) """
+ global nfe_counter
+ nfe_counter = nfe_counter + 1
+ with autocast(enabled=enable_autocast):
+ variance = self.var(t=t)
+ mixing_component = self.mixing_component(
+ x_noisy=x,
+ var_t=variance,
+ t=t,
+ enabled=dae.mixed_prediction)
+ pred_params = dae(x=x, t=t)
+ params = get_mixed_prediction(dae.mixed_prediction,
+ pred_params, dae.mixing_logit,
+ mixing_component)
+ dx_dt = self.f(t=t) * x + 0.5 * self.g2(
+ t=t) * params / torch.sqrt(variance)
+
+ return dx_dt
+
+ # the initial noise
+ if noise is None:
+ noise = torch.randn(size=[num_samples] + shape,
+ device=dist_util.dev())
+
+ if self.sde_type == 'vesde':
+ noise_init = temp * noise * np.sqrt(self.sigma2_max)
+ else:
+ noise_init = temp * noise
+
+ # NFE counter
+ global nfe_counter
+ nfe_counter = 0
+
+ # solve the ODE
+ start = timer()
+ samples_out = odeint(
+ ode_func,
+ noise_init,
+ torch.tensor([1.0, ode_eps], device=dist_util.dev()),
+ atol=ode_solver_tol,
+ rtol=ode_solver_tol,
+ method="scipy_solver",
+ options={"solver": 'RK45'},
+ )
+ end = timer()
+ ode_solve_time = end - start
+
+ return samples_out[-1], nfe_counter, ode_solve_time
+
+ # def iw_quantities(self, size, time_eps, iw_sample_mode, iw_subvp_like_vp_sde):
+ def iw_quantities(self, iw_sample_mode, size=None):
+
+ args = self.args
+ time_eps, iw_subvp_like_vp_sde = args.sde_time_eps, args.iw_subvp_like_vp_sde
+ if size is None:
+ size = args.batch_size
+
+ if self.sde_type in ['geometric_sde', 'vpsde']:
+ return self._iw_quantities_vpsdelike(size, time_eps,
+ iw_sample_mode)
+ elif self.sde_type in ['sub_vpsde']:
+ return self._iw_quantities_subvpsdelike(size, time_eps,
+ iw_sample_mode,
+ iw_subvp_like_vp_sde)
+ elif self.sde_type in ['vesde']:
+ return self._iw_quantities_vesde(size, time_eps, iw_sample_mode)
+ else:
+ raise NotImplementedError
+
+ def _iw_quantities_vpsdelike(self, size, time_eps, iw_sample_mode):
+ """
+ For all SDEs where the underlying SDE is of the form dz = -0.5 * beta(t) * z * dt + sqrt{beta(t)} * dw, like
+ for the VPSDE.
+ """
+ rho = torch.rand(size=[size], device=dist_util.dev())
+
+ # In the following, obj_weight_t corresponds to the weight in front of the l2 loss for the given iw_sample_mode.
+ # obj_weight_t_ll corresponds to the weight that converts the weighting scheme in iw_sample_mode to likelihood
+ # weighting.
+
+ if iw_sample_mode == 'll_uniform':
+ # uniform t sampling - likelihood obj. for both q and p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'll_iw': # ! q-obj
+ # importance sampling for likelihood obj. - likelihood obj. for both q and p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones)
+ log_sigma2_1, log_sigma2_eps = torch.log(sigma2_1), torch.log(
+ sigma2_eps)
+ var_t = torch.exp(rho * log_sigma2_1 +
+ (1 - rho) * log_sigma2_eps) # sigma square
+ t = self.inv_var(var_t)
+ m_t, g2_t = self.e2int_f(t), self.g2(t) # m_t is alpha_bar
+ obj_weight_t = obj_weight_t_ll = 0.5 * (
+ log_sigma2_1 - log_sigma2_eps) / (1.0 - var_t)
+
+ elif iw_sample_mode == 'drop_all_uniform':
+ # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = torch.ones(1, device=dist_util.dev())
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'drop_all_iw':
+ # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
+ assert self.sde_type == 'vpsde', 'Importance sampling for fully unweighted objective is currently only ' \
+ 'implemented for the regular VPSDE.'
+ t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(
+ rho * self.const_norm_2 + self.const_erf) - self.beta_frac
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = self.const_norm / (1.0 - var_t)
+ obj_weight_t_ll = obj_weight_t * g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'drop_sigma2t_iw': # ! default mode for p
+ # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ sigma2_1, sigma2_eps = self.var(ones), self.var(time_eps * ones)
+ var_t = rho * sigma2_1 + (1 - rho) * sigma2_eps # ! sigma square
+ t = self.inv_var(var_t)
+ m_t, g2_t = self.e2int_f(t), self.g2(t) # ! m_t: alpha_bar sqrt
+ obj_weight_t = 0.5 * (sigma2_1 - sigma2_eps) / (1.0 - var_t)
+ obj_weight_t_ll = obj_weight_t / var_t
+
+ elif iw_sample_mode == 'drop_sigma2t_uniform':
+ # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = g2_t / 2.0
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'rescale_iw':
+ # importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = 0.5 / (1.0 - var_t)
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ else:
+ raise ValueError(
+ "Unrecognized importance sampling type: {}".format(
+ iw_sample_mode))
+
+ return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t.view(-1, 1, 1, 1), \
+ obj_weight_t_ll.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
+
+ def _iw_quantities_subvpsdelike(self, size, time_eps, iw_sample_mode,
+ iw_subvp_like_vp_sde):
+ """
+ For all SDEs where the underlying SDE is of the form
+ dz = -0.5 * beta(t) * z * dt + sqrt{beta(t) * (1 - exp[-2 * betaintegral])} * dw, like for the Sub-VPSDE.
+ When iw_subvp_like_vp_sde is True, then we define the importance sampling distributions based on an analogous
+ VPSDE, while stile using the Sub-VPSDE. The motivation is that deriving the correct importance sampling
+ distributions for the Sub-VPSDE itself is hard, but the importance sampling distributions from analogous VPSDEs
+ probably already significantly reduce the variance also for the Sub-VPSDE.
+ """
+ rho = torch.rand(size=[size], device=dist_util.dev())
+
+ # In the following, obj_weight_t corresponds to the weight in front of the l2 loss for the given iw_sample_mode.
+ # obj_weight_t_ll corresponds to the weight that converts the weighting scheme in iw_sample_mode to likelihood
+ # weighting.
+ if iw_sample_mode == 'll_uniform':
+ # uniform t sampling - likelihood obj. for both q and p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'll_iw':
+ if iw_subvp_like_vp_sde:
+ # importance sampling for vpsde likelihood obj. - sub-vpsde likelihood obj. for both q and p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(
+ time_eps * ones)
+ log_sigma2_1, log_sigma2_eps = torch.log(sigma2_1), torch.log(
+ sigma2_eps)
+ var_t_vpsde = torch.exp(rho * log_sigma2_1 +
+ (1 - rho) * log_sigma2_eps)
+ t = self.inv_var_vpsde(var_t_vpsde)
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = obj_weight_t_ll = g2_t / (2.0 * var_t) * \
+ (log_sigma2_1 - log_sigma2_eps) * var_t_vpsde / (1 - var_t_vpsde) / self.beta(t)
+ else:
+ raise NotImplementedError
+
+ elif iw_sample_mode == 'drop_all_uniform':
+ # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = torch.ones(1, device=dist_util.dev())
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'drop_all_iw':
+ if iw_subvp_like_vp_sde:
+ # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
+ assert self.sde_type == 'sub_vpsde', 'Importance sampling for fully unweighted objective is ' \
+ 'currently only implemented for the Sub-VPSDE.'
+ t = torch.sqrt(1.0 / self.delta_beta_half) * torch.erfinv(
+ rho * self.const_norm_2 + self.const_erf) - self.beta_frac
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = self.const_norm / (1.0 - self.var_vpsde(t))
+ obj_weight_t_ll = obj_weight_t * g2_t / (2.0 * var_t)
+ else:
+ raise NotImplementedError
+
+ elif iw_sample_mode == 'drop_sigma2t_iw':
+ if iw_subvp_like_vp_sde:
+ # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ sigma2_1, sigma2_eps = self.var_vpsde(ones), self.var_vpsde(
+ time_eps * ones)
+ var_t_vpsde = rho * sigma2_1 + (1 - rho) * sigma2_eps
+ t = self.inv_var_vpsde(var_t_vpsde)
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = 0.5 * g2_t / self.beta(t) * (
+ sigma2_1 - sigma2_eps) / (1.0 - var_t_vpsde)
+ obj_weight_t_ll = obj_weight_t / var_t
+ else:
+ raise NotImplementedError
+
+ elif iw_sample_mode == 'drop_sigma2t_uniform':
+ # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = g2_t / 2.0
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'rescale_iw':
+ # importance sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
+ # Note that we use the sub-vpsde variance to scale the p objective! It's not clear what's optimal here!
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = 0.5 / (1.0 - var_t)
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ else:
+ raise ValueError(
+ "Unrecognized importance sampling type: {}".format(
+ iw_sample_mode))
+
+ return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t.view(-1, 1, 1, 1), \
+ obj_weight_t_ll.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
+
+ def _iw_quantities_vesde(self, size, time_eps, iw_sample_mode):
+ """
+ For the VESDE.
+ """
+ rho = torch.rand(size=[size], device=dist_util.dev())
+
+ # In the following, obj_weight_t corresponds to the weight in front of the l2 loss for the given iw_sample_mode.
+ # obj_weight_t_ll corresponds to the weight that converts the weighting scheme in iw_sample_mode to likelihood
+ # weighting.
+ if iw_sample_mode == 'll_uniform':
+ # uniform t sampling - likelihood obj. for both q and p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'll_iw':
+ # importance sampling for likelihood obj. - likelihood obj. for both q and p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(
+ time_eps * ones), self.var(time_eps * ones)
+ log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(
+ self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps /
+ sigma2_eps)
+ var_N_t = (1.0 - self.sigma2_min) / (
+ 1.0 - torch.exp(rho *
+ (log_frac_sigma2_1 + log_frac_sigma2_eps) -
+ log_frac_sigma2_eps))
+ t = self.inv_var_N(var_N_t)
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = obj_weight_t_ll = 0.5 * (
+ log_frac_sigma2_1 +
+ log_frac_sigma2_eps) * self.var_N(t) / (1.0 - self.sigma2_min)
+
+ elif iw_sample_mode == 'drop_all_uniform':
+ # uniform t sampling - likelihood obj. for q, all-prefactors-dropped obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = torch.ones(1, device=dist_util.dev())
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'drop_all_iw':
+ # importance sampling for all-pref.-dropped obj. - likelihood obj. for q, all-pref.-dropped obj. for p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ nsigma2_1, nsigma2_eps, sigma2_eps = self.var_N(ones), self.var_N(
+ time_eps * ones), self.var(time_eps * ones)
+ log_frac_sigma2_1, log_frac_sigma2_eps = torch.log(
+ self.sigma2_max / nsigma2_1), torch.log(nsigma2_eps /
+ sigma2_eps)
+ var_N_t = (1.0 - self.sigma2_min) / (
+ 1.0 - torch.exp(rho *
+ (log_frac_sigma2_1 + log_frac_sigma2_eps) -
+ log_frac_sigma2_eps))
+ t = self.inv_var_N(var_N_t)
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t_ll = 0.5 * (log_frac_sigma2_1 +
+ log_frac_sigma2_eps) * self.var_N(t) / (
+ 1.0 - self.sigma2_min)
+ obj_weight_t = 2.0 * obj_weight_t_ll / np.log(
+ self.sigma2_max / self.sigma2_min)
+
+ elif iw_sample_mode == 'drop_sigma2t_iw':
+ # importance sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
+ ones = torch.ones_like(rho, device=dist_util.dev())
+ nsigma2_1, nsigma2_eps = self.var_N(ones), self.var_N(time_eps *
+ ones)
+ var_N_t = torch.exp(rho * torch.log(nsigma2_1) +
+ (1 - rho) * torch.log(nsigma2_eps))
+ t = self.inv_var_N(var_N_t)
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = 0.5 * torch.log(
+ nsigma2_1 / nsigma2_eps) * self.var_N(t)
+ obj_weight_t_ll = obj_weight_t / var_t
+
+ elif iw_sample_mode == 'drop_sigma2t_uniform':
+ # uniform sampling for inv_sigma2_t-dropped obj. - likelihood obj. for q, inv_sigma2_t-dropped obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = g2_t / 2.0
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ elif iw_sample_mode == 'rescale_iw':
+ # uniform sampling for 1/(1-sigma2_t) resc. obj. - likelihood obj. for q, 1/(1-sigma2_t) resc. obj. for p
+ t = rho * (1. - time_eps) + time_eps
+ var_t, m_t, g2_t = self.var(t), self.e2int_f(t), self.g2(t)
+ obj_weight_t = 0.5 / (1.0 - var_t)
+ obj_weight_t_ll = g2_t / (2.0 * var_t)
+
+ else:
+ raise ValueError(
+ "Unrecognized importance sampling type: {}".format(
+ iw_sample_mode))
+
+ return t, var_t.view(-1, 1, 1, 1), m_t.view(-1, 1, 1, 1), obj_weight_t.view(-1, 1, 1, 1), \
+ obj_weight_t_ll.view(-1, 1, 1, 1), g2_t.view(-1, 1, 1, 1)
+
+
+class DiffusionGeometric(DiffusionBase):
+ """
+ Diffusion implementation with dz = -0.5 * beta(t) * z * dt + sqrt(beta(t)) * dW SDE and geometric progression of
+ variance. This is our new diffusion.
+ """
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.sigma2_min = args.sde_sigma2_min
+ self.sigma2_max = args.sde_sigma2_max
+
+ def f(self, t):
+ return -0.5 * self.g2(t)
+
+ def g2(self, t):
+ sigma2_geom = self.sigma2_min * (
+ (self.sigma2_max / self.sigma2_min)**t)
+ log_term = np.log(self.sigma2_max / self.sigma2_min)
+ return sigma2_geom * log_term / (1.0 - self.sigma2_0 +
+ self.sigma2_min - sigma2_geom)
+
+ def var(self, t):
+ return self.sigma2_min * ((self.sigma2_max / self.sigma2_min)**
+ t) - self.sigma2_min + self.sigma2_0
+
+ def e2int_f(self, t):
+ return torch.sqrt(1.0 + self.sigma2_min *
+ (1.0 - (self.sigma2_max / self.sigma2_min)**t) /
+ (1.0 - self.sigma2_0))
+
+ def inv_var(self, var):
+ return torch.log(
+ (var + self.sigma2_min - self.sigma2_0) /
+ self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
+
+ def mixing_component(self, x_noisy, var_t, t, enabled):
+ if enabled:
+ return torch.sqrt(var_t) * x_noisy
+ else:
+ return None
+
+
+class DiffusionVPSDE(DiffusionBase):
+ """
+ Diffusion implementation of the VPSDE. This uses the same SDE like DiffusionGeometric but with linear beta(t).
+ Note that we need to scale beta_start and beta_end by 1000 relative to JH's DDPM values, since our t is in [0,1].
+ """
+
+ def __init__(self, args):
+ super().__init__(args)
+ # self.beta_start = args.sde_beta_start # 0.1
+ # self.beta_end = args.sde_beta_end # 20
+
+ # ! hard coded, in the scale of 1000.
+ # beta_start = scale * 0.0001
+ # beta_end = scale * 0.02
+
+ self.beta_start = 0.1
+ self.beta_end = 20
+
+ # auxiliary constants
+ self.time_eps = args.sde_time_eps # 0.01 by default in LSGM. Any influence?
+ self.delta_beta_half = torch.tensor(0.5 *
+ (self.beta_end - self.beta_start),
+ device=dist_util.dev())
+ self.beta_frac = torch.tensor(self.beta_start /
+ (self.beta_end - self.beta_start),
+ device=dist_util.dev())
+ self.const_aq = (1.0 - self.sigma2_0) * torch.exp(
+ 0.5 * self.beta_frac) * torch.sqrt(
+ 0.25 * np.pi / self.delta_beta_half)
+ self.const_erf = torch.erf(
+ torch.sqrt(self.delta_beta_half) *
+ (self.time_eps + self.beta_frac))
+ self.const_norm = self.const_aq * (torch.erf(
+ torch.sqrt(self.delta_beta_half) *
+ (1.0 + self.beta_frac)) - self.const_erf)
+ self.const_norm_2 = torch.erf(
+ torch.sqrt(self.delta_beta_half) *
+ (1.0 + self.beta_frac)) - self.const_erf
+
+ def f(self, t):
+ return -0.5 * self.g2(t)
+
+ def g2(self, t):
+ return self.beta_start + (self.beta_end - self.beta_start) * t
+
+ def var(self, t):
+ return 1.0 - (1.0 - self.sigma2_0
+ ) * torch.exp(-self.beta_start * t - 0.5 *
+ (self.beta_end - self.beta_start) * t * t)
+
+ def e2int_f(self, t):
+ return torch.exp(-0.5 * self.beta_start * t - 0.25 *
+ (self.beta_end - self.beta_start) * t * t)
+
+ def inv_var(self, var):
+ c = torch.log((1 - var) / (1 - self.sigma2_0))
+ a = self.beta_end - self.beta_start
+ t = (-self.beta_start +
+ torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a
+ return t
+
+ def mixing_component(self, x_noisy, var_t, t, enabled):
+ if enabled:
+ return torch.sqrt(var_t) * x_noisy
+ else:
+ return None
+
+ def mixing_component_x0(self, x_noisy, var_t, t, enabled):
+ if enabled:
+ # return torch.sqrt(var_t) * x_noisy
+ return torch.sqrt(1-var_t) * x_noisy # zt * alpha_t
+ else:
+ return None
+
+
+class DiffusionSubVPSDE(DiffusionBase):
+ """
+ Diffusion implementation of the sub-VPSDE. Note that this uses a different SDE compared to the above two diffusions.
+ """
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.beta_start = args.sde_beta_start
+ self.beta_end = args.sde_beta_end
+
+ # auxiliary constants (assumes regular VPSDE)
+ self.time_eps = args.sde_time_eps
+ self.delta_beta_half = torch.tensor(0.5 *
+ (self.beta_end - self.beta_start),
+ device=dist_util.dev())
+ self.beta_frac = torch.tensor(self.beta_start /
+ (self.beta_end - self.beta_start),
+ device=dist_util.dev())
+ self.const_aq = (1.0 - self.sigma2_0) * torch.exp(
+ 0.5 * self.beta_frac) * torch.sqrt(
+ 0.25 * np.pi / self.delta_beta_half)
+ self.const_erf = torch.erf(
+ torch.sqrt(self.delta_beta_half) *
+ (self.time_eps + self.beta_frac))
+ self.const_norm = self.const_aq * (torch.erf(
+ torch.sqrt(self.delta_beta_half) *
+ (1.0 + self.beta_frac)) - self.const_erf)
+ self.const_norm_2 = torch.erf(
+ torch.sqrt(self.delta_beta_half) *
+ (1.0 + self.beta_frac)) - self.const_erf
+
+ def f(self, t):
+ return -0.5 * self.beta(t)
+
+ def g2(self, t):
+ return self.beta(t) * (
+ 1.0 - torch.exp(-2.0 * self.beta_start * t -
+ (self.beta_end - self.beta_start) * t * t))
+
+ def var(self, t):
+ int_term = torch.exp(-self.beta_start * t - 0.5 *
+ (self.beta_end - self.beta_start) * t * t)
+ return torch.square(1.0 - int_term) + self.sigma2_0 * int_term
+
+ def e2int_f(self, t):
+ return torch.exp(-0.5 * self.beta_start * t - 0.25 *
+ (self.beta_end - self.beta_start) * t * t)
+
+ def beta(self, t):
+ """ auxiliary beta function """
+ return self.beta_start + (self.beta_end - self.beta_start) * t
+
+ def inv_var(self, var):
+ raise NotImplementedError
+
+ def mixing_component(self, x_noisy, var_t, t, enabled):
+ if enabled:
+ int_term = torch.exp(-self.beta_start * t - 0.5 *
+ (self.beta_end - self.beta_start) * t *
+ t).view(-1, 1, 1, 1)
+ return torch.sqrt(var_t) * x_noisy / (
+ torch.square(1.0 - int_term) + int_term)
+ else:
+ return None
+
+ def var_vpsde(self, t):
+ return 1.0 - (1.0 - self.sigma2_0
+ ) * torch.exp(-self.beta_start * t - 0.5 *
+ (self.beta_end - self.beta_start) * t * t)
+
+ def inv_var_vpsde(self, var):
+ c = torch.log((1 - var) / (1 - self.sigma2_0))
+ a = self.beta_end - self.beta_start
+ t = (-self.beta_start +
+ torch.sqrt(np.square(self.beta_start) - 2 * a * c)) / a
+ return t
+
+
+class DiffusionVESDE(DiffusionBase):
+ """
+ Diffusion implementation of the VESDE with dz = sqrt(beta(t)) * dW
+ """
+
+ def __init__(self, args):
+ super().__init__(args)
+ self.sigma2_min = args.sde_sigma2_min
+ self.sigma2_max = args.sde_sigma2_max
+ assert self.sigma2_min == self.sigma2_0, "VESDE was proposed implicitly assuming sigma2_min = sigma2_0!"
+
+ def f(self, t):
+ return torch.zeros_like(t, device=dist_util.dev())
+
+ def g2(self, t):
+ return self.sigma2_min * np.log(self.sigma2_max / self.sigma2_min) * (
+ (self.sigma2_max / self.sigma2_min)**t)
+
+ def var(self, t):
+ return self.sigma2_min * ((self.sigma2_max / self.sigma2_min)**
+ t) - self.sigma2_min + self.sigma2_0
+
+ def e2int_f(self, t):
+ return torch.ones_like(t, device=dist_util.dev())
+
+ def inv_var(self, var):
+ return torch.log(
+ (var + self.sigma2_min - self.sigma2_0) /
+ self.sigma2_min) / np.log(self.sigma2_max / self.sigma2_min)
+
+ def mixing_component(self, x_noisy, var_t, t, enabled):
+ if enabled:
+ return torch.sqrt(var_t) * x_noisy / (self.sigma2_min * (
+ (self.sigma2_max / self.sigma2_min)**t.view(-1, 1, 1, 1)) -
+ self.sigma2_min + 1.0)
+ else:
+ return None
+
+ def var_N(self, t):
+ return 1.0 - self.sigma2_min + self.sigma2_min * (
+ (self.sigma2_max / self.sigma2_min)**t)
+
+ def inv_var_N(self, var):
+ return torch.log(
+ (var + self.sigma2_min - 1.0) / self.sigma2_min) / np.log(
+ self.sigma2_max / self.sigma2_min)
diff --git a/guided_diffusion/continuous_diffusion_utils.py b/guided_diffusion/continuous_diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d70c43111c37cd219e545d14b6bfb5b2f95ca30f
--- /dev/null
+++ b/guided_diffusion/continuous_diffusion_utils.py
@@ -0,0 +1,815 @@
+# ---------------------------------------------------------------
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# This work is licensed under the NVIDIA Source Code License
+# for LSGM. To view a copy of this license, see the LICENSE file.
+# ---------------------------------------------------------------
+
+import logging
+import os
+import math
+import shutil
+import time
+import sys
+import types
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.distributed as dist
+# from util.distributions import PixelNormal
+from torch.cuda.amp import autocast
+
+# from tensorboardX import SummaryWriter
+
+
+class AvgrageMeter(object):
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.avg = 0
+ self.sum = 0
+ self.cnt = 0
+
+ def update(self, val, n=1):
+ self.sum += val * n
+ self.cnt += n
+ self.avg = self.sum / self.cnt
+
+
+class ExpMovingAvgrageMeter(object):
+
+ def __init__(self, momentum=0.9):
+ self.momentum = momentum
+ self.reset()
+
+ def reset(self):
+ self.avg = 0
+
+ def update(self, val):
+ self.avg = (1. - self.momentum) * self.avg + self.momentum * val
+
+
+class DummyDDP(nn.Module):
+ def __init__(self, model):
+ super(DummyDDP, self).__init__()
+ self.module = model
+
+ def forward(self, *input, **kwargs):
+ return self.module(*input, **kwargs)
+
+
+def count_parameters_in_M(model):
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
+
+
+def save_checkpoint(state, is_best, save):
+ filename = os.path.join(save, 'checkpoint.pth.tar')
+ torch.save(state, filename)
+ if is_best:
+ best_filename = os.path.join(save, 'model_best.pth.tar')
+ shutil.copyfile(filename, best_filename)
+
+
+def save(model, model_path):
+ torch.save(model.state_dict(), model_path)
+
+
+def load(model, model_path):
+ model.load_state_dict(torch.load(model_path))
+
+
+def create_exp_dir(path, scripts_to_save=None):
+ if not os.path.exists(path):
+ os.makedirs(path, exist_ok=True)
+ print('Experiment dir : {}'.format(path))
+
+ if scripts_to_save is not None:
+ if not os.path.exists(os.path.join(path, 'scripts')):
+ os.mkdir(os.path.join(path, 'scripts'))
+ for script in scripts_to_save:
+ dst_file = os.path.join(path, 'scripts', os.path.basename(script))
+ shutil.copyfile(script, dst_file)
+
+
+class Logger(object):
+ def __init__(self, rank, save):
+ # other libraries may set logging before arriving at this line.
+ # by reloading logging, we can get rid of previous configs set by other libraries.
+ from importlib import reload
+ reload(logging)
+ self.rank = rank
+ if self.rank == 0:
+ log_format = '%(asctime)s %(message)s'
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO,
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
+ fh = logging.FileHandler(os.path.join(save, 'log.txt'))
+ fh.setFormatter(logging.Formatter(log_format))
+ logging.getLogger().addHandler(fh)
+ self.start_time = time.time()
+
+ def info(self, string, *args):
+ if self.rank == 0:
+ elapsed_time = time.time() - self.start_time
+ elapsed_time = time.strftime(
+ '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time))
+ if isinstance(string, str):
+ string = elapsed_time + string
+ else:
+ logging.info(elapsed_time)
+ logging.info(string, *args)
+
+
+class Writer(object):
+ def __init__(self, rank, save):
+ self.rank = rank
+ if self.rank == 0:
+ self.writer = SummaryWriter(log_dir=save, flush_secs=20)
+
+ def add_scalar(self, *args, **kwargs):
+ if self.rank == 0:
+ self.writer.add_scalar(*args, **kwargs)
+
+ def add_figure(self, *args, **kwargs):
+ if self.rank == 0:
+ self.writer.add_figure(*args, **kwargs)
+
+ def add_image(self, *args, **kwargs):
+ if self.rank == 0:
+ self.writer.add_image(*args, **kwargs)
+
+ def add_histogram(self, *args, **kwargs):
+ if self.rank == 0:
+ self.writer.add_histogram(*args, **kwargs)
+
+ def add_histogram_if(self, write, *args, **kwargs):
+ if write and False: # Used for debugging.
+ self.add_histogram(*args, **kwargs)
+
+ def close(self, *args, **kwargs):
+ if self.rank == 0:
+ self.writer.close()
+
+
+def common_init(rank, seed, save_dir):
+ # we use different seeds per gpu. But we sync the weights after model initialization.
+ torch.manual_seed(rank + seed)
+ np.random.seed(rank + seed)
+ torch.cuda.manual_seed(rank + seed)
+ torch.cuda.manual_seed_all(rank + seed)
+ torch.backends.cudnn.benchmark = True
+
+ # prepare logging and tensorboard summary
+ logging = Logger(rank, save_dir)
+ writer = Writer(rank, save_dir)
+
+ return logging, writer
+
+
+def reduce_tensor(tensor, world_size):
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ rt /= world_size
+ return rt
+
+
+def get_stride_for_cell_type(cell_type):
+ if cell_type.startswith('normal') or cell_type.startswith('combiner'):
+ stride = 1
+ elif cell_type.startswith('down'):
+ stride = 2
+ elif cell_type.startswith('up'):
+ stride = -1
+ else:
+ raise NotImplementedError(cell_type)
+
+ return stride
+
+
+def get_cout(cin, stride):
+ if stride == 1:
+ cout = cin
+ elif stride == -1:
+ cout = cin // 2
+ elif stride == 2:
+ cout = 2 * cin
+
+ return cout
+
+
+def kl_balancer_coeff(num_scales, groups_per_scale, fun):
+ if fun == 'equal':
+ coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda()
+ elif fun == 'linear':
+ coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)],
+ dim=0).cuda()
+ elif fun == 'sqrt':
+ coeff = torch.cat(
+ [np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)],
+ dim=0).cuda()
+ elif fun == 'square':
+ coeff = torch.cat(
+ [np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1])
+ for i in range(num_scales)], dim=0).cuda()
+ else:
+ raise NotImplementedError
+ # convert min to 1.
+ coeff /= torch.min(coeff)
+ return coeff
+
+
+def kl_per_group(kl_all):
+ kl_vals = torch.mean(kl_all, dim=0)
+ kl_coeff_i = torch.abs(kl_all)
+ kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01
+
+ return kl_coeff_i, kl_vals
+
+
+def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None):
+ if kl_balance and kl_coeff < 1.0:
+ alpha_i = alpha_i.unsqueeze(0)
+
+ kl_all = torch.stack(kl_all, dim=1)
+ kl_coeff_i, kl_vals = kl_per_group(kl_all)
+ total_kl = torch.sum(kl_coeff_i)
+
+ kl_coeff_i = kl_coeff_i / alpha_i * total_kl
+ kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True)
+ kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1)
+
+ # for reporting
+ kl_coeffs = kl_coeff_i.squeeze(0)
+ else:
+ kl_all = torch.stack(kl_all, dim=1)
+ kl_vals = torch.mean(kl_all, dim=0)
+ # kl = torch.sum(kl_all, dim=1)
+ # kl = torch.mean(kl_all, dim=1)
+ kl = torch.mean(kl_all)
+ kl_coeffs = torch.ones(size=(len(kl_vals),))
+
+ return kl_coeff * kl, kl_coeffs, kl_vals
+
+
+def kl_per_group_vada(all_log_q, all_neg_log_p):
+ assert len(all_log_q) == len(all_neg_log_p)
+
+ kl_all_list = []
+ kl_diag = []
+ for log_q, neg_log_p in zip(all_log_q, all_neg_log_p):
+ # kl_diag.append(torch.mean(torch.sum(neg_log_p + log_q, dim=[2, 3]), dim=0))
+ kl_diag.append(torch.mean(torch.mean(neg_log_p + log_q, dim=[2, 3]), dim=0))
+ # kl_all_list.append(torch.sum(neg_log_p + log_q, dim=[1, 2, 3]))
+ kl_all_list.append(torch.mean(neg_log_p + log_q, dim=[1, 2, 3]))
+
+ # kl_all = torch.stack(kl_all, dim=1) # batch x num_total_groups
+ kl_vals = torch.mean(torch.stack(kl_all_list, dim=1), dim=0) # mean per group
+
+ return kl_all_list, kl_vals, kl_diag
+
+
+def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff):
+ # return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff)
+ return max(min(min_kl_coeff + (max_kl_coeff - min_kl_coeff) * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff)
+
+
+def log_iw(decoder, x, log_q, log_p, crop=False):
+ recon = reconstruction_loss(decoder, x, crop)
+ return - recon - log_q + log_p
+
+
+def reconstruction_loss(decoder, x, crop=False):
+ from util.distributions import DiscMixLogistic
+
+ recon = decoder.log_p(x)
+ if crop:
+ recon = recon[:, :, 2:30, 2:30]
+
+ if isinstance(decoder, DiscMixLogistic):
+ return - torch.sum(recon, dim=[1, 2]) # summation over RGB is done.
+ else:
+ return - torch.sum(recon, dim=[1, 2, 3])
+
+
+def vae_terms(all_log_q, all_eps):
+ from util.distributions import log_p_standard_normal
+
+ # compute kl
+ kl_all = []
+ kl_diag = []
+ log_p, log_q = 0., 0.
+ for log_q_conv, eps in zip(all_log_q, all_eps):
+ log_p_conv = log_p_standard_normal(eps)
+ kl_per_var = log_q_conv - log_p_conv
+ kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=[2, 3]), dim=0))
+ kl_all.append(torch.sum(kl_per_var, dim=[1, 2, 3]))
+ log_q += torch.sum(log_q_conv, dim=[1, 2, 3])
+ log_p += torch.sum(log_p_conv, dim=[1, 2, 3])
+ return log_q, log_p, kl_all, kl_diag
+
+
+def sum_log_q(all_log_q):
+ log_q = 0.
+ for log_q_conv in all_log_q:
+ log_q += torch.sum(log_q_conv, dim=[1, 2, 3])
+
+ return log_q
+
+
+def cross_entropy_normal(all_eps):
+ from util.distributions import log_p_standard_normal
+
+ cross_entropy = 0.
+ neg_log_p_per_group = []
+ for eps in all_eps:
+ neg_log_p_conv = - log_p_standard_normal(eps)
+ neg_log_p = torch.sum(neg_log_p_conv, dim=[1, 2, 3])
+ cross_entropy += neg_log_p
+ neg_log_p_per_group.append(neg_log_p_conv)
+
+ return cross_entropy, neg_log_p_per_group
+
+
+def tile_image(batch_image, n, m=None):
+ if m is None:
+ m = n
+ assert n * m == batch_image.size(0)
+ channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3)
+ batch_image = batch_image.view(n, m, channels, height, width)
+ batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c
+ batch_image = batch_image.contiguous().view(channels, n * height, m * width)
+ return batch_image
+
+
+def average_gradients_naive(params, is_distributed):
+ """ Gradient averaging. """
+ if is_distributed:
+ size = float(dist.get_world_size())
+ for param in params:
+ if param.requires_grad:
+ param.grad.data /= size
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
+
+
+def average_gradients(params, is_distributed):
+ """ Gradient averaging. """
+ if is_distributed:
+ if isinstance(params, types.GeneratorType):
+ params = [p for p in params]
+
+ size = float(dist.get_world_size())
+ grad_data = []
+ grad_size = []
+ grad_shapes = []
+ # Gather all grad values
+ for param in params:
+ if param.requires_grad:
+ grad_size.append(param.grad.data.numel())
+ grad_shapes.append(list(param.grad.data.shape))
+ grad_data.append(param.grad.data.flatten())
+ grad_data = torch.cat(grad_data).contiguous()
+
+ # All-reduce grad values
+ grad_data /= size
+ dist.all_reduce(grad_data, op=dist.ReduceOp.SUM)
+
+ # Put back the reduce grad values to parameters
+ base = 0
+ for i, param in enumerate(params):
+ if param.requires_grad:
+ param.grad.data = grad_data[base:base + grad_size[i]].view(grad_shapes[i])
+ base += grad_size[i]
+
+
+def average_params(params, is_distributed):
+ """ parameter averaging. """
+ if is_distributed:
+ size = float(dist.get_world_size())
+ for param in params:
+ param.data /= size
+ dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
+
+
+def average_tensor(t, is_distributed):
+ if is_distributed:
+ size = float(dist.get_world_size())
+ dist.all_reduce(t.data, op=dist.ReduceOp.SUM)
+ t.data /= size
+
+
+def broadcast_params(params, is_distributed):
+ if is_distributed:
+ for param in params:
+ dist.broadcast(param.data, src=0)
+
+
+def num_output(dataset):
+ if dataset in {'mnist', 'omniglot'}:
+ return 28 * 28
+ elif dataset == 'cifar10':
+ return 3 * 32 * 32
+ elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
+ size = int(dataset.split('_')[-1])
+ return 3 * size * size
+ elif dataset == 'ffhq':
+ return 3 * 256 * 256
+ else:
+ raise NotImplementedError
+
+
+def get_input_size(dataset):
+ if dataset in {'mnist', 'omniglot'}:
+ return 32
+ elif dataset == 'cifar10':
+ return 32
+ elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'):
+ size = int(dataset.split('_')[-1])
+ return size
+ elif dataset == 'ffhq':
+ return 256
+ else:
+ raise NotImplementedError
+
+
+def get_bpd_coeff(dataset):
+ n = num_output(dataset)
+ return 1. / np.log(2.) / n
+
+
+def get_channel_multiplier(dataset, num_scales):
+ if dataset in {'cifar10', 'omniglot'}:
+ mult = (1, 1, 1)
+ elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}:
+ if num_scales == 3:
+ mult = (1, 1, 1) # used for prior at 16
+ elif num_scales == 4:
+ mult = (1, 2, 2, 2) # used for prior at 32
+ elif num_scales == 5:
+ mult = (1, 1, 2, 2, 2) # used for prior at 64
+ elif dataset == 'mnist':
+ mult = (1, 1)
+ else:
+ raise NotImplementedError
+
+ return mult
+
+
+def get_attention_scales(dataset):
+ if dataset in {'cifar10', 'omniglot'}:
+ attn = (True, False, False)
+ elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}:
+ # attn = (False, True, False, False) # used for 32
+ attn = (False, False, True, False, False) # used for 64
+ elif dataset == 'mnist':
+ attn = (True, False)
+ else:
+ raise NotImplementedError
+
+ return attn
+
+
+def change_bit_length(x, num_bits):
+ if num_bits != 8:
+ x = torch.floor(x * 255 / 2 ** (8 - num_bits))
+ x /= (2 ** num_bits - 1)
+ return x
+
+
+def view4D(t, size, inplace=True):
+ """
+ Equal to view(-1, 1, 1, 1).expand(size)
+ Designed because of this bug:
+ https://github.com/pytorch/pytorch/pull/48696
+ """
+ if inplace:
+ return t.unsqueeze_(-1).unsqueeze_(-1).unsqueeze_(-1).expand(size)
+ else:
+ return t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(size)
+
+
+def get_arch_cells(arch_type, use_se):
+ if arch_type == 'res_mbconv':
+ arch_cells = dict()
+ arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se}
+ arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se}
+ arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se}
+ arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se}
+ arch_cells['ar_nn'] = ['']
+ elif arch_type == 'res_bnswish':
+ arch_cells = dict()
+ arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_dec'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['up_dec'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_post'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['up_post'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['ar_nn'] = ['']
+ elif arch_type == 'res_bnswish2':
+ arch_cells = dict()
+ arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['down_enc'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['normal_dec'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['up_dec'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['down_pre'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['normal_post'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['up_post'] = {'conv_branch': ['res_bnswish_x2'], 'se': use_se}
+ arch_cells['ar_nn'] = ['']
+ elif arch_type == 'res_mbconv_attn':
+ arch_cells = dict()
+ arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish', ], 'se': use_se, 'attn_type': 'attn'}
+ arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se, 'attn_type': 'attn'}
+ arch_cells['normal_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'}
+ arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'}
+ arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se}
+ arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se}
+ arch_cells['ar_nn'] = ['']
+ elif arch_type == 'res_mbconv_attn_half':
+ arch_cells = dict()
+ arch_cells['normal_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_enc'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'}
+ arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'}
+ arch_cells['normal_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['down_pre'] = {'conv_branch': ['res_bnswish', 'res_bnswish'], 'se': use_se}
+ arch_cells['normal_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se}
+ arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se}
+ arch_cells['ar_nn'] = ['']
+ else:
+ raise NotImplementedError
+
+ return arch_cells
+
+
+def groups_per_scale(num_scales, num_groups_per_scale):
+ g = []
+ n = num_groups_per_scale
+ for s in range(num_scales):
+ assert n >= 1
+ g.append(n)
+ return g
+
+
+class PositionalEmbedding(nn.Module):
+ def __init__(self, embedding_dim, scale):
+ super(PositionalEmbedding, self).__init__()
+ self.embedding_dim = embedding_dim
+ self.scale = scale
+
+ def forward(self, timesteps):
+ assert len(timesteps.shape) == 1
+ timesteps = timesteps * self.scale
+ half_dim = self.embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ return emb
+
+
+class RandomFourierEmbedding(nn.Module):
+ def __init__(self, embedding_dim, scale):
+ super(RandomFourierEmbedding, self).__init__()
+ self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False)
+
+ def forward(self, timesteps):
+ emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359)
+ return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+
+
+def init_temb_fun(embedding_type, embedding_scale, embedding_dim):
+ if embedding_type == 'positional':
+ temb_fun = PositionalEmbedding(embedding_dim, embedding_scale)
+ elif embedding_type == 'fourier':
+ temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale)
+ else:
+ raise NotImplementedError
+
+ return temb_fun
+
+def get_dae_model(args, num_input_channels):
+ if args.dae_arch == 'ncsnpp':
+ # we need to import NCSNpp after processes are launched on the multi gpu training.
+ from score_sde.ncsnpp import NCSNpp
+ dae = NCSNpp(args, num_input_channels)
+ else:
+ raise NotImplementedError
+
+ return dae
+
+def symmetrize_image_data(images):
+ return 2.0 * images - 1.0
+
+
+def unsymmetrize_image_data(images):
+ return (images + 1.) / 2.
+
+
+def normalize_symmetric(images):
+ """
+ Normalize images by dividing the largest intensity. Used for visualizing the intermediate steps.
+ """
+ b = images.shape[0]
+ m, _ = torch.max(torch.abs(images).view(b, -1), dim=1)
+ images /= (m.view(b, 1, 1, 1) + 1e-3)
+
+ return images
+
+
+@torch.jit.script
+def soft_clamp5(x: torch.Tensor):
+ return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
+
+@torch.jit.script
+def soft_clamp(x: torch.Tensor, a: torch.Tensor):
+ return x.div(a).tanh_().mul(a)
+
+class SoftClamp5(nn.Module):
+ def __init__(self):
+ super(SoftClamp5, self).__init__()
+
+ def forward(self, x):
+ return soft_clamp5(x)
+
+
+def override_architecture_fields(args, stored_args, logging):
+ # list of architecture parameters used in NVAE:
+ architecture_fields = ['arch_instance', 'num_nf', 'num_latent_scales', 'num_groups_per_scale',
+ 'num_latent_per_group', 'num_channels_enc', 'num_preprocess_blocks',
+ 'num_preprocess_cells', 'num_cell_per_cond_enc', 'num_channels_dec',
+ 'num_postprocess_blocks', 'num_postprocess_cells', 'num_cell_per_cond_dec',
+ 'decoder_dist', 'num_x_bits', 'log_sig_q_scale',
+ 'progressive_input_vae', 'channel_mult']
+
+ # backward compatibility
+ """ We have broken backward compatibility. No need to se these manually
+ if not hasattr(stored_args, 'log_sig_q_scale'):
+ logging.info('*** Setting %s manually ****', 'log_sig_q_scale')
+ setattr(stored_args, 'log_sig_q_scale', 5.)
+
+ if not hasattr(stored_args, 'latent_grad_cutoff'):
+ logging.info('*** Setting %s manually ****', 'latent_grad_cutoff')
+ setattr(stored_args, 'latent_grad_cutoff', 0.)
+
+ if not hasattr(stored_args, 'progressive_input_vae'):
+ logging.info('*** Setting %s manually ****', 'progressive_input_vae')
+ setattr(stored_args, 'progressive_input_vae', 'none')
+
+ if not hasattr(stored_args, 'progressive_output_vae'):
+ logging.info('*** Setting %s manually ****', 'progressive_output_vae')
+ setattr(stored_args, 'progressive_output_vae', 'none')
+ """
+
+ if not hasattr(stored_args, 'num_x_bits'):
+ logging.info('*** Setting %s manually ****', 'num_x_bits')
+ setattr(stored_args, 'num_x_bits', 8)
+
+ if not hasattr(stored_args, 'channel_mult'):
+ logging.info('*** Setting %s manually ****', 'channel_mult')
+ setattr(stored_args, 'channel_mult', [1, 2])
+
+ for f in architecture_fields:
+ if not hasattr(args, f) or getattr(args, f) != getattr(stored_args, f):
+ logging.info('Setting %s from loaded checkpoint', f)
+ setattr(args, f, getattr(stored_args, f))
+
+
+def init_processes(rank, size, fn, args):
+ """ Initialize the distributed environment. """
+ os.environ['MASTER_ADDR'] = args.master_address
+ os.environ['MASTER_PORT'] = '6020'
+ torch.cuda.set_device(args.local_rank)
+ dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size)
+ fn(args)
+ dist.barrier()
+ dist.destroy_process_group()
+
+
+def sample_rademacher_like(y):
+ return torch.randint(low=0, high=2, size=y.shape, device='cuda') * 2 - 1
+
+
+def sample_gaussian_like(y):
+ return torch.randn_like(y, device='cuda')
+
+
+def trace_df_dx_hutchinson(f, x, noise, no_autograd):
+ """
+ Hutchinson's trace estimator for Jacobian df/dx, O(1) call to autograd
+ """
+ if no_autograd:
+ # the following is compatible with checkpointing
+ torch.sum(f * noise).backward()
+ # torch.autograd.backward(tensors=[f], grad_tensors=[noise])
+ jvp = x.grad
+ trJ = torch.sum(jvp * noise, dim=[1, 2, 3])
+ x.grad = None
+ else:
+ jvp = torch.autograd.grad(f, x, noise, create_graph=False)[0]
+ trJ = torch.sum(jvp * noise, dim=[1, 2, 3])
+ # trJ = torch.einsum('bijk,bijk->b', jvp, noise) # we could test if there's a speed difference in einsum vs sum
+
+ return trJ
+
+def different_p_q_objectives(iw_sample_p, iw_sample_q):
+ assert iw_sample_p in ['ll_uniform', 'drop_all_uniform', 'll_iw', 'drop_all_iw', 'drop_sigma2t_iw', 'rescale_iw',
+ 'drop_sigma2t_uniform']
+ assert iw_sample_q in ['reweight_p_samples', 'll_uniform', 'll_iw']
+ # In these cases, we reuse the likelihood-based p-objective (either the uniform sampling version or the importance
+ # sampling version) also for q.
+ if iw_sample_p in ['ll_uniform', 'll_iw'] and iw_sample_q == 'reweight_p_samples':
+ return False
+ # In these cases, we are using a non-likelihood-based objective for p, and hence definitly need to use another q
+ # objective.
+ else:
+ return True
+
+
+# def decoder_output(dataset, logits, fixed_log_scales=None):
+# if dataset in {'cifar10', 'celeba_64', 'celeba_256', 'imagenet_32', 'imagenet_64', 'ffhq',
+# 'lsun_bedroom_128', 'lsun_bedroom_256', 'mnist', 'omniglot',
+# 'lsun_church_256'}:
+# return PixelNormal(logits, fixed_log_scales)
+# else:
+# raise NotImplementedError
+
+
+def get_mixed_prediction(mixed_prediction, param, mixing_logit, mixing_component=None):
+ if mixed_prediction:
+ assert mixing_component is not None, 'Provide mixing component when mixed_prediction is enabled.'
+ coeff = torch.sigmoid(mixing_logit)
+ param = (1 - coeff) * mixing_component + coeff * param
+
+ return param
+
+
+def set_vesde_sigma_max(args, vae, train_queue, logging, is_distributed):
+ logging.info('')
+ logging.info('Calculating max. pairwise distance in latent space to set sigma2_max for VESDE...')
+
+ eps_list = []
+ vae.eval()
+ for step, x in enumerate(train_queue):
+ x = x[0] if len(x) > 1 else x
+ x = x.cuda()
+ x = symmetrize_image_data(x)
+
+ # run vae
+ with autocast(enabled=args.autocast_train):
+ with torch.set_grad_enabled(False):
+ logits, all_log_q, all_eps = vae(x)
+ eps = torch.cat(all_eps, dim=1)
+
+ eps_list.append(eps.detach())
+
+ # concat eps tensor on each GPU and then gather all on all GPUs
+ eps_this_rank = torch.cat(eps_list, dim=0)
+ if is_distributed:
+ eps_all_gathered = [torch.zeros_like(eps_this_rank)] * dist.get_world_size()
+ dist.all_gather(eps_all_gathered, eps_this_rank)
+ eps_full = torch.cat(eps_all_gathered, dim=0)
+ else:
+ eps_full = eps_this_rank
+
+ # max pairwise distance squared between all latent encodings, is computed on CPU
+ eps_full = eps_full.cpu().float()
+ eps_full = eps_full.flatten(start_dim=1).unsqueeze(0)
+ max_pairwise_dist_sqr = torch.cdist(eps_full, eps_full).square().max()
+ max_pairwise_dist_sqr = max_pairwise_dist_sqr.cuda()
+
+ # to be safe, we broadcast to all GPUs if we are in distributed environment. Shouldn't be necessary in principle.
+ if is_distributed:
+ dist.broadcast(max_pairwise_dist_sqr, src=0)
+
+ args.sigma2_max = max_pairwise_dist_sqr.item()
+
+ logging.info('Done! Set args.sigma2_max set to {}'.format(args.sigma2_max))
+ logging.info('')
+ return args
+
+
+def mask_inactive_variables(x, is_active):
+ x = x * is_active
+ return x
+
+
+def common_x_operations(x, num_x_bits):
+ x = x[0] if len(x) > 1 else x
+ x = x.cuda()
+
+ # change bit length
+ x = change_bit_length(x, num_x_bits)
+ x = symmetrize_image_data(x)
+
+ return x
diff --git a/guided_diffusion/continuous_distributions.py b/guided_diffusion/continuous_distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..288c3d9a6ade4d2b2bef64704d666a00ddc499c0
--- /dev/null
+++ b/guided_diffusion/continuous_distributions.py
@@ -0,0 +1,284 @@
+# ---------------------------------------------------------------
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# This work is licensed under the NVIDIA Source Code License
+# for LSGM. To view a copy of this license, see the LICENSE file.
+# ---------------------------------------------------------------
+
+import torch
+import torch.nn.functional as F
+from torch.distributions.bernoulli import Bernoulli as Bern
+import numpy as np
+from pdb import set_trace as st
+# from util import utils
+from .continuous_diffusion_utils import view4D
+
+@torch.jit.script
+def sample_normal_jit(mu, sigma):
+ rho = mu.mul(0).normal_()
+ z = rho.mul_(sigma).add_(mu)
+ return z, rho
+
+
+@torch.jit.script
+def log_p_standard_normal(samples):
+ log_p = - 0.5 * torch.square(samples) - 0.9189385332 # 0.5 * np.log(2 * np.pi)
+ return log_p
+
+
+def log_p_var_normal(samples, var):
+ log_p = - 0.5 * torch.square(samples) / var - 0.5 * np.log(var) - 0.9189385332 # 0.5 * np.log(2 * np.pi)
+ return log_p
+
+
+def one_hot(indices, depth, dim):
+ indices = indices.unsqueeze(dim)
+ size = list(indices.size())
+ size[dim] = depth
+ y_onehot = torch.zeros(size).cuda()
+ y_onehot.zero_()
+ y_onehot.scatter_(dim, indices, 1)
+
+ return y_onehot
+
+# TODO: merge this with the next class
+class PixelNormal(object):
+ def __init__(self, param, fixed_log_scales=None):
+ size = param.size()
+ C = size[1]
+ if fixed_log_scales is None:
+ self.num_c = C // 2
+ self.means = param[:, :self.num_c, :, :] # B, 1 or 3, H, W
+ self.log_scales = torch.clamp(param[:, self.num_c:, :, :], min=-7.0) # B, 1 or 3, H, W
+ raise NotImplementedError
+ else:
+ self.num_c = C
+ self.means = param # B, 1 or 3, H, W
+ self.log_scales = view4D(fixed_log_scales, size) # B, 1 or 3, H, W
+
+ def get_params(self):
+ return self.means, self.log_scales, self.num_c
+
+ def log_prob(self, samples):
+ B, C, H, W = samples.size()
+ assert C == self.num_c
+
+ log_probs = -0.5 * torch.square(self.means - samples) * torch.exp(-2.0 * self.log_scales) - self.log_scales - 0.9189385332 # -0.5*log(2*pi)
+ return log_probs
+
+ def sample(self, t=1.):
+ z, rho = sample_normal_jit(self.means, torch.exp(self.log_scales)*t) # B, 3, H, W
+ return z
+
+ def log_prob_discrete(self, samples):
+ """
+ Calculates discrete pixel probabilities.
+ """
+ # samples should be in [-1, 1] already
+ B, C, H, W = samples.size()
+ assert C == self.num_c
+
+ centered = samples - self.means
+ inv_stdv = torch.exp(- self.log_scales)
+ plus_in = inv_stdv * (centered + 1. / 255.)
+ cdf_plus = torch.distributions.Normal(0, 1).cdf(plus_in)
+ min_in = inv_stdv * (centered - 1. / 255.)
+ cdf_min = torch.distributions.Normal(0, 1).cdf(min_in)
+ log_cdf_plus = torch.log(torch.clamp(cdf_plus, min=1e-12))
+ log_one_minus_cdf_min = torch.log(torch.clamp(1. - cdf_min, min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.999, log_one_minus_cdf_min,
+ torch.log(torch.clamp(cdf_delta, min=1e-12))))
+
+ assert log_probs.size() == samples.size()
+ return log_probs
+
+ def mean(self):
+ return self.means
+
+
+class Normal:
+ def __init__(self, mu, log_sigma):
+ self.mu = mu
+ self.log_sigma = log_sigma
+ self.sigma = torch.exp(log_sigma)
+
+ def sample(self, t=1.):
+ return sample_normal_jit(self.mu, self.sigma * t)
+
+ def sample_given_rho(self, rho):
+ return rho * self.sigma + self.mu
+
+ def log_p(self, samples):
+ normalized_samples = (samples - self.mu) / self.sigma
+ log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma
+ return log_p
+
+ def kl(self, normal_dist):
+ term1 = (self.mu - normal_dist.mu) / normal_dist.sigma
+ term2 = self.sigma / normal_dist.sigma
+
+ return 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(self.log_sigma) + normal_dist.log_sigma
+
+ def mean(self):
+ return self.mu
+
+
+class Bernoulli:
+ def __init__(self, logits):
+ self.dist = Bern(logits=logits)
+
+ def log_p(self, samples):
+ # convert samples to {0, 1}
+ samples = (samples + 1.) / 2
+ return self.dist.log_prob(samples)
+
+ def mean(self):
+ # map the mean to [-1, 1]
+ return 2 * self.dist.mean - 1.
+
+class DiscLogistic:
+ def __init__(self, param):
+ B, C, H, W = param.size()
+ self.num_c = C // 2
+ self.means = param[:, :self.num_c, :, :] # B, 3, H, W
+ self.log_scales = torch.clamp(param[:, self.num_c:, :, :], min=-7.0) # B, 3, H, W
+
+ def log_p(self, samples):
+ assert torch.max(samples) <= 1.0 and torch.min(samples) >= -1.0
+
+ B, C, H, W = samples.size()
+ assert C == self.num_c
+
+ centered = samples - self.means # B, 3, H, W
+ inv_stdv = torch.exp(- self.log_scales)
+ plus_in = inv_stdv * (centered + 1. / 255.)
+ cdf_plus = torch.sigmoid(plus_in)
+ min_in = inv_stdv * (centered - 1. / 255.)
+ cdf_min = torch.sigmoid(min_in)
+ log_cdf_plus = plus_in - F.softplus(plus_in)
+ log_one_minus_cdf_min = - F.softplus(min_in)
+ cdf_delta = cdf_plus - cdf_min
+ mid_in = inv_stdv * centered
+ log_pdf_mid = mid_in - self.log_scales - 2. * F.softplus(mid_in)
+
+ log_prob_mid_safe = torch.where(cdf_delta > 1e-5,
+ torch.log(torch.clamp(cdf_delta, min=1e-10)),
+ log_pdf_mid - np.log(127.5))
+
+ log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.999, log_one_minus_cdf_min,
+ log_prob_mid_safe)) # B, 3, H, W
+
+ return log_probs
+
+ def sample(self):
+ u = torch.Tensor(self.means.size()).uniform_(1e-5, 1. - 1e-5).cuda() # B, 3, H, W
+ x = self.means + torch.exp(self.log_scales) * (torch.log(u) - torch.log(1. - u)) # B, 3, H, W
+ x = torch.clamp(x, -1, 1.)
+ return x
+
+ def mean(self):
+ return self.means
+
+
+class DiscMixLogistic:
+ def __init__(self, param, num_mix=10, num_bits=8):
+ B, C, H, W = param.size()
+ self.num_mix = num_mix
+ self.logit_probs = param[:, :num_mix, :, :] # B, M, H, W
+ l = param[:, num_mix:, :, :].view(B, 3, 3 * num_mix, H, W) # B, 3, 3 * M, H, W
+ self.means = l[:, :, :num_mix, :, :] # B, 3, M, H, W
+ self.log_scales = torch.clamp(l[:, :, num_mix:2 * num_mix, :, :], min=-7.0) # B, 3, M, H, W
+ self.coeffs = torch.tanh(l[:, :, 2 * num_mix:3 * num_mix, :, :]) # B, 3, M, H, W
+ self.max_val = 2. ** num_bits - 1
+
+ def log_p(self, samples):
+ assert torch.max(samples) <= 1.0 and torch.min(samples) >= -1.0
+
+ B, C, H, W = samples.size()
+ assert C == 3, 'only RGB images are considered.'
+
+ samples = samples.unsqueeze(4) # B, 3, H , W
+ samples = samples.expand(-1, -1, -1, -1, self.num_mix).permute(0, 1, 4, 2, 3) # B, 3, M, H, W
+ mean1 = self.means[:, 0, :, :, :] # B, M, H, W
+ mean2 = self.means[:, 1, :, :, :] + \
+ self.coeffs[:, 0, :, :, :] * samples[:, 0, :, :, :] # B, M, H, W
+ mean3 = self.means[:, 2, :, :, :] + \
+ self.coeffs[:, 1, :, :, :] * samples[:, 0, :, :, :] + \
+ self.coeffs[:, 2, :, :, :] * samples[:, 1, :, :, :] # B, M, H, W
+
+ mean1 = mean1.unsqueeze(1) # B, 1, M, H, W
+ mean2 = mean2.unsqueeze(1) # B, 1, M, H, W
+ mean3 = mean3.unsqueeze(1) # B, 1, M, H, W
+ means = torch.cat([mean1, mean2, mean3], dim=1) # B, 3, M, H, W
+ centered = samples - means # B, 3, M, H, W
+
+ inv_stdv = torch.exp(- self.log_scales)
+ plus_in = inv_stdv * (centered + 1. / self.max_val)
+ cdf_plus = torch.sigmoid(plus_in)
+ min_in = inv_stdv * (centered - 1. / self.max_val)
+ cdf_min = torch.sigmoid(min_in)
+ log_cdf_plus = plus_in - F.softplus(plus_in)
+ log_one_minus_cdf_min = - F.softplus(min_in)
+ cdf_delta = cdf_plus - cdf_min
+ mid_in = inv_stdv * centered
+ log_pdf_mid = mid_in - self.log_scales - 2. * F.softplus(mid_in)
+
+ log_prob_mid_safe = torch.where(cdf_delta > 1e-5,
+ torch.log(torch.clamp(cdf_delta, min=1e-10)),
+ log_pdf_mid - np.log(self.max_val / 2))
+
+ log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.999, log_one_minus_cdf_min,
+ log_prob_mid_safe)) # B, 3, M, H, W
+
+ log_probs = torch.sum(log_probs, 1) + F.log_softmax(self.logit_probs, dim=1) # B, M, H, W
+ return torch.logsumexp(log_probs, dim=1) # B, H, W
+
+ def sample(self, t=1.):
+ gumbel = -torch.log(- torch.log(torch.Tensor(self.logit_probs.size()).uniform_(1e-5, 1. - 1e-5).cuda())) # B, M, H, W
+ sel = one_hot(torch.argmax(self.logit_probs / t + gumbel, 1), self.num_mix, dim=1) # B, M, H, W
+ sel = sel.unsqueeze(1) # B, 1, M, H, W
+
+ # select logistic parameters
+ means = torch.sum(self.means * sel, dim=2) # B, 3, H, W
+ log_scales = torch.sum(self.log_scales * sel, dim=2) # B, 3, H, W
+ coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W
+
+ # cells from logistic & clip to interval
+ # we don't actually round to the nearest 8bit value when sampling
+ u = torch.Tensor(means.size()).uniform_(1e-5, 1. - 1e-5).cuda() # B, 3, H, W
+ x = means + torch.exp(log_scales) * t * (torch.log(u) - torch.log(1. - u)) # B, 3, H, W
+
+ x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W
+ x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W
+ x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W
+
+ x0 = x0.unsqueeze(1)
+ x1 = x1.unsqueeze(1)
+ x2 = x2.unsqueeze(1)
+
+ x = torch.cat([x0, x1, x2], 1)
+ return x
+
+ def mean(self):
+ sel = torch.softmax(self.logit_probs, dim=1) # B, M, H, W
+ sel = sel.unsqueeze(1) # B, 1, M, H, W
+
+ # select logistic parameters
+ means = torch.sum(self.means * sel, dim=2) # B, 3, H, W
+ coeffs = torch.sum(self.coeffs * sel, dim=2) # B, 3, H, W
+
+ # we don't sample from logistic components, because of the linear dependencies, we use mean
+ x = means # B, 3, H, W
+ x0 = torch.clamp(x[:, 0, :, :], -1, 1.) # B, H, W
+ x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1) # B, H, W
+ x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1) # B, H, W
+
+ x0 = x0.unsqueeze(1)
+ x1 = x1.unsqueeze(1)
+ x2 = x2.unsqueeze(1)
+
+ x = torch.cat([x0, x1, x2], 1)
+ return x
+
+
diff --git a/guided_diffusion/dist_util.py b/guided_diffusion/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8be748d586b4854d437df202c1b6018a56752ae
--- /dev/null
+++ b/guided_diffusion/dist_util.py
@@ -0,0 +1,170 @@
+"""
+Helpers for distributed training.
+"""
+
+import datetime
+import io
+import os
+import socket
+
+import blobfile as bf
+from pdb import set_trace as st
+# from mpi4py import MPI
+import torch as th
+import torch.distributed as dist
+
+# Change this to reflect your cluster layout.
+# The GPU for a given rank is (rank % GPUS_PER_NODE).
+GPUS_PER_NODE = 8
+SETUP_RETRY_COUNT = 3
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+
+ if not dist.is_initialized():
+ return 0
+
+ return dist.get_rank()
+
+
+def synchronize():
+ if not dist.is_available():
+ return
+
+ if not dist.is_initialized():
+ return
+
+ world_size = dist.get_world_size()
+
+ if world_size == 1:
+ return
+
+ dist.barrier()
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+
+ if not dist.is_initialized():
+ return 1
+
+ return dist.get_world_size()
+
+
+def setup_dist(args):
+ """
+ Setup a distributed process group.
+ """
+ if dist.is_initialized():
+ return
+
+ # print(f"{os.environ['MASTER_ADDR']=} {args.master_port=}")
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count(), timeout=datetime.timedelta(seconds=5400))
+ # st() no mark
+ dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000))
+ print(f"{args.local_rank=} init complete")
+
+ # synchronize() # extra memory on rank 0, why?
+
+ th.cuda.empty_cache()
+
+def cleanup():
+ dist.destroy_process_group()
+
+def dev():
+ """
+ Get the device to use for torch.distributed.
+ """
+ if th.cuda.is_available():
+
+ if get_world_size() > 1:
+ return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}")
+ return th.device(f"cuda")
+
+ return th.device("cpu")
+
+
+# def load_state_dict(path, submodule_name='', **kwargs):
+def load_state_dict(path, **kwargs):
+ """
+ Load a PyTorch file without redundant fetches across MPI ranks.
+ """
+ # chunk_size = 2 ** 30 # MPI has a relatively small size limit
+ # if get_rank() == 0:
+ # with bf.BlobFile(path, "rb") as f:
+ # data = f.read()
+ # num_chunks = len(data) // chunk_size
+ # if len(data) % chunk_size:
+ # num_chunks += 1
+ # MPI.COMM_WORLD.bcast(num_chunks)
+ # for i in range(0, len(data), chunk_size):
+ # MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
+ # else:
+ # num_chunks = MPI.COMM_WORLD.bcast(None)
+ # data = bytes()
+ # for _ in range(num_chunks):
+ # data += MPI.COMM_WORLD.bcast(None)
+
+ # return th.load(io.BytesIO(data), **kwargs)
+ # with open(path) as f:
+ ckpt = th.load(path, **kwargs)
+ # if submodule_name != '':
+ # assert submodule_name in ckpt
+ # return ckpt[submodule_name]
+ # else:
+ return ckpt
+
+
+def sync_params(params):
+ """
+ Synchronize a sequence of Tensors across ranks from rank 0.
+ """
+ # for k, p in params:
+ for p in params:
+ with th.no_grad():
+ try:
+ dist.broadcast(p, 0)
+ except Exception as e:
+ print(k, e)
+ # print(e)
+
+
+def _find_free_port():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
+ finally:
+ s.close()
+
+
+_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
+_reduce_dtype = th.float32 # Data type to use for initial per-tensor reduction.
+_counter_dtype = th.float64 # Data type to use for the internal counters.
+_rank = 0 # Rank of the current process.
+_sync_device = None # Device to use for multiprocess communication. None = single-process.
+_sync_called = False # Has _sync() been called yet?
+_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
+_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
+
+def init_multiprocessing(rank, sync_device):
+ r"""Initializes `utils.torch_utils.training_stats` for collecting statistics
+ across multiple processes.
+ This function must be called after
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
+ The call is not necessary if multi-process collection is not needed.
+ Args:
+ rank: Rank of the current process.
+ sync_device: PyTorch device to use for inter-process
+ communication, or None to disable multi-process
+ collection. Typically `torch.device('cuda', rank)`.
+ """
+ global _rank, _sync_device
+ assert not _sync_called
+ _rank = rank
+ _sync_device = sync_device
\ No newline at end of file
diff --git a/guided_diffusion/fp16_util.py b/guided_diffusion/fp16_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a61ca5bfd4e07cd77a7fb69c3fb98df25230f9
--- /dev/null
+++ b/guided_diffusion/fp16_util.py
@@ -0,0 +1,307 @@
+"""
+Helpers to train with 16-bit precision.
+"""
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+from . import logger
+
+INITIAL_LOG_LOSS_SCALE = 20.0
+
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.float()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.float()
+
+
+def make_master_params(param_groups_and_shapes):
+ """
+ Copy model parameters into a (differently-shaped) list of full-precision
+ parameters.
+ """
+ master_params = []
+ for param_group, shape in param_groups_and_shapes:
+ master_param = nn.Parameter(
+ _flatten_dense_tensors([
+ param.detach().float() for (_, param) in param_group
+ ]).view(shape))
+ master_param.requires_grad = True
+ master_params.append(master_param)
+ return master_params
+
+
+def model_grads_to_master_grads(param_groups_and_shapes, master_params):
+ """
+ Copy the gradients from the model parameters into the master parameters
+ from make_master_params().
+ """
+ for master_param, (param_group, shape) in zip(master_params,
+ param_groups_and_shapes):
+ master_param.grad = _flatten_dense_tensors([
+ param_grad_or_zeros(param) for (_, param) in param_group
+ ]).view(shape)
+
+
+def master_params_to_model_params(param_groups_and_shapes, master_params):
+ """
+ Copy the master parameter data back into the model parameters.
+ """
+ # Without copying to a list, if a generator is passed, this will
+ # silently not copy any parameters.
+ for master_param, (param_group, _) in zip(master_params,
+ param_groups_and_shapes):
+ for (_, param), unflat_master_param in zip(
+ param_group,
+ unflatten_master_params(param_group, master_param.view(-1))):
+ param.detach().copy_(unflat_master_param)
+
+
+def unflatten_master_params(param_group, master_param):
+ return _unflatten_dense_tensors(master_param,
+ [param for (_, param) in param_group])
+
+
+def get_param_groups_and_shapes(named_model_params):
+ named_model_params = list(named_model_params)
+ scalar_vector_named_params = (
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
+ (-1),
+ )
+ matrix_named_params = (
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
+ (1, -1),
+ )
+ return [scalar_vector_named_params, matrix_named_params]
+
+
+def master_params_to_state_dict(model, param_groups_and_shapes, master_params,
+ use_fp16):
+ if use_fp16:
+ state_dict = model.state_dict()
+ for master_param, (param_group, _) in zip(master_params,
+ param_groups_and_shapes):
+ for (name, _), unflat_master_param in zip(
+ param_group,
+ unflatten_master_params(param_group,
+ master_param.view(-1))):
+ assert name in state_dict
+ state_dict[name] = unflat_master_param
+ else:
+ state_dict = model.state_dict()
+ for i, (name, _value) in enumerate(model.named_parameters()):
+ assert name in state_dict
+ state_dict[name] = master_params[i]
+ return state_dict
+
+
+def state_dict_to_master_params(model, state_dict, use_fp16):
+ if use_fp16:
+ named_model_params = [(name, state_dict[name])
+ for name, _ in model.named_parameters()]
+ param_groups_and_shapes = get_param_groups_and_shapes(
+ named_model_params)
+ master_params = make_master_params(param_groups_and_shapes)
+ else:
+ master_params = [
+ state_dict[name] for name, _ in model.named_parameters()
+ ]
+ return master_params
+
+
+def zero_master_grads(master_params):
+ for param in master_params:
+ param.grad = None
+
+
+def zero_grad(model_params):
+ for param in model_params:
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad.zero_()
+
+
+def param_grad_or_zeros(param):
+ if param.grad is not None:
+ return param.grad.data.detach()
+ else:
+ return th.zeros_like(param)
+
+
+class MixedPrecisionTrainer:
+
+ def __init__(self,
+ *,
+ model,
+ use_fp16=False,
+ use_amp=False,
+ fp16_scale_growth=1e-3,
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
+ model_name='ddpm',
+ submodule_name='',
+ model_params=None):
+ self.model_name = model_name
+ self.model = model
+ self.use_fp16 = use_fp16
+ self.use_amp = use_amp
+ if self.use_amp:
+ # https://github.com/pytorch/pytorch/issues/40497#issuecomment-1262373602
+ # https://github.com/pytorch/pytorch/issues/111739
+ self.scaler = th.cuda.amp.GradScaler(enabled=use_amp, init_scale=2**15, growth_interval=100)
+ logger.log(model_name, 'enables AMP to accelerate training')
+ else:
+ logger.log(model_name, 'not enables AMP to accelerate training')
+
+ self.fp16_scale_growth = fp16_scale_growth
+
+ self.model_params = list(self.model.parameters(
+ )) if model_params is None else list(model_params) if not isinstance(
+ model_params, list) else model_params
+ self.master_params = self.model_params
+ self.param_groups_and_shapes = None
+ self.lg_loss_scale = initial_lg_loss_scale
+
+ if self.use_fp16:
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
+ self.model.named_parameters())
+ self.master_params = make_master_params(
+ self.param_groups_and_shapes)
+ self.model.convert_to_fp16()
+
+ def zero_grad(self):
+ zero_grad(self.model_params)
+
+ def backward(self, loss: th.Tensor, disable_amp=False, **kwargs):
+ """**kwargs: retain_graph=True
+ """
+ if self.use_fp16:
+ loss_scale = 2**self.lg_loss_scale
+ (loss * loss_scale).backward(**kwargs)
+ elif self.use_amp and not disable_amp:
+ self.scaler.scale(loss).backward(**kwargs)
+ else:
+ loss.backward(**kwargs)
+
+ # def optimize(self, opt: th.optim.Optimizer, clip_grad=False):
+ def optimize(self, opt: th.optim.Optimizer, clip_grad=True):
+ if self.use_fp16:
+ return self._optimize_fp16(opt)
+ elif self.use_amp:
+ return self._optimize_amp(opt, clip_grad)
+ else:
+ return self._optimize_normal(opt, clip_grad)
+
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
+ model_grads_to_master_grads(self.param_groups_and_shapes,
+ self.master_params)
+ grad_norm, param_norm = self._compute_norms(
+ grad_scale=2**self.lg_loss_scale)
+ if check_overflow(grad_norm):
+ self.lg_loss_scale -= 1
+ logger.log(
+ f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
+ zero_master_grads(self.master_params)
+ return False
+
+ logger.logkv_mean("grad_norm", grad_norm)
+ logger.logkv_mean("param_norm", param_norm)
+
+ for p in self.master_params:
+ p.grad.mul_(1.0 / (2**self.lg_loss_scale))
+ opt.step()
+ zero_master_grads(self.master_params)
+ master_params_to_model_params(self.param_groups_and_shapes,
+ self.master_params)
+ self.lg_loss_scale += self.fp16_scale_growth
+ return True
+
+ def _optimize_amp(self, opt: th.optim.Optimizer, clip_grad=False):
+ # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
+ assert clip_grad
+ self.scaler.unscale_(opt) # to calculate accurate gradients
+
+ if clip_grad:
+ th.nn.utils.clip_grad_norm_( # type: ignore
+ self.master_params,
+ 5.0,
+ norm_type=2,
+ error_if_nonfinite=False,
+ foreach=True,
+ ) # clip before compute_norm
+
+ grad_norm, param_norm = self._compute_norms()
+ logger.logkv_mean("grad_norm", grad_norm)
+ logger.logkv_mean("param_norm", param_norm)
+
+ self.scaler.step(opt)
+ self.scaler.update()
+ return True
+
+ def _optimize_normal(self, opt: th.optim.Optimizer, clip_grad:bool=False):
+
+ assert clip_grad
+ if clip_grad:
+ th.nn.utils.clip_grad_norm_( # type: ignore
+ self.master_params,
+ 5.0,
+ norm_type=2,
+ error_if_nonfinite=False,
+ foreach=True,
+ ) # clip before compute_norm
+
+ grad_norm, param_norm = self._compute_norms()
+ logger.logkv_mean("grad_norm", grad_norm)
+ logger.logkv_mean("param_norm", param_norm)
+ opt.step()
+ return True
+
+ def _compute_norms(self, grad_scale=1.0):
+ grad_norm = 0.0
+ param_norm = 0.0
+ for p in self.master_params:
+ with th.no_grad():
+ param_norm += th.norm(p, p=2, dtype=th.float32).item()**2
+ if p.grad is not None:
+ grad_norm += th.norm(p.grad, p=2,
+ dtype=th.float32).item()**2
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
+
+ def master_params_to_state_dict(self, master_params, model=None):
+ if model is None:
+ model = self.model
+ return master_params_to_state_dict(model, self.param_groups_and_shapes,
+ master_params, self.use_fp16)
+
+ def state_dict_to_master_params(self, state_dict, model=None):
+ if model is None:
+ model = self.model
+ return state_dict_to_master_params(model, state_dict, self.use_fp16)
+
+ def state_dict_to_master_params_given_submodule_name(
+ self, state_dict, submodule_name):
+ return state_dict_to_master_params(getattr(self.model, submodule_name),
+ state_dict, self.use_fp16)
+
+
+def check_overflow(value):
+ return (value == float("inf")) or (value == -float("inf")) or (value
+ != value)
diff --git a/guided_diffusion/gaussian_diffusion.py b/guided_diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..33ede2ac608ec99dcfe858ae5b1ab4af238621a9
--- /dev/null
+++ b/guided_diffusion/gaussian_diffusion.py
@@ -0,0 +1,1253 @@
+"""
+This code started out as a PyTorch port of Ho et al's diffusion models:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
+
+Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
+"""
+
+from pdb import set_trace as st
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+from .nn import mean_flat
+from .losses import normal_kl, discretized_gaussian_log_likelihood
+from . import dist_util
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear": # * used here
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(beta_start,
+ beta_end,
+ num_diffusion_timesteps,
+ dtype=np.float64)
+
+ elif schedule_name == "linear_simple":
+ return betas_for_alpha_bar_linear_simple(num_diffusion_timesteps,
+ lambda t: 0.001 / (1.001 - t))
+
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
+ )
+
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar_linear_simple(num_diffusion_timesteps,
+ alpha_bar,
+ max_beta=0.999):
+ """proposed by Chen Ting, on the importance of noise schedule, arXiv 2023.
+ gamma = 1-t
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t = i / num_diffusion_timesteps
+ betas.append(min(max_beta, alpha_bar(t)))
+
+ return betas
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+ V = enum.auto() # the model predicts velosity
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+ '''
+ defaults:
+ learn_sigma=False,
+ diffusion_steps=1000,
+ noise_schedule="linear",
+ timestep_respacing="",
+ use_kl=False,
+ predict_xstart=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ '''
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ standarization_xt=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+ self.standarization_xt = standarization_xt
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(
+ 1.0 / self.alphas_cumprod -
+ 1) # sqrt(1/cumprod(alphas) - 1), for calculating x_0 from x_t
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
+ (1.0 - self.alphas_cumprod))
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:]))
+ self.posterior_mean_coef1 = (betas *
+ np.sqrt(self.alphas_cumprod_prev) /
+ (1.0 - self.alphas_cumprod))
+ self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
+ np.sqrt(alphas) /
+ (1.0 - self.alphas_cumprod))
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
+ x_start)
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
+ x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
+ t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None, return_detail=False):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ alpha_bar = _extract_into_tensor(self.sqrt_alphas_cumprod, t,
+ x_start.shape)
+ one_minus_alpha_bar = _extract_into_tensor(
+ self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ xt = (alpha_bar * x_start + one_minus_alpha_bar * noise)
+
+ if self.standarization_xt:
+ xt = xt / (1e-5 + xt.std(dim=list(range(1, xt.ndim)), keepdim=True)
+ ) # B 1 1 1 #
+
+ if return_detail:
+ return xt, alpha_bar, one_minus_alpha_bar
+
+ return xt
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
+ x_start +
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) *
+ x_t)
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t,
+ x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape)
+ assert (posterior_mean.shape[0] == posterior_variance.shape[0] ==
+ posterior_log_variance_clipped.shape[0] == x_start.shape[0])
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self,
+ model,
+ x,
+ t,
+ c=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ mixing_normal=False,
+ direct_return_model_output=False):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ # lazy import to avoid partially initialized import
+ from guided_diffusion.continuous_diffusion_utils import get_mixed_prediction
+
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ # if mixing_normal is not None:
+ # t = t / self.num_timesteps # [0,1] for SDE diffusion
+
+ B, C = x.shape[:2]
+ assert t.shape == (B, )
+ model_output = model(x,
+ self._scale_timesteps(t),
+ c=c,
+ mixing_normal=mixing_normal,
+ **model_kwargs)
+
+ if direct_return_model_output:
+ return model_output
+
+ if self.model_mean_type == ModelMeanType.V:
+ v_transformed_to_eps_flag = False
+
+ # st()
+ if mixing_normal: # directly change the model predicted eps logits
+ if self.model_mean_type == ModelMeanType.START_X:
+ mixing_component = self.get_mixing_component_x0(x,
+ t,
+ enabled=True)
+ else:
+ assert self.model_mean_type in [
+ ModelMeanType.EPSILON, ModelMeanType.V
+ ]
+ mixing_component = self.get_mixing_component(x,
+ t,
+ enabled=True)
+
+ if self.model_mean_type == ModelMeanType.V:
+ model_output = self._predict_eps_from_z_and_v(
+ x, t, model_output)
+ v_transformed_to_eps_flag = True
+ # ! transform result to v first?
+ # model_output =
+ model_output = get_mixed_prediction(True, model_output,
+ model.mixing_logit,
+ mixing_component)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE
+ ]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ # ?
+ ModelVarType.FIXED_LARGE: ( # * used here
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(
+ np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t,
+ x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t,
+ xprev=model_output))
+ model_mean = model_output
+ elif self.model_mean_type in [
+ ModelMeanType.START_X, ModelMeanType.EPSILON, ModelMeanType.V
+ ]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else: # * used here
+ if self.model_mean_type == ModelMeanType.V:
+ assert v_transformed_to_eps_flag # type: ignore
+ pred_xstart = process_xstart( # * return the x_0 using self._predict_xstart_from_eps as the denoised_fn
+ self._predict_xstart_from_eps(x_t=x, t=t,
+ eps=model_output))
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t)
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (model_mean.shape == model_log_variance.shape ==
+ pred_xstart.shape == x.shape)
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
+ x_t.shape) * x_t -
+ _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
+ x_t.shape) * eps)
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
+ * xprev - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
+ x_t.shape) * x_t)
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
+ x_t.shape) * x_t -
+ pred_xstart) / _extract_into_tensor(
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ # https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddpm.py#L288
+ def _predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) *
+ x_t - _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
+ t, x_t.shape) * v)
+
+ def _predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
+ x_t.shape) * x_t)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * (1000.0 / self.num_timesteps)
+ return t
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
+ new_mean = (p_mean_var["mean"].float() +
+ p_mean_var["variance"] * gradient.float())
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+
+ See condition_mean() for details on cond_fn.
+
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
+ x, self._scale_timesteps(t), **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ cond=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ mixing_normal=False,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(model,
+ x,
+ t,
+ c=cond,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mixing_normal=mixing_normal)
+ noise = th.randn_like(x)
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn,
+ out,
+ x,
+ t,
+ model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(
+ 0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def get_mixing_component(self, x_noisy, t, enabled):
+ # alpha_bars = th.gather(self._alpha_bars, 0, timestep-1)
+ if enabled:
+ # one_minus_alpha_bars_sqrt = utils.view4D(th.sqrt(1.0 - alpha_bars), size)
+ one_minus_alpha_bars_sqrt = _extract_into_tensor(
+ self.sqrt_one_minus_alphas_cumprod, t, x_noisy.shape)
+ mixing_component = one_minus_alpha_bars_sqrt * x_noisy
+ else:
+ mixing_component = None
+
+ return mixing_component
+
+ def get_mixing_component_x0(self, x_noisy, t, enabled):
+ # alpha_bars = th.gather(self._alpha_bars, 0, timestep-1)
+ if enabled:
+ # one_minus_alpha_bars_sqrt = utils.view4D(th.sqrt(1.0 - alpha_bars), size)
+ one_minus_alpha_bars_sqrt = _extract_into_tensor(
+ self.sqrt_alphas_cumprod, t, x_noisy.shape)
+ mixing_component = one_minus_alpha_bars_sqrt * x_noisy
+ else:
+ mixing_component = None
+
+ return mixing_component
+
+ def p_sample_mixing_component(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+
+ assert self.model_mean_type == ModelMeanType.EPSILON, 'currently LSGM only implemented for EPSILON prediction'
+
+ out = self.p_mean_variance(
+ model,
+ x,
+ t / self.
+ num_timesteps, # trained on SDE diffusion, normalize steps to (0,1]
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # mixing_component = self.get_mixing_component(x, t, enabled=True)
+ # out['mean'] = get_mixed_prediction(model.mixed_prediction, out['mean'], model.mixing_logit, mixing_component)
+
+ noise = th.randn_like(x)
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn,
+ out,
+ x,
+ t,
+ model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(
+ 0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ cond=None,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ mixing_normal=False,
+ ):
+ """
+ Generate samples from the model.
+
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ cond=cond,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ mixing_normal=mixing_normal):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ cond=None,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ mixing_normal=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = dist_util.dev()
+ # device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(model,
+ img,
+ t,
+ cond=cond,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ mixing_normal=mixing_normal)
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ cond=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ mixing_normal=False,
+ objv_inference=False,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+
+ if unconditional_guidance_scale != 1.0:
+ assert cond is not None
+ if unconditional_conditioning is None:
+ unconditional_conditioning = th.zeros_like(
+ cond['c_crossattn']
+ ) # ImageEmbedding adopts zero as the null embedding
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ # e_t = self.model.apply_model(x, t, c)
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ c=cond,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mixing_normal=mixing_normal,
+ )
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ elif objv_inference:
+ assert cond is not None
+ x_in = th.cat([x] * 2)
+ t_in = th.cat([t] * 2)
+ c_in = {}
+ for k in cond:
+ c_in[k] = th.cat([
+ unconditional_conditioning[k].repeat_interleave(
+ cond[k].shape[0], 0), cond[k]
+ ])
+
+ model_uncond, model_t = self.p_mean_variance(
+ model,
+ x_in,
+ t_in,
+ c=c_in,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mixing_normal=mixing_normal,
+ direct_return_model_output=True, # ! compat with _wrapper
+ ).chunk(2)
+ # Usually our model outputs epsilon, but we re-derive it
+ # model_uncond, model_t = model(x_in, self._scale_timesteps(t_in), c=c_in, mixing_normal=mixing_normal, **model_kwargs).chunk(2)
+
+ # in case we used x_start or x_prev prediction.
+ # st()
+
+ # ! guidance
+ # e_t_uncond, e_t = eps.chunk(2)
+ model_out = model_uncond + unconditional_guidance_scale * (
+ model_t - model_uncond)
+
+ if self.model_mean_type == ModelMeanType.V:
+ eps = self._predict_eps_from_z_and_v(x, t, model_out)
+
+ # eps = self._predict_eps_from_xstart(x_in, t_in, out["pred_xstart"])
+
+ else:
+ assert cond is not None
+ x_in = th.cat([x] * 2)
+ t_in = th.cat([t] * 2)
+ c_in = {
+ 'c_crossattn':
+ th.cat([
+ unconditional_conditioning.repeat_interleave(
+ cond['c_crossattn'].shape[0], dim=0),
+ cond['c_crossattn']
+ ])
+ }
+
+ # c_in = {}
+ # for k in cond:
+ # c_in[k] = th.cat([unconditional_conditioning[k], cond[k]])
+
+ out = self.p_mean_variance(
+ model,
+ x_in,
+ t_in,
+ c=c_in,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mixing_normal=mixing_normal,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x_in, t_in, out["pred_xstart"])
+
+ # ! guidance
+ e_t_uncond, e_t = eps.chunk(2)
+ # st()
+ eps = e_t_uncond + unconditional_guidance_scale * (e_t -
+ e_t_uncond)
+
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn,
+ out,
+ x,
+ t,
+ model_kwargs=model_kwargs)
+
+ # eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+ # ! re-derive xstart
+ pred_x0 = self._predict_xstart_from_eps(x, t, eps)
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
+ x.shape)
+ sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
+ th.sqrt(1 - alpha_bar / alpha_bar_prev))
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (pred_x0 * th.sqrt(alpha_bar_prev) +
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": pred_x0}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape)
+ * x - out["pred_xstart"]) / _extract_into_tensor(
+ self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t,
+ x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) +
+ th.sqrt(1 - alpha_bar_next) * eps)
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ cond=None,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ mixing_normal=False,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ objv_inference=False,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ cond=cond,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ mixing_normal=mixing_normal,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ objv_inference=objv_inference,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ cond=None,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ mixing_normal=False,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ objv_inference=False,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ cond=cond,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ mixing_normal=mixing_normal,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ objv_inference=objv_inference,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(self,
+ model,
+ x_start,
+ x_t,
+ t,
+ clip_denoised=True,
+ model_kwargs=None):
+ """
+ Get a term for the variational lower-bound.
+
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t)
+ out = self.p_mean_variance(model,
+ x_t,
+ t,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs)
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"],
+ out["log_variance"])
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"])
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self,
+ model,
+ x_start,
+ t,
+ model_kwargs=None,
+ noise=None,
+ return_detail=False):
+ """
+ Compute training losses for a single timestep.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None: # * micro_cond
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start) # x_start is the x0 image
+ x_t = self.q_sample(x_start,
+ t,
+ noise=noise,
+ return_detail=return_detail
+ ) # * add noise according to predefined schedule
+ if return_detail:
+ x_t, alpha_bar, _ = x_t
+
+ # terms = {}
+ terms = {"x_t": x_t}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(
+ x_t, self._scale_timesteps(t), **model_kwargs
+ ) # directly predict epsilon or x_0; no learned sigma
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = th.split(model_output,
+ C,
+ dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values],
+ dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X:
+ self.q_posterior_mean_variance(x_start=x_start, x_t=x_t,
+ t=t)[0],
+ ModelMeanType.START_X:
+ x_start,
+ ModelMeanType.EPSILON:
+ noise,
+ }[self.model_mean_type] # ModelMeanType.EPSILON
+ # st()
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output)**2)
+
+ terms['model_output'] = model_output
+ # terms['target'] = target # TODO, flag.
+ if return_detail:
+ terms.update({
+ 'diffusion_target': target,
+ 'alpha_bar': alpha_bar,
+ # 'one_minus_alpha':one_minus_alpha
+ # 'noise': noise
+ })
+
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+
+ This term can't be optimized, as it only depends on the encoder.
+
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size,
+ device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean,
+ logvar1=qt_log_variance,
+ mean2=0.0,
+ logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self,
+ model,
+ x_start,
+ clip_denoised=True,
+ model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch,
+ out["pred_xstart"])
+ mse.append(mean_flat((eps - noise)**2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
diff --git a/guided_diffusion/image_datasets.py b/guided_diffusion/image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..93022ae208a01e72eb162d7b63c07bf94a6afe3b
--- /dev/null
+++ b/guided_diffusion/image_datasets.py
@@ -0,0 +1,167 @@
+import math
+import random
+
+from PIL import Image
+import blobfile as bf
+from mpi4py import MPI
+import numpy as np
+from torch.utils.data import DataLoader, Dataset
+
+
+def load_data(
+ *,
+ data_dir,
+ batch_size,
+ image_size,
+ class_cond=False,
+ deterministic=False,
+ random_crop=False,
+ random_flip=True,
+):
+ """
+ For a dataset, create a generator over (images, kwargs) pairs.
+
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
+ more keys, each of which map to a batched Tensor of their own.
+ The kwargs dict can be used for class labels, in which case the key is "y"
+ and the values are integer tensors of class labels.
+
+ :param data_dir: a dataset directory.
+ :param batch_size: the batch size of each returned pair.
+ :param image_size: the size to which images are resized.
+ :param class_cond: if True, include a "y" key in returned dicts for class
+ label. If classes are not available and this is true, an
+ exception will be raised.
+ :param deterministic: if True, yield results in a deterministic order.
+ :param random_crop: if True, randomly crop the images for augmentation.
+ :param random_flip: if True, randomly flip the images for augmentation.
+ """
+ if not data_dir:
+ raise ValueError("unspecified data directory")
+ all_files = _list_image_files_recursively(data_dir)
+ classes = None
+ if class_cond:
+ # Assume classes are the first part of the filename,
+ # before an underscore.
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
+ classes = [sorted_classes[x] for x in class_names]
+ dataset = ImageDataset(
+ image_size,
+ all_files,
+ classes=classes,
+ shard=MPI.COMM_WORLD.Get_rank(),
+ num_shards=MPI.COMM_WORLD.Get_size(),
+ random_crop=random_crop,
+ random_flip=random_flip,
+ )
+ if deterministic:
+ loader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
+ )
+ else:
+ loader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
+ )
+ while True:
+ yield from loader
+
+
+def _list_image_files_recursively(data_dir):
+ results = []
+ for entry in sorted(bf.listdir(data_dir)):
+ full_path = bf.join(data_dir, entry)
+ ext = entry.split(".")[-1]
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
+ results.append(full_path)
+ elif bf.isdir(full_path):
+ results.extend(_list_image_files_recursively(full_path))
+ return results
+
+
+class ImageDataset(Dataset):
+ def __init__(
+ self,
+ resolution,
+ image_paths,
+ classes=None,
+ shard=0,
+ num_shards=1,
+ random_crop=False,
+ random_flip=True,
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.local_images = image_paths[shard:][::num_shards]
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
+ self.random_crop = random_crop
+ self.random_flip = random_flip
+
+ def __len__(self):
+ return len(self.local_images)
+
+ def __getitem__(self, idx):
+ path = self.local_images[idx]
+ with bf.BlobFile(path, "rb") as f:
+ pil_image = Image.open(f)
+ pil_image.load()
+ pil_image = pil_image.convert("RGB")
+
+ if self.random_crop:
+ arr = random_crop_arr(pil_image, self.resolution)
+ else:
+ arr = center_crop_arr(pil_image, self.resolution)
+
+ if self.random_flip and random.random() < 0.5:
+ arr = arr[:, ::-1]
+
+ arr = arr.astype(np.float32) / 127.5 - 1
+
+ out_dict = {}
+ if self.local_classes is not None:
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
+ return np.transpose(arr, [2, 0, 1]), out_dict
+
+
+def center_crop_arr(pil_image, image_size):
+ # We are not on a new enough PIL to support the `reducing_gap`
+ # argument, which uses BOX downsampling at powers of two first.
+ # Thus, we do it by hand to improve downsample quality.
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
+
+
+def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
+
+ # We are not on a new enough PIL to support the `reducing_gap`
+ # argument, which uses BOX downsampling at powers of two first.
+ # Thus, we do it by hand to improve downsample quality.
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = smaller_dim_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
diff --git a/guided_diffusion/logger.py b/guided_diffusion/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..77ea89ab2f7cd5a897175ff13ad01135bef010db
--- /dev/null
+++ b/guided_diffusion/logger.py
@@ -0,0 +1,500 @@
+"""
+Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
+https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
+"""
+
+import os
+import sys
+import shutil
+import os.path as osp
+import json
+import time
+import datetime
+import tempfile
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+
+DEBUG = 10
+INFO = 20
+WARN = 30
+ERROR = 40
+
+DISABLED = 50
+
+
+class KVWriter(object):
+ def writekvs(self, kvs):
+ raise NotImplementedError
+
+
+class SeqWriter(object):
+ def writeseq(self, seq):
+ raise NotImplementedError
+
+
+class HumanOutputFormat(KVWriter, SeqWriter):
+ def __init__(self, filename_or_file):
+ if isinstance(filename_or_file, str):
+ self.file = open(filename_or_file, "wt")
+ self.own_file = True
+ else:
+ assert hasattr(filename_or_file, "read"), (
+ "expected file or str, got %s" % filename_or_file
+ )
+ self.file = filename_or_file
+ self.own_file = False
+
+ def writekvs(self, kvs):
+ # Create strings for printing
+ key2str = {}
+ for (key, val) in sorted(kvs.items()):
+ if hasattr(val, "__float__"):
+ valstr = "%-8.3g" % val
+ else:
+ valstr = str(val)
+ key2str[self._truncate(key)] = self._truncate(valstr)
+
+ # Find max widths
+ if len(key2str) == 0:
+ print("WARNING: tried to write empty key-value dict")
+ return
+ else:
+ keywidth = max(map(len, key2str.keys()))
+ valwidth = max(map(len, key2str.values()))
+
+ # Write out the data
+ dashes = "-" * (keywidth + valwidth + 7)
+ lines = [dashes]
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
+ lines.append(
+ "| %s%s | %s%s |"
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
+ )
+ lines.append(dashes)
+ self.file.write("\n".join(lines) + "\n")
+
+ # Flush the output to the file
+ self.file.flush()
+
+ def _truncate(self, s):
+ maxlen = 30
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
+
+ def writeseq(self, seq):
+ seq = list(seq)
+ for (i, elem) in enumerate(seq):
+ self.file.write(elem)
+ if i < len(seq) - 1: # add space unless this is the last one
+ self.file.write(" ")
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ if self.own_file:
+ self.file.close()
+
+
+class JSONOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "wt")
+
+ def writekvs(self, kvs):
+ for k, v in sorted(kvs.items()):
+ if hasattr(v, "dtype"):
+ kvs[k] = float(v)
+ self.file.write(json.dumps(kvs) + "\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class CSVOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "w+t")
+ self.keys = []
+ self.sep = ","
+
+ def writekvs(self, kvs):
+ # Add our current row to the history
+ extra_keys = list(kvs.keys() - self.keys)
+ extra_keys.sort()
+ if extra_keys:
+ self.keys.extend(extra_keys)
+ self.file.seek(0)
+ lines = self.file.readlines()
+ self.file.seek(0)
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ self.file.write(k)
+ self.file.write("\n")
+ for line in lines[1:]:
+ self.file.write(line[:-1])
+ self.file.write(self.sep * len(extra_keys))
+ self.file.write("\n")
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ v = kvs.get(k)
+ if v is not None:
+ self.file.write(str(v))
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class TensorBoardOutputFormat(KVWriter):
+ """
+ Dumps key/value pairs into TensorBoard's numeric format.
+ """
+
+ def __init__(self, dir):
+ os.makedirs(dir, exist_ok=True)
+ self.dir = dir
+ self.step = 1
+ prefix = "events"
+ path = osp.join(osp.abspath(dir), prefix)
+ import tensorflow as tf
+ from tensorflow.python import pywrap_tensorflow
+ from tensorflow.core.util import event_pb2
+ from tensorflow.python.util import compat
+
+ self.tf = tf
+ self.event_pb2 = event_pb2
+ self.pywrap_tensorflow = pywrap_tensorflow
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
+
+ def writekvs(self, kvs):
+ def summary_val(k, v):
+ kwargs = {"tag": k, "simple_value": float(v)}
+ return self.tf.Summary.Value(**kwargs)
+
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
+ event.step = (
+ self.step
+ ) # is there any reason why you'd want to specify the step?
+ self.writer.WriteEvent(event)
+ self.writer.Flush()
+ self.step += 1
+
+ def close(self):
+ if self.writer:
+ self.writer.Close()
+ self.writer = None
+
+
+def make_output_format(format, ev_dir, log_suffix=""):
+ os.makedirs(ev_dir, exist_ok=True)
+ if format == "stdout":
+ return HumanOutputFormat(sys.stdout)
+ elif format == "log":
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
+ elif format == "json":
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
+ elif format == "csv":
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
+ elif format == "tensorboard":
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
+ else:
+ raise ValueError("Unknown format specified: %s" % (format,))
+
+
+# ================================================================
+# API
+# ================================================================
+
+
+def logkv(key, val):
+ """
+ Log a value of some diagnostic
+ Call this once for each diagnostic quantity, each iteration
+ If called many times, last value will be used.
+ """
+ get_current().logkv(key, val)
+
+
+def logkv_mean(key, val):
+ """
+ The same as logkv(), but if called many times, values averaged.
+ """
+ get_current().logkv_mean(key, val)
+
+
+def logkvs(d):
+ """
+ Log a dictionary of key-value pairs
+ """
+ for (k, v) in d.items():
+ logkv(k, v)
+
+
+def dumpkvs():
+ """
+ Write all of the diagnostics from the current iteration
+ """
+ return get_current().dumpkvs()
+
+
+def getkvs():
+ return get_current().name2val
+
+
+def log(*args, level=INFO):
+ """
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
+ """
+ get_current().log(*args, level=level)
+
+
+def debug(*args):
+ log(*args, level=DEBUG)
+
+
+def info(*args):
+ log(*args, level=INFO)
+
+
+def warn(*args):
+ log(*args, level=WARN)
+
+
+def error(*args):
+ log(*args, level=ERROR)
+
+
+def set_level(level):
+ """
+ Set logging threshold on current logger.
+ """
+ get_current().set_level(level)
+
+
+def set_comm(comm):
+ get_current().set_comm(comm)
+
+
+def get_dir():
+ """
+ Get directory that log files are being written to.
+ will be None if there is no output directory (i.e., if you didn't call start)
+ """
+ return get_current().get_dir()
+
+def get_tensorboard_writer():
+ """get the tensorboard writer
+ """
+ pass
+
+
+record_tabular = logkv
+dump_tabular = dumpkvs
+
+
+@contextmanager
+def profile_kv(scopename):
+ logkey = "wait_" + scopename
+ tstart = time.time()
+ try:
+ yield
+ finally:
+ get_current().name2val[logkey] += time.time() - tstart
+
+
+def profile(n):
+ """
+ Usage:
+ @profile("my_func")
+ def my_func(): code
+ """
+
+ def decorator_with_name(func):
+ def func_wrapper(*args, **kwargs):
+ with profile_kv(n):
+ return func(*args, **kwargs)
+
+ return func_wrapper
+
+ return decorator_with_name
+
+
+# ================================================================
+# Backend
+# ================================================================
+
+
+def get_current():
+ if Logger.CURRENT is None:
+ _configure_default_logger()
+
+ return Logger.CURRENT
+
+
+class Logger(object):
+ DEFAULT = None # A logger with no output files. (See right below class definition)
+ # So that you can still log to the terminal without setting up any output files
+ CURRENT = None # Current logger being used by the free functions above
+
+ def __init__(self, dir, output_formats, comm=None):
+ self.name2val = defaultdict(float) # values this iteration
+ self.name2cnt = defaultdict(int)
+ self.level = INFO
+ self.dir = dir
+ self.output_formats = output_formats
+ self.comm = comm
+
+ # Logging API, forwarded
+ # ----------------------------------------
+ def logkv(self, key, val):
+ self.name2val[key] = val
+
+ def logkv_mean(self, key, val):
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
+ self.name2cnt[key] = cnt + 1
+
+ def dumpkvs(self):
+ if self.comm is None:
+ d = self.name2val
+ else:
+ d = mpi_weighted_mean(
+ self.comm,
+ {
+ name: (val, self.name2cnt.get(name, 1))
+ for (name, val) in self.name2val.items()
+ },
+ )
+ if self.comm.rank != 0:
+ d["dummy"] = 1 # so we don't get a warning about empty dict
+ out = d.copy() # Return the dict for unit testing purposes
+ for fmt in self.output_formats:
+ if isinstance(fmt, KVWriter):
+ fmt.writekvs(d)
+ self.name2val.clear()
+ self.name2cnt.clear()
+ return out
+
+ def log(self, *args, level=INFO):
+ if self.level <= level:
+ self._do_log(args)
+
+ # Configuration
+ # ----------------------------------------
+ def set_level(self, level):
+ self.level = level
+
+ def set_comm(self, comm):
+ self.comm = comm
+
+ def get_dir(self):
+ return self.dir
+
+ def close(self):
+ for fmt in self.output_formats:
+ fmt.close()
+
+ # Misc
+ # ----------------------------------------
+ def _do_log(self, args):
+ for fmt in self.output_formats:
+ if isinstance(fmt, SeqWriter):
+ fmt.writeseq(map(str, args))
+
+
+def get_rank_without_mpi_import():
+ # check environment variables here instead of importing mpi4py
+ # to avoid calling MPI_Init() when this module is imported
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
+ if varname in os.environ:
+ return int(os.environ[varname])
+ return 0
+
+
+def mpi_weighted_mean(comm, local_name2valcount):
+ """
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
+ Perform a weighted average over dicts that are each on a different node
+ Input: local_name2valcount: dict mapping key -> (value, count)
+ Returns: key -> mean
+ """
+ all_name2valcount = comm.gather(local_name2valcount)
+ if comm.rank == 0:
+ name2sum = defaultdict(float)
+ name2count = defaultdict(float)
+ for n2vc in all_name2valcount:
+ for (name, (val, count)) in n2vc.items():
+ try:
+ val = float(val)
+ except ValueError:
+ if comm.rank == 0:
+ warnings.warn(
+ "WARNING: tried to compute mean on non-float {}={}".format(
+ name, val
+ )
+ )
+ else:
+ name2sum[name] += val * count
+ name2count[name] += count
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
+ else:
+ return {}
+
+
+def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
+ """
+ If comm is provided, average all numerical stats across that comm
+ """
+ if dir is None:
+ dir = os.getenv("OPENAI_LOGDIR")
+ if dir is None:
+ dir = osp.join(
+ tempfile.gettempdir(),
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
+ )
+ assert isinstance(dir, str)
+ dir = os.path.expanduser(dir)
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
+
+ rank = get_rank_without_mpi_import()
+ if rank > 0:
+ log_suffix = log_suffix + "-rank%03i" % rank
+
+ if format_strs is None:
+ if rank == 0:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
+ else:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
+ format_strs = filter(None, format_strs)
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
+
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
+ if output_formats:
+ log("Logging to %s" % dir)
+
+
+def _configure_default_logger():
+ configure()
+ Logger.DEFAULT = Logger.CURRENT
+
+
+def reset():
+ if Logger.CURRENT is not Logger.DEFAULT:
+ Logger.CURRENT.close()
+ Logger.CURRENT = Logger.DEFAULT
+ log("Reset logger")
+
+
+@contextmanager
+def scoped_configure(dir=None, format_strs=None, comm=None):
+ prevlogger = Logger.CURRENT
+ configure(dir=dir, format_strs=format_strs, comm=comm)
+ try:
+ yield
+ finally:
+ Logger.CURRENT.close()
+ Logger.CURRENT = prevlogger
+
diff --git a/guided_diffusion/losses.py b/guided_diffusion/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..251e42e4f36a31bb5e1aeda874b3a45d722000a2
--- /dev/null
+++ b/guided_diffusion/losses.py
@@ -0,0 +1,77 @@
+"""
+Helpers for various likelihood-based losses. These are ported from the original
+Ho et al. diffusion models codebase:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
+"""
+
+import numpy as np
+
+import torch as th
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/guided_diffusion/nn.py b/guided_diffusion/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fc05ff451c095b98e54a2098223d10816abc32d
--- /dev/null
+++ b/guided_diffusion/nn.py
@@ -0,0 +1,171 @@
+"""
+Various utilities for neural networks.
+"""
+
+import math
+
+import torch as th
+import torch.nn as nn
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * th.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(th.autograd.Function):
+ @staticmethod
+ @th.autocast(device_type='cuda')
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with th.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with th.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = th.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/guided_diffusion/resample.py b/guided_diffusion/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..c82eccdcd47c468d41e7cbe02de6a731f2c9bf81
--- /dev/null
+++ b/guided_diffusion/resample.py
@@ -0,0 +1,154 @@
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/guided_diffusion/respace.py b/guided_diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..c72a0905aed132bdc03e6251a44b39f5ca9a329c
--- /dev/null
+++ b/guided_diffusion/respace.py
@@ -0,0 +1,136 @@
+import numpy as np
+import torch as th
+from pdb import set_trace as st
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, c=None, mixing_normal=False, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ # st()
+ # assert mixing_normal
+ new_ts = new_ts / self.original_num_steps # already respaced to 1000 steps
+ if mixing_normal:
+ self.mixing_logit = self.model.ddp_model(x=None, # will be queried in gaussian_diffusion.py
+ timesteps=None,
+ get_attr='mixing_logit')
+ return self.model.apply_model_inference(x,new_ts, c, **kwargs) # send in "self" not "Unet", to use cldm
diff --git a/guided_diffusion/script_util.py b/guided_diffusion/script_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e1e3a0a2fc26959af5e627d157d1e406effe7b
--- /dev/null
+++ b/guided_diffusion/script_util.py
@@ -0,0 +1,701 @@
+import argparse
+import inspect
+
+from pdb import set_trace as st
+
+from cldm.cldm import ControlledUnetModel, ControlNet
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+# from .unet_old import SuperResModel, UNetModel, EncoderUNetModel # , UNetModelWithHint
+from .unet import SuperResModel, UNetModel, EncoderUNetModel # , UNetModelWithHint
+import torch as th
+from dit.dit_models_xformers import DiT_models
+if th.cuda.is_available():
+ from xformers.triton import FusedLayerNorm as LayerNorm
+
+NUM_CLASSES = 1000
+
+
+def diffusion_defaults():
+ """
+ Defaults for image and classifier training.
+ """
+ return dict(
+ learn_sigma=False,
+ diffusion_steps=1000,
+ noise_schedule="linear",
+ standarization_xt=False,
+ timestep_respacing="",
+ use_kl=False,
+ predict_xstart=False,
+ predict_v=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ mixed_prediction=False, # ! to assign later
+ )
+
+
+def classifier_defaults():
+ """
+ Defaults for classifier models.
+ """
+ return dict(
+ image_size=64,
+ classifier_use_fp16=False,
+ classifier_width=128,
+ classifier_depth=2,
+ classifier_attention_resolutions="32,16,8", # 16
+ classifier_use_scale_shift_norm=True, # False
+ classifier_resblock_updown=True, # False
+ classifier_pool="attention",
+ )
+
+
+def control_net_defaults():
+ res = dict(
+ only_mid_control=False, # TODO
+ control_key='img',
+ normalize_clip_encoding=False, # zero-shot text inference
+ scale_clip_encoding=1.0,
+ cfg_dropout_prob=0.0, # dropout condition for CFG training
+ # cond_key='caption',
+ )
+ return res
+
+
+def continuous_diffusion_defaults():
+ # NVlabs/LSGM/train_vada.py
+ res = dict(
+ sde_time_eps=1e-2,
+ sde_beta_start=0.1,
+ sde_beta_end=20.0,
+ sde_sde_type='vpsde',
+ sde_sigma2_0=0.0, # ?
+ iw_sample_p='drop_sigma2t_iw',
+ iw_sample_q='ll_iw',
+ iw_subvp_like_vp_sde=False,
+ train_vae=True,
+ pred_type='eps', # [x0, eps]
+ # joint_train=False,
+ p_rendering_loss=False,
+ unfix_logit=False,
+ loss_type='eps',
+ loss_weight='simple', # snr snr_sqrt sigmoid_snr
+ # train_vae_denoise_rendering=False,
+ diffusion_ce_anneal=True,
+ enable_mixing_normal=True,
+ )
+
+ return res
+
+
+def model_and_diffusion_defaults():
+ """
+ Defaults for image training.
+ """
+ res = dict(
+ # image_size=64,
+ diffusion_input_size=224,
+ num_channels=128,
+ num_res_blocks=2,
+ num_heads=4,
+ num_heads_upsample=-1,
+ num_head_channels=-1,
+ attention_resolutions="16,8",
+ channel_mult="",
+ dropout=0.0,
+ class_cond=False,
+ use_checkpoint=False,
+ use_scale_shift_norm=True,
+ resblock_updown=False,
+ use_fp16=False,
+ use_new_attention_order=False,
+ denoise_in_channels=3,
+ denoise_out_channels=3,
+ # ! controlnet args
+ create_controlnet=False,
+ create_dit=False,
+ create_unet_with_hint=False,
+ dit_model_arch='DiT-L/2',
+ # ! ldm unet support
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=-1, # custom transformer support
+ roll_out=False, # whether concat in batch, not channel
+ n_embed=
+ None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ mixing_logit_init=-6,
+ hint_channels=3,
+ # unconditional_guidance_scale=1.0,
+ # normalize_clip_encoding=False, # for zero-shot conditioning
+ )
+ res.update(diffusion_defaults())
+ # res.update(continuous_diffusion_defaults())
+ return res
+
+
+def classifier_and_diffusion_defaults():
+ res = classifier_defaults()
+ res.update(diffusion_defaults())
+ return res
+
+
+def create_model_and_diffusion(
+ # image_size,
+ diffusion_input_size,
+ class_cond,
+ learn_sigma,
+ num_channels,
+ num_res_blocks,
+ channel_mult,
+ num_heads,
+ num_head_channels,
+ num_heads_upsample,
+ attention_resolutions,
+ dropout,
+ diffusion_steps,
+ noise_schedule,
+ timestep_respacing,
+ use_kl,
+ predict_xstart,
+ predict_v,
+ rescale_timesteps,
+ rescale_learned_sigmas,
+ use_checkpoint,
+ use_scale_shift_norm,
+ resblock_updown,
+ use_fp16,
+ use_new_attention_order,
+ denoise_in_channels,
+ denoise_out_channels,
+ standarization_xt,
+ mixed_prediction,
+ # controlnet
+ create_controlnet,
+ # only_mid_control,
+ # control_key,
+ use_spatial_transformer,
+ transformer_depth,
+ context_dim,
+ n_embed,
+ legacy,
+ mixing_logit_init,
+ create_dit,
+ create_unet_with_hint,
+ dit_model_arch,
+ roll_out,
+ hint_channels,
+ # unconditional_guidance_scale,
+ # normalize_clip_encoding,
+):
+ model = create_model(
+ diffusion_input_size,
+ num_channels,
+ num_res_blocks,
+ channel_mult=channel_mult,
+ learn_sigma=learn_sigma,
+ class_cond=class_cond,
+ use_checkpoint=use_checkpoint,
+ attention_resolutions=attention_resolutions,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dropout=dropout,
+ resblock_updown=resblock_updown,
+ use_fp16=use_fp16,
+ use_new_attention_order=use_new_attention_order,
+ denoise_in_channels=denoise_in_channels,
+ denoise_out_channels=denoise_out_channels,
+ mixed_prediction=mixed_prediction,
+ create_controlnet=create_controlnet,
+ # only_mid_control=only_mid_control,
+ # control_key=control_key,
+ use_spatial_transformer=use_spatial_transformer,
+ transformer_depth=transformer_depth,
+ context_dim=context_dim,
+ n_embed=n_embed,
+ legacy=legacy,
+ mixing_logit_init=mixing_logit_init,
+ create_dit=create_dit,
+ create_unet_with_hint=create_unet_with_hint,
+ dit_model_arch=dit_model_arch,
+ roll_out=roll_out,
+ hint_channels=hint_channels,
+ # normalize_clip_encoding=normalize_clip_encoding,
+ )
+ diffusion = create_gaussian_diffusion(
+ diffusion_steps=diffusion_steps,
+ learn_sigma=learn_sigma,
+ noise_schedule=noise_schedule,
+ use_kl=use_kl,
+ predict_xstart=predict_xstart,
+ predict_v=predict_v,
+ rescale_timesteps=rescale_timesteps,
+ rescale_learned_sigmas=rescale_learned_sigmas,
+ timestep_respacing=timestep_respacing,
+ standarization_xt=standarization_xt,
+ )
+ return model, diffusion
+
+
+def create_model(
+ image_size,
+ num_channels,
+ num_res_blocks,
+ channel_mult="",
+ learn_sigma=False,
+ class_cond=False,
+ use_checkpoint=False,
+ attention_resolutions="16",
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ dropout=0,
+ resblock_updown=False,
+ use_fp16=False,
+ use_new_attention_order=False,
+ # denoise_in_channels=3,
+ denoise_in_channels=-1,
+ denoise_out_channels=3,
+ mixed_prediction=False,
+ create_controlnet=False,
+ create_dit=False,
+ create_unet_with_hint=False,
+ dit_model_arch='DiT-L/2',
+ hint_channels=3,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ mixing_logit_init=-6,
+ roll_out=False,
+ # normalize_clip_encoding=False,
+):
+ if channel_mult == "":
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 448:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 320: # ffhq
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 224 and denoise_in_channels == 144: # ffhq
+ channel_mult = (1, 1, 2, 3, 4, 4)
+ elif image_size == 224:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+
+ elif image_size == 32: # https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml#L37
+ channel_mult = (1, 2, 4, 4)
+
+ elif image_size == 16: # B,12,16,16. just for baseline check. not good performance.
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+ else:
+ channel_mult = tuple(
+ int(ch_mult) for ch_mult in channel_mult.split(","))
+
+ attention_ds = []
+ for res in attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ if create_controlnet:
+
+ controlledUnetModel = ControlledUnetModel(
+ image_size=image_size,
+ in_channels=denoise_in_channels,
+ model_channels=num_channels,
+ # out_channels=(3 if not learn_sigma else 6),
+ out_channels=(denoise_out_channels
+ if not learn_sigma else denoise_out_channels * 2),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(NUM_CLASSES if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ use_fp16=use_fp16,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_new_attention_order=use_new_attention_order,
+ mixed_prediction=mixed_prediction,
+ # ldm support
+ use_spatial_transformer=use_spatial_transformer,
+ transformer_depth=transformer_depth,
+ context_dim=context_dim,
+ n_embed=n_embed,
+ legacy=legacy,
+ mixing_logit_init=mixing_logit_init,
+ roll_out=roll_out
+ )
+
+ controlNet = ControlNet(
+ image_size=image_size,
+ in_channels=denoise_in_channels,
+ model_channels=num_channels,
+ # ! condition channels
+ hint_channels=hint_channels,
+ # out_channels=(3 if not learn_sigma else 6),
+ # out_channels=(denoise_out_channels
+ # if not learn_sigma else denoise_out_channels * 2),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ # num_classes=(NUM_CLASSES if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ use_fp16=use_fp16,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_new_attention_order=use_new_attention_order,
+ roll_out=roll_out
+ )
+ # mixed_prediction=mixed_prediction)
+
+ return controlledUnetModel, controlNet
+
+ elif create_dit:
+ return DiT_models[dit_model_arch](
+ input_size=image_size,
+ num_classes=0,
+ learn_sigma=learn_sigma,
+ in_channels=denoise_in_channels,
+ context_dim=context_dim, # add CLIP text embedding
+ roll_out=roll_out)
+ else:
+
+ # if create_unet_with_hint:
+ # unet_cls = UNetModelWithHint
+ # else:
+ unet_cls = UNetModel
+
+ # st()
+ return unet_cls(
+ image_size=image_size,
+ in_channels=denoise_in_channels,
+ model_channels=num_channels,
+ # out_channels=(3 if not learn_sigma else 6),
+ out_channels=(denoise_out_channels
+ if not learn_sigma else denoise_out_channels * 2),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(NUM_CLASSES if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ use_fp16=use_fp16,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_new_attention_order=use_new_attention_order,
+ mixed_prediction=mixed_prediction,
+ # ldm support
+ use_spatial_transformer=use_spatial_transformer,
+ transformer_depth=transformer_depth,
+ context_dim=context_dim,
+ n_embed=n_embed,
+ legacy=legacy,
+ mixing_logit_init=mixing_logit_init,
+ roll_out=roll_out,
+ hint_channels=hint_channels,
+ # normalize_clip_encoding=normalize_clip_encoding,
+ )
+
+
+def create_classifier_and_diffusion(
+ image_size,
+ classifier_use_fp16,
+ classifier_width,
+ classifier_depth,
+ classifier_attention_resolutions,
+ classifier_use_scale_shift_norm,
+ classifier_resblock_updown,
+ classifier_pool,
+ learn_sigma,
+ diffusion_steps,
+ noise_schedule,
+ timestep_respacing,
+ use_kl,
+ predict_xstart,
+ rescale_timesteps,
+ rescale_learned_sigmas,
+):
+ classifier = create_classifier(
+ image_size,
+ classifier_use_fp16,
+ classifier_width,
+ classifier_depth,
+ classifier_attention_resolutions,
+ classifier_use_scale_shift_norm,
+ classifier_resblock_updown,
+ classifier_pool,
+ )
+ diffusion = create_gaussian_diffusion(
+ steps=diffusion_steps,
+ learn_sigma=learn_sigma,
+ noise_schedule=noise_schedule,
+ use_kl=use_kl,
+ predict_xstart=predict_xstart,
+ rescale_timesteps=rescale_timesteps,
+ rescale_learned_sigmas=rescale_learned_sigmas,
+ timestep_respacing=timestep_respacing,
+ )
+ return classifier, diffusion
+
+
+def create_classifier(
+ image_size,
+ classifier_use_fp16,
+ classifier_width,
+ classifier_depth,
+ classifier_attention_resolutions,
+ classifier_use_scale_shift_norm,
+ classifier_resblock_updown,
+ classifier_pool,
+):
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+
+ attention_ds = []
+ for res in classifier_attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return EncoderUNetModel(
+ image_size=image_size,
+ in_channels=3,
+ model_channels=classifier_width,
+ out_channels=1000,
+ num_res_blocks=classifier_depth,
+ attention_resolutions=tuple(attention_ds),
+ channel_mult=channel_mult,
+ use_fp16=classifier_use_fp16,
+ num_head_channels=64,
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
+ resblock_updown=classifier_resblock_updown,
+ pool=classifier_pool,
+ )
+
+
+def sr_model_and_diffusion_defaults():
+ res = model_and_diffusion_defaults()
+ res["large_size"] = 256
+ res["small_size"] = 64
+ arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
+ for k in res.copy().keys():
+ if k not in arg_names:
+ del res[k]
+ return res
+
+
+def sr_create_model_and_diffusion(
+ large_size,
+ small_size,
+ class_cond,
+ learn_sigma,
+ num_channels,
+ num_res_blocks,
+ num_heads,
+ num_head_channels,
+ num_heads_upsample,
+ attention_resolutions,
+ dropout,
+ diffusion_steps,
+ noise_schedule,
+ timestep_respacing,
+ use_kl,
+ predict_xstart,
+ rescale_timesteps,
+ rescale_learned_sigmas,
+ use_checkpoint,
+ use_scale_shift_norm,
+ resblock_updown,
+ use_fp16,
+):
+ model = sr_create_model(
+ large_size,
+ small_size,
+ num_channels,
+ num_res_blocks,
+ learn_sigma=learn_sigma,
+ class_cond=class_cond,
+ use_checkpoint=use_checkpoint,
+ attention_resolutions=attention_resolutions,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dropout=dropout,
+ resblock_updown=resblock_updown,
+ use_fp16=use_fp16,
+ )
+ diffusion = create_gaussian_diffusion(
+ steps=diffusion_steps,
+ learn_sigma=learn_sigma,
+ noise_schedule=noise_schedule,
+ use_kl=use_kl,
+ predict_xstart=predict_xstart,
+ rescale_timesteps=rescale_timesteps,
+ rescale_learned_sigmas=rescale_learned_sigmas,
+ timestep_respacing=timestep_respacing,
+ )
+ return model, diffusion
+
+
+def sr_create_model(
+ large_size,
+ small_size,
+ num_channels,
+ num_res_blocks,
+ learn_sigma,
+ class_cond,
+ use_checkpoint,
+ attention_resolutions,
+ num_heads,
+ num_head_channels,
+ num_heads_upsample,
+ use_scale_shift_norm,
+ dropout,
+ resblock_updown,
+ use_fp16,
+):
+ _ = small_size # hack to prevent unused variable
+
+ if large_size == 512:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif large_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif large_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported large size: {large_size}")
+
+ attention_ds = []
+ for res in attention_resolutions.split(","):
+ attention_ds.append(large_size // int(res))
+
+ return SuperResModel(
+ image_size=large_size,
+ in_channels=3,
+ model_channels=num_channels,
+ out_channels=(3 if not learn_sigma else 6),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(NUM_CLASSES if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_fp16=use_fp16,
+ )
+
+
+def create_gaussian_diffusion(
+ *,
+ diffusion_steps=1000,
+ learn_sigma=False,
+ sigma_small=False,
+ noise_schedule="linear",
+ use_kl=False,
+ predict_xstart=False,
+ predict_v=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ timestep_respacing="",
+ standarization_xt=False,
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE # * used here.
+ if not timestep_respacing:
+ timestep_respacing = [diffusion_steps]
+
+ if predict_xstart:
+ model_mean_type = gd.ModelMeanType.START_X
+ elif predict_v:
+ model_mean_type = gd.ModelMeanType.V
+ else:
+ model_mean_type = gd.ModelMeanType.EPSILON
+
+ # model_mean_type=(
+ # gd.ModelMeanType.EPSILON if not predict_xstart else
+ # gd.ModelMeanType.START_X # * used gd.ModelMeanType.EPSILON
+ # ),
+
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=model_mean_type,
+ # (
+ # gd.ModelMeanType.EPSILON if not predict_xstart else
+ # gd.ModelMeanType.START_X # * used gd.ModelMeanType.EPSILON
+ # ),
+ model_var_type=((
+ gd.ModelVarType.FIXED_LARGE # * used here
+ if not sigma_small else gd.ModelVarType.FIXED_SMALL)
+ if not learn_sigma else gd.ModelVarType.LEARNED_RANGE),
+ loss_type=loss_type,
+ rescale_timesteps=rescale_timesteps,
+ standarization_xt=standarization_xt,
+ )
+
+
+def add_dict_to_argparser(parser, default_dict):
+ for k, v in default_dict.items():
+ v_type = type(v)
+ if v is None:
+ v_type = str
+ elif isinstance(v, bool):
+ v_type = str2bool
+ parser.add_argument(f"--{k}", default=v, type=v_type)
+
+
+def args_to_dict(args, keys):
+ return {k: getattr(args, k) for k in keys}
+
+
+def str2bool(v):
+ """
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("boolean value expected")
diff --git a/guided_diffusion/train_util.py b/guided_diffusion/train_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e5677c6fec0c7e0390fcecd5a0c3fa85051653
--- /dev/null
+++ b/guided_diffusion/train_util.py
@@ -0,0 +1,537 @@
+import copy
+from pdb import set_trace as st
+import functools
+import os
+import numpy as np
+
+import blobfile as bf
+import torch as th
+import torch.distributed as dist
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+
+from . import dist_util, logger
+from .fp16_util import MixedPrecisionTrainer
+from .nn import update_ema
+from .resample import LossAwareSampler, UniformSampler
+
+from pathlib import Path
+
+# For ImageNet experiments, this was a good default value.
+# We found that the lg_loss_scale quickly climbed to
+# 20-21 within the first ~1K steps of training.
+INITIAL_LOG_LOSS_SCALE = 20.0
+
+# use_amp = True
+# use_amp = False
+# if use_amp:
+# logger.log('ddpm use AMP to accelerate training')
+
+
+class TrainLoop:
+
+ def __init__(
+ self,
+ *,
+ model,
+ diffusion,
+ data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ schedule_sampler=None,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ use_amp=False,
+ model_name='ddpm',
+ **kwargs
+ ):
+
+ self.kwargs = kwargs
+ self.pool_512 = th.nn.AdaptiveAvgPool2d((512, 512))
+ self.pool_256 = th.nn.AdaptiveAvgPool2d((256, 256))
+ self.pool_128 = th.nn.AdaptiveAvgPool2d((128, 128))
+ self.pool_64 = th.nn.AdaptiveAvgPool2d((64, 64))
+
+ self.use_amp = use_amp
+ self.model_name = model_name
+ self.model = model
+ self.diffusion = diffusion
+ self.data = data
+ self.batch_size = batch_size
+ self.microbatch = microbatch if microbatch > 0 else batch_size
+ self.lr = lr
+ self.ema_rate = ([ema_rate] if isinstance(ema_rate, float) else
+ [float(x) for x in ema_rate.split(",")])
+ self.log_interval = log_interval
+ self.save_interval = save_interval
+ self.resume_checkpoint = resume_checkpoint
+ self.use_fp16 = use_fp16
+ self.fp16_scale_growth = fp16_scale_growth
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
+ self.weight_decay = weight_decay
+ self.lr_anneal_steps = lr_anneal_steps
+
+ self.step = 0
+ self.resume_step = 0
+ self.global_batch = self.batch_size * dist.get_world_size()
+
+ self.sync_cuda = th.cuda.is_available()
+ self._setup_model()
+ self._load_model()
+ self._setup_opt()
+
+ def _load_model(self):
+ self._load_and_sync_parameters()
+
+ def _setup_opt(self):
+ self.opt = AdamW(self.mp_trainer.master_params,
+ lr=self.lr,
+ weight_decay=self.weight_decay)
+
+ def _setup_model(self):
+
+ self.mp_trainer = MixedPrecisionTrainer(
+ model=self.model,
+ use_fp16=self.use_fp16,
+ fp16_scale_growth=self.fp16_scale_growth,
+ use_amp=self.use_amp,
+ model_name=self.model_name
+ )
+
+ if self.resume_step:
+ self._load_optimizer_state()
+ # Model was resumed, either due to a restart or a checkpoint
+ # being specified at the command line.
+ self.ema_params = [
+ self._load_ema_parameters(rate) for rate in self.ema_rate
+ ]
+ else:
+ self.ema_params = [
+ copy.deepcopy(self.mp_trainer.master_params)
+ for _ in range(len(self.ema_rate))
+ ]
+
+ # for compatability
+
+ # print('creating DDP')
+ if th.cuda.is_available():
+ self.use_ddp = True
+ self.ddpm_model = self.model
+ self.ddp_model = DDP(
+ self.model,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ if dist.get_world_size() > 1:
+ logger.warn("Distributed training requires CUDA. "
+ "Gradients will not be synchronized properly!")
+ self.use_ddp = False
+ self.ddp_model = self.model
+ # print('creating DDP done')
+
+
+ def _load_and_sync_parameters(self):
+ resume_checkpoint, resume_step = find_resume_checkpoint(
+ ) or self.resume_checkpoint
+
+ if resume_checkpoint:
+ if not Path(resume_checkpoint).exists():
+ logger.log(
+ f"failed to load model from checkpoint: {resume_checkpoint}, not exist"
+ )
+ return
+
+ # self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
+ self.resume_step = resume_step # TODO, EMA part
+ if dist.get_rank() == 0:
+ logger.log(
+ f"loading model from checkpoint: {resume_checkpoint}...")
+ # if model is None:
+ # model = self.model
+ self.model.load_state_dict(
+ dist_util.load_state_dict(
+ resume_checkpoint,
+ map_location=dist_util.dev(),
+ ))
+
+ # ! debugging, remove to check which key fails.
+ dist_util.sync_params(self.model.parameters())
+ # dist_util.sync_params(self.model.named_parameters())
+
+ def _load_ema_parameters(self,
+ rate,
+ model=None,
+ mp_trainer=None,
+ model_name='ddpm'):
+
+ if mp_trainer is None:
+ mp_trainer = self.mp_trainer
+ if model is None:
+ model = self.model
+
+ ema_params = copy.deepcopy(mp_trainer.master_params)
+
+ main_checkpoint, _ = find_resume_checkpoint(
+ self.resume_checkpoint, model_name) or self.resume_checkpoint
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step,
+ rate, model_name)
+ if ema_checkpoint:
+
+ if dist_util.get_rank() == 0:
+
+ if not Path(ema_checkpoint).exists():
+ logger.log(
+ f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist"
+ )
+ return
+
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
+
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ state_dict = dist_util.load_state_dict(
+ ema_checkpoint, map_location=map_location)
+
+ model_ema_state_dict = model.state_dict()
+
+ for k, v in state_dict.items():
+ if k in model_ema_state_dict.keys() and v.size(
+ ) == model_ema_state_dict[k].size():
+ model_ema_state_dict[k] = v
+
+ # elif 'IN' in k and model_name == 'rec' and getattr(model.decoder, 'decomposed_IN', False):
+ # model_ema_state_dict[k.replace('IN', 'superresolution.norm.norm_layer')] = v # decomposed IN
+
+ else:
+ print('ignore key: ', k, ": ", v.size())
+
+ ema_params = mp_trainer.state_dict_to_master_params(
+ model_ema_state_dict)
+
+ del state_dict
+
+ # print('ema mark 3, ', model_name, flush=True)
+ if dist_util.get_world_size() > 1:
+ dist_util.sync_params(ema_params)
+ # print('ema mark 4, ', model_name, flush=True)
+ # del ema_params
+ return ema_params
+
+ def _load_ema_parameters_freezeAE(
+ self,
+ rate,
+ model,
+ # mp_trainer=None,
+ model_name='rec'):
+
+ # if mp_trainer is None:
+ # mp_trainer = self.mp_trainer
+ # if model is None:
+ # model = self.model_rec
+
+ # ema_params = copy.deepcopy(mp_trainer.master_params)
+
+ main_checkpoint, _ = find_resume_checkpoint(
+ self.resume_checkpoint, model_name) or self.resume_checkpoint
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step,
+ rate, model_name)
+ if ema_checkpoint:
+
+ if dist_util.get_rank() == 0:
+
+ if not Path(ema_checkpoint).exists():
+ logger.log(
+ f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist"
+ )
+ return
+
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
+
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ state_dict = dist_util.load_state_dict(
+ ema_checkpoint, map_location=map_location)
+
+ model_ema_state_dict = model.state_dict()
+
+ for k, v in state_dict.items():
+ if k in model_ema_state_dict.keys() and v.size(
+ ) == model_ema_state_dict[k].size():
+ model_ema_state_dict[k] = v
+ else:
+ print('ignore key: ', k, ": ", v.size())
+
+ ema_params = mp_trainer.state_dict_to_master_params(
+ model_ema_state_dict)
+
+ del state_dict
+
+ # print('ema mark 3, ', model_name, flush=True)
+ if dist_util.get_world_size() > 1:
+ dist_util.sync_params(ema_params)
+ # print('ema mark 4, ', model_name, flush=True)
+ # del ema_params
+ return ema_params
+
+ # def _load_ema_parameters(self, rate):
+ # ema_params = copy.deepcopy(self.mp_trainer.master_params)
+
+ # main_checkpoint, _ = find_resume_checkpoint() or self.resume_checkpoint
+ # ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
+ # if ema_checkpoint:
+ # if dist.get_rank() == 0:
+ # logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
+ # state_dict = dist_util.load_state_dict(
+ # ema_checkpoint, map_location=dist_util.dev()
+ # )
+ # ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
+
+ # dist_util.sync_params(ema_params)
+ # return ema_params
+
+ def _load_optimizer_state(self):
+ main_checkpoint, _ = find_resume_checkpoint() or self.resume_checkpoint
+ opt_checkpoint = bf.join(bf.dirname(main_checkpoint),
+ f"opt{self.resume_step:06}.pt")
+ if bf.exists(opt_checkpoint):
+ logger.log(
+ f"loading optimizer state from checkpoint: {opt_checkpoint}")
+ state_dict = dist_util.load_state_dict(
+ opt_checkpoint, map_location=dist_util.dev())
+ self.opt.load_state_dict(state_dict)
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+ batch, cond = next(self.data)
+ self.run_step(batch, cond)
+ if self.step % self.log_interval == 0:
+ logger.dumpkvs()
+ if self.step % self.save_interval == 0:
+ self.save()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+ self.step += 1
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+
+ def run_step(self, batch, cond):
+ self.forward_backward(batch, cond)
+ took_step = self.mp_trainer.optimize(self.opt)
+ if took_step:
+ self._update_ema()
+ self._anneal_lr()
+ self.log_step()
+
+ def forward_backward(self, batch, cond):
+ self.mp_trainer.zero_grad()
+ for i in range(0, batch.shape[0], self.microbatch):
+
+ # st()
+ with th.autocast(device_type=dist_util.dev(),
+ dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+
+ micro = batch[i:i + self.microbatch].to(dist_util.dev())
+ micro_cond = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in cond.items()
+ }
+ last_batch = (i + self.microbatch) >= batch.shape[0]
+ t, weights = self.schedule_sampler.sample(
+ micro.shape[0], dist_util.dev())
+
+ compute_losses = functools.partial(
+ self.diffusion.training_losses,
+ self.ddp_model,
+ micro,
+ t,
+ model_kwargs=micro_cond,
+ )
+
+ if last_batch or not self.use_ddp:
+ losses = compute_losses()
+ else:
+ with self.ddp_model.no_sync():
+ losses = compute_losses()
+
+ if isinstance(self.schedule_sampler, LossAwareSampler):
+ self.schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach())
+
+ loss = (losses["loss"] * weights).mean()
+ log_loss_dict(self.diffusion, t,
+ {k: v * weights
+ for k, v in losses.items()})
+
+ self.mp_trainer.backward(loss)
+
+ def _update_ema(self):
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ update_ema(params, self.mp_trainer.master_params, rate=rate)
+
+ def _anneal_lr(self):
+ if not self.lr_anneal_steps:
+ return
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
+ lr = self.lr * (1 - frac_done)
+ for param_group in self.opt.param_groups:
+ param_group["lr"] = lr
+
+ def log_step(self):
+ logger.logkv("step", self.step + self.resume_step)
+ logger.logkv("samples",
+ (self.step + self.resume_step + 1) * self.global_batch)
+
+ def save(self):
+
+ def save_checkpoint(rate, params):
+ state_dict = self.mp_trainer.master_params_to_state_dict(params)
+ if dist.get_rank() == 0:
+ logger.log(f"saving model {rate}...")
+ if not rate:
+ filename = f"model{(self.step+self.resume_step):07d}.pt"
+ else:
+ filename = f"ema_{rate}_{(self.step+self.resume_step):07d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename),
+ "wb") as f:
+ th.save(state_dict, f)
+
+ save_checkpoint(0, self.mp_trainer.master_params)
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ save_checkpoint(rate, params)
+
+ if dist.get_rank() == 0:
+ with bf.BlobFile(
+ bf.join(get_blob_logdir(),
+ f"opt{(self.step+self.resume_step):07d}.pt"),
+ "wb",
+ ) as f:
+ th.save(self.opt.state_dict(), f)
+
+ dist.barrier()
+
+
+def parse_resume_step_from_filename(filename):
+ """
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
+ checkpoint's number of steps.
+ """
+ # split1 = Path(filename).stem[-6:]
+ split1 = Path(filename).stem[-7:]
+ # split = filename.split("model")
+ # if len(split) < 2:
+ # return 0
+ # split1 = split[-1].split(".")[0]
+ try:
+ return int(split1)
+ except ValueError:
+ print('fail to load model step', split1)
+ return 0
+
+
+def get_blob_logdir():
+ # You can change this to be a separate path to save checkpoints to
+ # a blobstore or some external drive.
+ return logger.get_dir()
+
+
+def find_resume_checkpoint(resume_checkpoint='', model_name='ddpm'):
+ # On your infrastructure, you may want to override this to automatically
+ # discover the latest checkpoint on your blob storage, etc.
+
+ if resume_checkpoint != '':
+ step = parse_resume_step_from_filename(resume_checkpoint)
+ split = resume_checkpoint.split("model")
+ resume_ckpt_path = str(
+ Path(split[0]) / f'model_{model_name}{step:07d}.pt')
+ else:
+ resume_ckpt_path = ''
+ step = 0
+
+ return resume_ckpt_path, step
+
+
+def find_ema_checkpoint(main_checkpoint, step, rate, model_name=''):
+ if main_checkpoint is None:
+ return None
+ if model_name == '':
+ filename = f"ema_{rate}_{(step):07d}.pt"
+ else:
+ filename = f"ema_{model_name}_{rate}_{(step):07d}.pt"
+ path = bf.join(bf.dirname(main_checkpoint), filename)
+ # print(path)
+ # st()
+ if bf.exists(path):
+ print('fine ema model', path)
+ return path
+ else:
+ print('fail to find ema model', path)
+ return None
+
+
+def log_loss_dict(diffusion, ts, losses):
+ for key, values in losses.items():
+ logger.logkv_mean(key, values.mean().item())
+ # Log the quantiles (four quartiles, in particular).
+ for sub_t, sub_loss in zip(ts.cpu().numpy(),
+ values.detach().cpu().numpy()):
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
+
+
+def log_rec3d_loss_dict(loss_dict):
+ for key, values in loss_dict.items():
+ try:
+ logger.logkv_mean(key, values.mean().item())
+ except:
+ print('type error:', key)
+
+
+
+def calc_average_loss(all_loss_dicts, verbose=True):
+ all_scores = {} # todo, defaultdict
+ mean_all_scores = {}
+
+ for loss_dict in all_loss_dicts:
+ for k, v in loss_dict.items():
+ v = v.item()
+ if k not in all_scores:
+ # all_scores[f'{k}_val'] = [v]
+ all_scores[k] = [v]
+ else:
+ all_scores[k].append(v)
+
+ for k, v in all_scores.items():
+ mean = np.mean(v)
+ std = np.std(v)
+ if k in ['loss_lpis', 'loss_ssim']:
+ mean = 1 - mean
+ result_str = '{} average loss is {:.4f} +- {:.4f}'.format(k, mean, std)
+ mean_all_scores[k] = mean
+ if verbose:
+ print(result_str)
+
+ val_scores_for_logging = {
+ f'{k}_val': v
+ for k, v in mean_all_scores.items()
+ }
+ return val_scores_for_logging
\ No newline at end of file
diff --git a/guided_diffusion/unet.py b/guided_diffusion/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..da07c1b7c33e896ecf8bbd219e1b39c5187338cf
--- /dev/null
+++ b/guided_diffusion/unet.py
@@ -0,0 +1,1109 @@
+from abc import abstractmethod
+
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from pdb import set_trace as st
+from einops import rearrange, repeat
+
+from .fp16_util import convert_module_to_f16, convert_module_to_f32
+from .nn import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+
+from ldm.modules.attention_compat import SpatialTransformer
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+# class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+# """
+# A sequential module that passes timestep embeddings to the children that
+# support it as an extra input.
+# """
+
+# def forward(self, x, emb):
+# for layer in self:
+# if isinstance(layer, TimestepBlock):
+# x = layer(x, emb)
+# else:
+# x = layer(x)
+# return x
+
+# from LDM openaimodel.py
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ # return checkpoint(
+ # self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ # )
+ return self._forward(x, emb)
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ # def forward(self, x):
+ # return checkpoint(self._forward, (x,), self.parameters(), True)
+
+ # def _forward(self, x):
+ # b, c, *spatial = x.shape
+ # x = x.reshape(b, c, -1)
+ # qkv = self.qkv(self.norm(x))
+ # h = self.attention(qkv)
+ # h = self.proj_out(h)
+ # return (x + h).reshape(b, c, *spatial)
+
+ # ! disable checkpoint here since it is incompatible with torch.amp
+ def forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ mixed_prediction=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=-1, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ mixing_logit_init=-6,
+ roll_out=False,**kwargs
+ ):
+ super().__init__()
+ self.roll_out = roll_out
+ if context_dim == -1:
+ context_dim = None
+
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ # from omegaconf.listconfig import ListConfig
+ # if type(context_dim) == ListConfig:
+ # context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ # follow LSGM
+ self.mixed_prediction = mixed_prediction # This enables mixed prediction
+ if self.mixed_prediction:
+ if self.roll_out:
+ init = mixing_logit_init * th.ones(size=[1, in_channels*3, 1, 1]) # hard coded for now
+ else:
+ init = mixing_logit_init * th.ones(size=[1, in_channels, 1, 1]) # hard coded for now
+ self.mixing_logit = th.nn.Parameter(init, requires_grad=True)
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, get_attr='', **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+
+ if isinstance(context, dict):
+ context = context['crossattn'] # sgm conditioner compat
+
+ if get_attr != '': # not breaking the forward hooks
+ return getattr(self, get_attr)
+
+
+ # if forward
+ # assert context is not None
+
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ # st()
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.roll_out:
+ x = rearrange(x, 'b (n c) h w->b c h (n w)', n=3) # torch.Size([84, 4, 32, 96])
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ h = self.out(h)
+ if self.roll_out:
+ return rearrange(h, 'b c h (n w) -> b (n c) h w', n=3)
+ return h
+
+
+class SuperResModel(UNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+
+ def __init__(self, image_size, in_channels, *args, **kwargs):
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
+ x = th.cat([x, upsampled], dim=1)
+ return super().forward(x, timesteps, **kwargs)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class UNetModelWithHint(UNetModel):
+ def __init__(self, image_size, in_channels, model_channels, hint_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, mixed_prediction=False, use_spatial_transformer=False, transformer_depth=1, context_dim=-1, n_embed=None, legacy=True, mixing_logit_init=-6, roll_out=False):
+ super().__init__(image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout, channel_mult, conv_resample, dims, num_classes, use_checkpoint, use_fp16, num_heads, num_head_channels, num_heads_upsample, use_scale_shift_norm, resblock_updown, use_new_attention_order, mixed_prediction, use_spatial_transformer, transformer_depth, context_dim, n_embed, legacy, mixing_logit_init, roll_out)
+
+ # lite encoder, borrowed from ControlNet
+
+ self.input_hint_block = TimestepEmbedSequential( # f=8
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ )
+
+ def forward(self, x, hint, timesteps=None, context=None, y=None, get_attr='', **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+
+ # st()
+
+ # if forward
+ # assert context is not None
+
+ assert context is not None
+ # assert (y is not None) == (
+ # self.num_classes is not None
+ # ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.roll_out:
+ x = rearrange(x, 'b (n c) h w->b c h (n w)', n=3) # torch.Size([84, 4, 32, 96])
+
+ # if self.num_classes is not None:
+ # assert y.shape == (x.shape[0],)
+ # emb = emb + self.label_emb(y)
+
+ guided_hint = self.input_hint_block(hint, emb, context)
+
+ if self.roll_out:
+ guided_hint = repeat(guided_hint, 'b c h w -> b c h (n w)', n=3) # torch.Size([84, 4, 32, 96])
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ if guided_hint is not None:
+ h = module(h, emb, context) # B, 320, 32, 96
+
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+
+ hs.append(h)
+
+ h = self.middle_block(h, emb, context)
+
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ h = self.out(h)
+ if self.roll_out:
+ return rearrange(h, 'b c h (n w) -> b (n c) h w', n=3)
+ return h
\ No newline at end of file
diff --git a/ldm/__init__.py b/ldm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/__pycache__/__init__.cpython-39.pyc b/ldm/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed45ecb91e765536cc7fcd93a973043a1169db7d
Binary files /dev/null and b/ldm/__pycache__/__init__.cpython-39.pyc differ
diff --git a/ldm/__pycache__/util.cpython-39.pyc b/ldm/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b752ff30a851eec70dee28e2dc4bb6634fb41e9e
Binary files /dev/null and b/ldm/__pycache__/util.cpython-39.pyc differ
diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/data/util.py b/ldm/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c
--- /dev/null
+++ b/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+
+from ldm.modules.midas.api import load_midas_transform
+
+
+class AddMiDaS(object):
+ def __init__(self, model_type):
+ super().__init__()
+ self.transform = load_midas_transform(model_type)
+
+ def pt2np(self, x):
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
+ return x
+
+ def np2pt(self, x):
+ x = torch.from_numpy(x) * 2 - 1.
+ return x
+
+ def __call__(self, sample):
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ x = self.pt2np(sample['jpg'])
+ x = self.transform({"image": x})["image"]
+ sample['midas_in'] = x
+ return sample
\ No newline at end of file
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+ log["reconstructions_ema"] = xrec_ema
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,336 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71a44af48c8cba8e97849b7e6813b3e6f9fe83c
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1797 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema: assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ if reset_ema:
+ assert self.use_ema
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+ else:
+ self.register_buffer('logvar', logvar)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[k][i] = val
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ force_null_conditioning=False,
+ *args, **kwargs):
+ self.force_null_conditioning = force_null_conditioning
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ['class_label', 'cls']:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_x:
+ out.extend([x])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+ else:
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ return c
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', "cls"]:
+ try:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ if not self.sequential_cross_attn:
+ cc = torch.cat(c_crossattn, 1)
+ else:
+ cc = c_crossattn
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ self.noise_level_key = noise_level_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ if self.noise_level_key is not None:
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+ raise NotImplementedError('TODO')
+
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ uc[k] = c[k]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+ """
+ Basis for different finetunas, such as inpainting or depth2image
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys: tuple,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"
+ ),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args, **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), 'did not find matching parameter to modify'
+ new_entry[:, :self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args, **kwargs
+ ):
+ super().__init__(concat_keys, *args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+ log["masked_image"] = rearrange(args[0]["masked_image"],
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on monocular depth estimation
+ """
+
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.depth_model = instantiate_from_config(depth_stage_config)
+ self.depth_stage_key = concat_keys[0]
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ cc = self.depth_model(cc)
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ depth = self.depth_model(args[0][self.depth_stage_key])
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+ return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on low-res image (and optionally on some spatial noise augmentation)
+ """
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.reshuffle_patch_size = reshuffle_patch_size
+ self.low_scale_model = None
+ if low_scale_config is not None:
+ print("Initializing a low-scale model")
+ assert exists(low_scale_key)
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ # optionally make spatial noise_level here
+ c_cat = list()
+ noise_level = None
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ cc = rearrange(cc, 'b h w c -> b c h w')
+ if exists(self.reshuffle_patch_size):
+ assert isinstance(self.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
+ cc, noise_level = self.low_scale_model(cc)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ if exists(noise_level):
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+ else:
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+ return log
diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+ t = self.inverse_lambda(lambda_t)
+ ===============================================================
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+ 1. For discrete-time DPMs:
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+ 2. For continuous-time DPMs:
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+ ===============================================================
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+ Example:
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+ We support four types of the diffusion model by setting `model_type`:
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+ ===============================================================
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+ =====================================================
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+ =====================================================
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+ "eps": "noise",
+ "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+ return x.to(device), None
\ No newline at end of file
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/ldm/modules/__pycache__/attention.cpython-39.pyc b/ldm/modules/__pycache__/attention.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e56306114d12cc5f74dde35037177682021d87a7
Binary files /dev/null and b/ldm/modules/__pycache__/attention.cpython-39.pyc differ
diff --git a/ldm/modules/__pycache__/attention_compat.cpython-39.pyc b/ldm/modules/__pycache__/attention_compat.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4a5fc1bc07a43cf393eacff92aa0b3335949f94
Binary files /dev/null and b/ldm/modules/__pycache__/attention_compat.cpython-39.pyc differ
diff --git a/ldm/modules/__pycache__/attention_new.cpython-39.pyc b/ldm/modules/__pycache__/attention_new.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..167c546abe9d3fa1d6be0fd5f3ef3ba8480f0a8d
Binary files /dev/null and b/ldm/modules/__pycache__/attention_new.cpython-39.pyc differ
diff --git a/ldm/modules/__pycache__/x_transformer.cpython-39.pyc b/ldm/modules/__pycache__/x_transformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6dba4314e751c51ca5ac9aa14ca7ed4aa104607e
Binary files /dev/null and b/ldm/modules/__pycache__/x_transformer.cpython-39.pyc differ
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc657ecce72daf65e2927a3ade2acaa703fecf07
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,436 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from pdb import set_trace as st
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+# class BasicTransformerBlock(nn.Module):
+# def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
+# super().__init__()
+# self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+# self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+# self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+# heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+# self.norm1 = nn.LayerNorm(dim)
+# self.norm2 = nn.LayerNorm(dim)
+# self.norm3 = nn.LayerNorm(dim)
+# self.checkpoint = checkpoint
+
+# def forward(self, x, context=None):
+# return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+# def _forward(self, x, context=None):
+# x = self.attn1(self.norm1(x)) + x
+# x = self.attn2(self.norm2(x), context=context) + x
+# x = self.ff(self.norm3(x)) + x
+# return x
+
+
+try:
+ from xformers.triton import FusedLayerNorm as LayerNorm
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+from typing import Optional, Any
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+ return self._forward(x, context)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
+
+
+
+class BasicTransformerBlock3D(BasicTransformerBlock):
+
+ def forward(self, x, context=None, num_frames=1):
+ # return checkpoint(self._forward, (x, context, num_frames), self.parameters(), self.checkpoint)
+ return self._forward(x, context, num_frames) # , self.parameters(), self.checkpoint
+
+ def _forward(self, x, context=None, num_frames=1):
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer3D(nn.Module):
+ ''' 3D self-attention '''
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ elif context_dim is None:
+ context_dim = [None] * depth
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, num_frames=1):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i], num_frames=num_frames)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
\ No newline at end of file
diff --git a/ldm/modules/attention_compat.py b/ldm/modules/attention_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..18e38b0bc223cdb07053dfcd527ffb0dece76ed8
--- /dev/null
+++ b/ldm/modules/attention_compat.py
@@ -0,0 +1,347 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from pdb import set_trace as st
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+try:
+ from xformers.triton import FusedLayerNorm as LayerNorm
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
+ super().__init__()
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
+
+# new stuffs
+
+class BasicTransformerBlock3D(BasicTransformerBlock):
+
+ def forward(self, x, context=None, num_frames=1):
+ # return checkpoint(self._forward, (x, context, num_frames), self.parameters(), self.checkpoint)
+ return self._forward(x, context, num_frames) # , self.parameters(), self.checkpoint
+
+ def _forward(self, x, context=None, num_frames=1):
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer3D(nn.Module):
+ ''' 3D self-attention '''
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ elif context_dim is None:
+ context_dim = [None] * depth
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, num_frames=1):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i], num_frames=num_frames)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
\ No newline at end of file
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9226986d9919d35e03e321f3ca54c2fa26900da7
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc b/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7469645bf39185131e9b69c0499b0b01a4de92be
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/mv_unet.cpython-39.pyc b/ldm/modules/diffusionmodules/__pycache__/mv_unet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be3badfab9096e66138a4161c2a406c92bd2ffd0
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/mv_unet.cpython-39.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc b/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5bcb0420ef0dbb0d92b40e20d30f35ea76da4ab
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc differ
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0041cc6d1f77b01768d482a1565aaa8463cd8c7
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,913 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+# from ldm.modules.attention import MemoryEfficientCrossAttention
+# from .modules.attention import MemoryEfficientCrossAttention
+# from ldm.modules.attention import SpatialTransformer3D
+from ldm.modules.attention import SpatialTransformer3D
+from pdb import set_trace as st
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+ print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+
+
+# class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+# def forward(self, x, context=None, mask=None):
+# b, c, h, w = x.shape
+# x = rearrange(x, 'b c h w -> b (h w) c')
+# out = super().forward(x, context=context, mask=mask)
+# out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+# return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none", "mv-vanilla"], f'attn_type {attn_type} unknown'
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "mv-vanilla":
+ assert attn_kwargs is not None
+ return SpatialTransformer3D(in_channels, **attn_kwargs) # TODO
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ raise NotImplementedError()
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla", attn_kwargs={}):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution,
+ z_channels, double_z=True,
+ use_linear_attn=False, attn_type="vanilla",
+ attn_kwargs={},
+ add_fusion_layer=False,
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, attn_kwargs=attn_kwargs)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # TODO: use attention-based? Later.
+ if add_fusion_layer: # fusion 4 frames
+ self.fusion_layer = torch.nn.Conv2d(2*z_channels*4 if double_z else z_channels*4,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, **kwargs):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+class MVEncoder(Encoder):
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="mv-vanilla", **ignore_kwargs):
+ super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=in_channels, resolution=resolution, z_channels=z_channels, double_z=double_z, use_linear_attn=use_linear_attn, attn_type=attn_type,
+ add_fusion_layer=True,
+ **ignore_kwargs)
+ self.num_frames = 4
+
+
+ def forward(self, x):
+ h = super().forward(x, num_frames=self.num_frames)
+ # multi-view aggregation, as in pixel-nerf
+ h = h.chunk(x.shape[0] // self.num_frames) # features from the same single instance aggregated here
+ # h = [feat.max(keepdim=True, dim=0)[0] for feat in h] # max pooling
+ h = [self.fusion_layer(torch.cat(feat.chunk(feat.shape[0]), dim=1)) for feat in h] # conv pooling
+ return torch.cat(h, dim=0)
+
+
+class MVEncoderGS(Encoder):
+ # support pixle-aligned rendering
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="mv-vanilla", **ignore_kwargs):
+ super().__init__(ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=in_channels, resolution=resolution, z_channels=z_channels, double_z=double_z, use_linear_attn=use_linear_attn, attn_type=attn_type,
+ add_fusion_layer=False,
+ **ignore_kwargs)
+ self.num_frames = 4
+
+
+ def forward(self, x):
+ h = super().forward(x, num_frames=self.num_frames)
+
+ # multi-view aggregation, as in pixel-nerf
+ h = h.chunk(x.shape[0] // self.num_frames) # features from the same single instance aggregated here
+ # st()
+
+ # concat
+ # torch.cat(latent, 1)
+ h = [rearrange(latent, 'B C H W -> 1 (B C) H W') for latent in h]
+ h = torch.cat(h, dim=0)
+
+ return h # B 16 H W when V=4, z_channels=2
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+
+# ! lgm unet
\ No newline at end of file
diff --git a/ldm/modules/diffusionmodules/mv_unet.py b/ldm/modules/diffusionmodules/mv_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..309f370995d782362656dccb5645068f365b9965
--- /dev/null
+++ b/ldm/modules/diffusionmodules/mv_unet.py
@@ -0,0 +1,456 @@
+from numpy import sqrt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+from typing import Tuple, Literal
+from functools import partial
+
+from pdb import set_trace as st
+
+# from core.attention import MemEffAttention
+from vit.vision_transformer import MemEffAttention
+
+
+class MVAttention(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ groups: int = 32,
+ eps: float = 1e-5,
+ residual: bool = True,
+ skip_scale: float = 1,
+ num_frames: int = 4, # WARN: hardcoded!
+ ):
+ super().__init__()
+
+ self.residual = residual
+ self.skip_scale = skip_scale
+ self.num_frames = num_frames
+
+ self.norm = nn.GroupNorm(num_groups=groups,
+ num_channels=dim,
+ eps=eps,
+ affine=True)
+ self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias,
+ attn_drop, proj_drop)
+
+ def forward(self, x):
+ # x: [B*V, C, H, W]
+ BV, C, H, W = x.shape
+ B = BV // self.num_frames # assert BV % self.num_frames == 0
+
+ res = x
+ x = self.norm(x)
+
+ x = x.reshape(B, self.num_frames, C, H,
+ W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
+ x = self.attn(x)
+ x = x.reshape(B, self.num_frames, H, W,
+ C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)
+
+ if self.residual:
+ x = (x + res) * self.skip_scale
+ return x
+
+
+class ResnetBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resample: Literal['default', 'up', 'down'] = 'default',
+ groups: int = 32,
+ eps: float = 1e-5,
+ skip_scale: float = 1, # multiplied to output
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.skip_scale = skip_scale
+
+ self.norm1 = nn.GroupNorm(num_groups=groups,
+ num_channels=in_channels,
+ eps=eps,
+ affine=True)
+ self.conv1 = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ self.norm2 = nn.GroupNorm(num_groups=groups,
+ num_channels=out_channels,
+ eps=eps,
+ affine=True)
+ self.conv2 = nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ self.act = F.silu
+
+ self.resample = None
+ if resample == 'up':
+ self.resample = partial(F.interpolate,
+ scale_factor=2.0,
+ mode="nearest")
+ elif resample == 'down':
+ self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.shortcut = nn.Identity()
+ if self.in_channels != self.out_channels:
+ self.shortcut = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=True)
+
+ def forward(self, x):
+ res = x
+
+ x = self.norm1(x)
+ x = self.act(x)
+
+ if self.resample:
+ res = self.resample(res)
+ x = self.resample(x)
+
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = self.act(x)
+ x = self.conv2(x)
+
+ x = (x + self.shortcut(res)) * self.skip_scale
+
+ return x
+
+
+class DownBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ downsample: bool = True,
+ attention: bool = True,
+ attention_heads: int = 16,
+ skip_scale: float = 1,
+ ):
+ super().__init__()
+
+ nets = []
+ attns = []
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ nets.append(
+ ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
+ if attention:
+ attns.append(
+ MVAttention(out_channels,
+ attention_heads,
+ skip_scale=skip_scale))
+ else:
+ attns.append(None)
+ self.nets = nn.ModuleList(nets)
+ self.attns = nn.ModuleList(attns)
+
+ self.downsample = None
+ if downsample:
+ self.downsample = nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1)
+
+ def forward(self, x):
+ xs = []
+
+ for attn, net in zip(self.attns, self.nets):
+ x = net(x)
+ if attn:
+ x = attn(x)
+ xs.append(x)
+
+ if self.downsample:
+ x = self.downsample(x)
+ xs.append(x)
+
+ return x, xs
+
+
+class MidBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ attention: bool = True,
+ attention_heads: int = 16,
+ skip_scale: float = 1,
+ ):
+ super().__init__()
+
+ nets = []
+ attns = []
+ # first layer
+ nets.append(
+ ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
+ # more layers
+ for i in range(num_layers):
+ nets.append(
+ ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
+ if attention:
+ attns.append(
+ MVAttention(in_channels,
+ attention_heads,
+ skip_scale=skip_scale))
+ else:
+ attns.append(None)
+ self.nets = nn.ModuleList(nets)
+ self.attns = nn.ModuleList(attns)
+
+ def forward(self, x):
+ x = self.nets[0](x)
+ for attn, net in zip(self.attns, self.nets[1:]):
+ if attn:
+ x = attn(x)
+ x = net(x)
+ return x
+
+
+class UpBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int,
+ prev_out_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ upsample: bool = True,
+ attention: bool = True,
+ attention_heads: int = 16,
+ skip_scale: float = 1,
+ ):
+ super().__init__()
+
+ nets = []
+ attns = []
+ for i in range(num_layers):
+ cin = in_channels if i == 0 else out_channels
+ cskip = prev_out_channels if (i == num_layers -
+ 1) else out_channels
+
+ nets.append(
+ ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
+ if attention:
+ attns.append(
+ MVAttention(out_channels,
+ attention_heads,
+ skip_scale=skip_scale))
+ else:
+ attns.append(None)
+ self.nets = nn.ModuleList(nets)
+ self.attns = nn.ModuleList(attns)
+
+ self.upsample = None
+ if upsample:
+ self.upsample = nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, xs):
+
+ for attn, net in zip(self.attns, self.nets):
+ res_x = xs[-1]
+ xs = xs[:-1]
+ x = torch.cat([x, res_x], dim=1)
+ x = net(x)
+ if attn:
+ x = attn(x)
+
+ if self.upsample:
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest')
+ x = self.upsample(x)
+
+ return x
+
+
+# it could be asymmetric!
+class MVUNet(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
+ down_attention: Tuple[bool,
+ ...] = (False, False, False, True, True),
+ mid_attention: bool = True,
+ up_channels: Tuple[int, ...] = (1024, 512, 256),
+ up_attention: Tuple[bool, ...] = (True, True, False),
+ layers_per_block: int = 2,
+ skip_scale: float = np.sqrt(0.5),
+ ):
+ super().__init__()
+
+ # first
+ self.conv_in = nn.Conv2d(in_channels,
+ down_channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # down
+ down_blocks = []
+ cout = down_channels[0]
+ for i in range(len(down_channels)):
+ cin = cout
+ cout = down_channels[i]
+
+ down_blocks.append(
+ DownBlock(
+ cin,
+ cout,
+ num_layers=layers_per_block,
+ downsample=(i
+ != len(down_channels) - 1), # not final layer
+ attention=down_attention[i],
+ skip_scale=skip_scale,
+ ))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ # mid
+ self.mid_block = MidBlock(down_channels[-1],
+ attention=mid_attention,
+ skip_scale=skip_scale)
+
+ # up
+ up_blocks = []
+ cout = up_channels[0]
+ for i in range(len(up_channels)):
+ cin = cout
+ cout = up_channels[i]
+ cskip = down_channels[max(-2 - i,
+ -len(down_channels))] # for assymetric
+
+ up_blocks.append(
+ UpBlock(
+ cin,
+ cskip,
+ cout,
+ num_layers=layers_per_block + 1, # one more layer for up
+ upsample=(i != len(up_channels) - 1), # not final layer
+ attention=up_attention[i],
+ skip_scale=skip_scale,
+ ))
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ # last
+ self.norm_out = nn.GroupNorm(num_channels=up_channels[-1],
+ num_groups=32,
+ eps=1e-5)
+ self.conv_out = nn.Conv2d(up_channels[-1],
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # x: [B, Cin, H, W]
+
+ # first
+ x = self.conv_in(x)
+
+ # down
+ xss = [x]
+ for block in self.down_blocks:
+ x, xs = block(x)
+ xss.extend(xs)
+
+ # mid
+ x = self.mid_block(x) # 32 (B V) 1024 16 16
+
+ # up
+ for block in self.up_blocks:
+ xs = xss[-len(block.nets):]
+ xss = xss[:-len(block.nets)]
+ x = block(x, xs)
+
+ # last
+ x = self.norm_out(x)
+ x = F.silu(x)
+ x = self.conv_out(x) # [B, Cout, H', W']
+
+ return x
+
+
+class LGM_MVEncoder(MVUNet):
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_channels: Tuple[int] = (64, 128, 256, 512, 1024),
+ down_attention: Tuple[bool] = (False, False, False, True, True),
+ mid_attention: bool = True,
+ up_channels: Tuple[int] = (1024, 512, 256),
+ up_attention: Tuple[bool] = (True, True, False),
+ layers_per_block: int = 2,
+ skip_scale: float = np.sqrt(0.5),
+ z_channels=4,
+ double_z=True,
+ add_fusion_layer=True,
+ ):
+ super().__init__(in_channels, out_channels, down_channels,
+ down_attention, mid_attention, up_channels,
+ up_attention, layers_per_block, skip_scale)
+ del self.up_blocks
+
+ self.conv_out = torch.nn.Conv2d(up_channels[0],
+ 2 *
+ z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if add_fusion_layer: # fusion 4 frames
+ self.fusion_layer = torch.nn.Conv2d(
+ 2 * z_channels * 4 if double_z else z_channels * 4,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ self.num_frames = 4 # !hard coded
+
+ def forward(self, x):
+ # first
+ x = self.conv_in(x)
+
+ # down
+ xss = [x]
+ for block in self.down_blocks:
+ x, xs = block(x)
+ xss.extend(xs)
+
+ # mid
+ x = self.mid_block(x) # 32 (B V) 1024 16 16
+
+ # multi-view aggregation, as in pixel-nerf
+ x = x.chunk(x.shape[0] // self.num_frames) # features from the same single instance aggregated here
+ # h = [feat.max(keepdim=True, dim=0)[0] for feat in h] # max pooling
+ x = [self.fusion_layer(torch.cat(feat.chunk(feat.shape[0]), dim=1)) for feat in x] # conv pooling
+ st()
+ return torch.cat(x, dim=0)
\ No newline at end of file
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..29ca2ad3939a21b8711ce0023de86865af3496db
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention_compat import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ return x, None
+
+ def decode(self, x):
+ return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
+
+
+
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2831ffa57cc964de2fb480a7bae2962f51560382
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,271 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+@torch.autocast(device_type="cuda")
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc b/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac786ad5d73df0cfde861ce2f1d8718ba0d3758f
Binary files /dev/null and b/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc differ
diff --git a/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc b/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..922ac3347899b8bf3a18be39ffa3a6424692ff6b
Binary files /dev/null and b/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc differ
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6157cfcb0b23e686eafb13b9eedced4767e69a0
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,569 @@
+import torch
+from typing import Dict, List, Optional, Tuple, Union
+import functools
+import fsspec
+import os
+import open_clip
+import torch.nn as nn
+from functools import partial
+import clip
+from einops import rearrange, repeat
+import kornia
+import numpy as np
+from inspect import isfunction
+
+from pdb import set_trace as st
+# from transformers import CLIPTokenizer, CLIPTextModel
+
+from ...util import (append_dims, autocast, count_params, default,
+ disabled_train, expand_dims_like, instantiate_from_config)
+
+from ..x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class TransformerEmbedder(AbstractEncoder):
+ """Some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
+ super().__init__()
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
+
+ def forward(self, tokens):
+ tokens = tokens.to(self.device) # meh
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, x):
+ return self(x)
+
+
+class BERTTokenizer(AbstractEncoder):
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
+ super().__init__()
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ self.device = device
+ self.vq_interface = vq_interface
+ self.max_length = max_length
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ return tokens
+
+ @torch.no_grad()
+ def encode(self, text):
+ tokens = self(text)
+ if not self.vq_interface:
+ return tokens
+ return None, None, [None, None, tokens]
+
+ def decode(self, text):
+ return text
+
+
+class BERTEmbedder(AbstractEncoder):
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ super().__init__()
+ self.use_tknz_fn = use_tokenizer
+ if self.use_tknz_fn:
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout)
+
+ def forward(self, text):
+ if self.use_tknz_fn:
+ tokens = self.tknz_fn(text)#.to(self.device)
+ else:
+ tokens = text
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, text):
+ # output of length 77
+ return self(text)
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(self,
+ n_stages=1,
+ method='bilinear',
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None
+ if self.remap_output:
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+
+ def forward(self,x):
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_eos_feature=False):
+ super().__init__()
+ from transformers import CLIPTokenizer, CLIPTextModel
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version).to(device)
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+ self.use_eos_feature = use_eos_feature
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ if self.use_eos_feature: # for DiT
+ z = outputs.pooler_output # N 77 C
+ else:
+ z = outputs.last_hidden_state # N 77 C
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+ def __init__(self, dropout_prob=0.1, use_eos_feature=False):
+ super().__init__()
+ self.text_encodder = FrozenCLIPEmbedder(use_eos_feature=use_eos_feature) # no normalization projection
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = np.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = list(np.where(drop_ids, "None", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings = self.text_encodder(text_prompts)
+ return embeddings
+
+class FrozenCLIPTextEmbedder(nn.Module):
+ """
+ Uses the CLIP transformer encoder for text.
+ """
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True, dropout_prob=0., scale_clip_encoding=None):
+ super().__init__()
+ self.model, _ = clip.load(version, jit=False, device=device)
+ self.device = device
+ self.max_length = max_length
+ self.n_repeat = n_repeat
+ self.normalize = normalize
+ self.dropout_prob = dropout_prob
+ self.scale_clip_encoding = scale_clip_encoding
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = clip.tokenize(text).to(self.device)
+ z = self.model.encode_text(tokens)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+
+ if self.scale_clip_encoding is not None:
+ z = z * self.scale_clip_encoding
+
+ return z
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = np.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = list(np.where(drop_ids, "None", text_prompts))
+ # print(labels)
+ return labels
+
+
+ def encode(self, text):
+ z = self(text)
+
+ if z.ndim==2: # match cross attention shape
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+
+ return z
+
+
+class FrozenClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ n_repeat=1,
+ dropout_prob=0.2, # follow Rodin
+ normalize_encoding=False,
+ scale_clip_encoding=1.0,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+ self.n_repeat = n_repeat
+ self.normalize_encoding = normalize_encoding
+ self.scale_clip_encoding = torch.tensor(scale_clip_encoding, dtype=torch.float32, device=device)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ self.dropout_prob = dropout_prob
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std) # type: ignore
+ return x
+
+ def token_drop(self, z):
+ """
+ zero the image encoding to enable classifier-free guidance.
+ """
+ drop_ids = np.random.uniform(0, 1, z.shape[0]) < self.dropout_prob # idx token to drop
+ drop_ids = torch.from_numpy(drop_ids).unsqueeze(1).expand_as(z).bool().to(z.device)
+ z = torch.where(drop_ids, torch.zeros_like(z), z)
+ return z
+
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ # return self.model.encode_image(self.preprocess(x))
+ z = self.model.encode_image(self.preprocess(x))
+
+ # ? normalized features, seems not working?
+ if self.normalize_encoding:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ if self.scale_clip_encoding:
+ # st()
+ z = z * self.scale_clip_encoding
+
+ if self.dropout_prob>0: # for cfg
+ z = self.token_drop(z)
+
+ if z.ndim==2:
+ # repeat 1 dim, for context shape compatability.
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ return z
+
+
+class AbstractEmbModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._is_trainable = None
+ self._ucg_rate = None
+ self._input_key = None
+
+ @property
+ def is_trainable(self) -> bool:
+ return self._is_trainable
+
+ @property
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
+ return self._ucg_rate
+
+ @property
+ def input_key(self) -> str:
+ return self._input_key
+
+ @is_trainable.setter
+ def is_trainable(self, value: bool):
+ self._is_trainable = value
+
+ @ucg_rate.setter
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
+ self._ucg_rate = value
+
+ @input_key.setter
+ def input_key(self, value: str):
+ self._input_key = value
+
+ @is_trainable.deleter
+ def is_trainable(self):
+ del self._is_trainable
+
+ @ucg_rate.deleter
+ def ucg_rate(self):
+ del self._ucg_rate
+
+ @input_key.deleter
+ def input_key(self):
+ del self._input_key
+
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ antialias=True,
+ ucg_rate=0.0,
+ unsqueeze_dim=False,
+ repeat_to_max_len=False,
+ num_image_crops=0,
+ output_tokens=False,
+ init_device=None,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device(default(init_device, "cpu")),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.max_crops = num_image_crops
+ self.pad_to_max_len = self.max_crops > 0
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+ self.unsqueeze_dim = unsqueeze_dim
+ self.stored_batch = None
+ self.model.visual.output_tokens = output_tokens
+ self.output_tokens = output_tokens
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ tokens = None
+ if self.output_tokens:
+ z, tokens = z[0], z[1]
+ z = z.to(image.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ if tokens is not None:
+ tokens = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(tokens.shape[0], device=tokens.device)
+ ),
+ tokens,
+ )
+ * tokens
+ )
+ if self.unsqueeze_dim:
+ z = z[:, None, :]
+ if self.output_tokens:
+ assert not self.repeat_to_max_len
+ assert not self.pad_to_max_len
+ return tokens, z
+ if self.repeat_to_max_len:
+ if z.dim() == 2:
+ z_ = z[:, None, :]
+ else:
+ z_ = z
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
+ elif self.pad_to_max_len:
+ assert z.dim() == 3
+ z_pad = torch.cat(
+ (
+ z,
+ torch.zeros(
+ z.shape[0],
+ self.max_length - z.shape[1],
+ z.shape[2],
+ device=z.device,
+ ),
+ ),
+ 1,
+ )
+ return z_pad, z_pad[:, 0, ...]
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ # if self.max_crops > 0:
+ # img = self.preprocess_by_cropping(img)
+ if img.dim() == 5:
+ assert self.max_crops == img.shape[1]
+ img = rearrange(img, "b n c h w -> (b n) c h w")
+ img = self.preprocess(img)
+ if not self.output_tokens:
+ assert not self.model.visual.output_tokens
+ x = self.model.visual(img)
+ tokens = None
+ else:
+ assert self.model.visual.output_tokens
+ x, tokens = self.model.visual(img)
+ if self.max_crops > 0:
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
+ # drop out between 0 and all along the sequence axis
+ x = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
+ )
+ * x
+ )
+ if tokens is not None:
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
+ print(
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
+ f"Check what you are doing, and then remove this message."
+ )
+ if self.output_tokens:
+ return x, tokens
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
+ def __init__(
+ self,
+ # open_clip_embedding_config: Dict,
+ n_cond_frames: int,
+ n_copies: int,
+ open_clip_module,
+ ):
+ super().__init__()
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+ # self.open_clip = instantiate_from_config(open_clip_embedding_config)
+ self.open_clip = open_clip_module
+
+ def forward(self, vid):
+ vid = self.open_clip(vid)
+ vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
+ vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
+
+ return vid
+
+
+if __name__ == "__main__":
+ from ldm.util import count_params
+ model = FrozenCLIPEmbedder()
+ count_params(model, verbose=True)
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ if up:
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/modules/midas/__init__.py b/ldm/modules/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/api.py b/ldm/modules/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c
--- /dev/null
+++ b/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+ # NOTE: we expect that the correct transform has been called during dataloading.
+ with torch.no_grad():
+ prediction = self.model(x)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=x.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+ return prediction
+
diff --git a/ldm/modules/midas/midas/__init__.py b/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/midas/base_model.py b/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/ldm/modules/midas/midas/blocks.py b/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/ldm/modules/midas/midas/dpt_depth.py b/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/ldm/modules/midas/midas/midas_net.py b/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/ldm/modules/midas/midas/midas_net_custom.py b/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/ldm/modules/midas/midas/transforms.py b/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/ldm/modules/midas/midas/vit.py b/ldm/modules/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/ldm/modules/midas/utils.py b/ldm/modules/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576
--- /dev/null
+++ b/ldm/modules/x_transformer.py
@@ -0,0 +1,641 @@
+"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat, reduce
+
+# constants
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple('Intermediates', [
+ 'pre_softmax_attn',
+ 'post_softmax_attn'
+])
+
+LayerIntermediates = namedtuple('Intermediates', [
+ 'hiddens',
+ 'attn_intermediates'
+])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+
+# classes
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, 'b n d -> (b n) d'),
+ rearrange(residual, 'b n d -> (b n) d')
+ )
+
+ return gated_output.reshape_as(x)
+
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# attention.
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ # talking heads
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ # explicit topk sparse attention
+ self.sparse_topk = sparse_topk
+
+ # entmax
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ self.attn_fn = F.softmax
+
+ # add memory key / values
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ # attention on attention
+ self.attn_on_attn = on_attn
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None
+ ):
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ # in shortformer, the query would start at a position offset depending on the past cached memory
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
+
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
+ self.rotary_pos_emb = always(None)
+
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ('a', 'c', 'f')
+ elif cross_attend and only_cross:
+ default_block = ('c', 'f')
+ else:
+ default_block = ('a', 'f')
+
+ if macaron:
+ default_block = ('f',) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
+ default_block = tuple(filter(not_equals('f'), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
+ par_block = default_block + ('f',) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == 'a':
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == 'c':
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == 'f':
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f'invalid layer type {layer_type}')
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([
+ norm_fn(),
+ layer,
+ residual_fn
+ ]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == 'a':
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == 'a':
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
+ prev_attn=prev_attn, mem=layer_mem)
+ elif layer_type == 'c':
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
+ elif layer_type == 'f':
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ('a', 'c'):
+ intermediates.append(inter)
+
+ if layer_type == 'a' and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == 'c' and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens,
+ attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
+ super().__init__(causal=False, **kwargs)
+
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
+ ):
+ super().__init__()
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
+
+ # memory tokens (like [cls]) from Memory Transformers paper
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ # let funnel encoder know number of memory tokens, if specified
+ if hasattr(attn_layers, 'num_memory_tokens'):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
+ return out, attn_maps
+
+ return out
+
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..66d9b2a69db2898323cbf2ad26a09ac8b2facd11
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,275 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model
+
+
+def get_configs_path() -> str:
+ """
+ Get the `configs` directory.
+ For a working copy, this is the one in the root of the repository,
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
+ """
+ this_dir = os.path.dirname(__file__)
+ candidates = (
+ os.path.join(this_dir, "configs"),
+ os.path.join(this_dir, "..", "configs"),
+ )
+ for candidate in candidates:
+ candidate = os.path.abspath(candidate)
+ if os.path.isdir(candidate):
+ return candidate
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
+
+
+def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
+ """
+ Will return the result of a recursive get attribute call.
+ E.g.:
+ a.b.c
+ = getattr(getattr(a, "b"), "c")
+ = get_nested_attribute(a, "b.c")
+ If any part of the attribute call is an integer x with current obj a, will
+ try to call a[x] instead of a.x first.
+ """
+ attributes = attribute_path.split(".")
+ if depth is not None and depth > 0:
+ attributes = attributes[:depth]
+ assert len(attributes) > 0, "At least one attribute should be selected"
+ current_attribute = obj
+ current_key = None
+ for level, attribute in enumerate(attributes):
+ current_key = ".".join(attributes[: level + 1])
+ try:
+ id_ = int(attribute)
+ current_attribute = current_attribute[id_]
+ except ValueError:
+ current_attribute = getattr(current_attribute, attribute)
+
+ return (current_attribute, current_key) if return_key else current_attribute
diff --git a/nsr/__init__.py b/nsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..96fc389d1e72bcd1fe2264273e13838dde0495b9
--- /dev/null
+++ b/nsr/__init__.py
@@ -0,0 +1,16 @@
+# triplane, tensorF etc.
+from .train_util import TrainLoop3DRec, TrainLoop3DRecTrajVis
+from .train_util_cvD import TrainLoop3DcvD
+
+# train ffhq
+from .cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD, TrainLoop3DcvD_nvsD_canoD_eg3d
+from .train_util_with_eg3d import TrainLoop3DRecEG3D
+# from .train_util_with_eg3d_real import TrainLoop3DRecEG3DReal, TrainLoop3DRecEG3DRealOnly
+# from .train_util_with_eg3d_real_D import TrainLoop3DRecEG3DRealOnl_D
+
+# * difffusion trainer
+from .train_util_diffusion import TrainLoop3DDiffusion
+
+# import lsgm
+from .lsgm import *
+from .lsgm import crossattn_cldm_objv
\ No newline at end of file
diff --git a/nsr/__pycache__/__init__.cpython-39.pyc b/nsr/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7775ebbc03b98d83117341c2ea88b43ff139bd6c
Binary files /dev/null and b/nsr/__pycache__/__init__.cpython-39.pyc differ
diff --git a/nsr/__pycache__/augment.cpython-39.pyc b/nsr/__pycache__/augment.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dbac03d2fa02666f7e1eda5f9b016b783424ce5
Binary files /dev/null and b/nsr/__pycache__/augment.cpython-39.pyc differ
diff --git a/nsr/__pycache__/camera_utils.cpython-39.pyc b/nsr/__pycache__/camera_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cd760274f3bb7aa04de248c2cf7c3ce540a1045
Binary files /dev/null and b/nsr/__pycache__/camera_utils.cpython-39.pyc differ
diff --git a/nsr/__pycache__/common_blks.cpython-39.pyc b/nsr/__pycache__/common_blks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e9a034db0efa5e1acd4bc569e73988fb493b330
Binary files /dev/null and b/nsr/__pycache__/common_blks.cpython-39.pyc differ
diff --git a/nsr/__pycache__/confnet.cpython-39.pyc b/nsr/__pycache__/confnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5479bf189f36a108cfe902c7729d5fa92027fa52
Binary files /dev/null and b/nsr/__pycache__/confnet.cpython-39.pyc differ
diff --git a/nsr/__pycache__/dual_discriminator.cpython-39.pyc b/nsr/__pycache__/dual_discriminator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31b36f7cad20ebd46737c62f1ec2176f0c797222
Binary files /dev/null and b/nsr/__pycache__/dual_discriminator.cpython-39.pyc differ
diff --git a/nsr/__pycache__/networks_stylegan2.cpython-39.pyc b/nsr/__pycache__/networks_stylegan2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13dd0887c59d306f7ac530a2a05fa2b838410c18
Binary files /dev/null and b/nsr/__pycache__/networks_stylegan2.cpython-39.pyc differ
diff --git a/nsr/__pycache__/networks_stylegan3.cpython-39.pyc b/nsr/__pycache__/networks_stylegan3.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e280da4f12bdb912554357680f70016cbc8df382
Binary files /dev/null and b/nsr/__pycache__/networks_stylegan3.cpython-39.pyc differ
diff --git a/nsr/__pycache__/script_util.cpython-39.pyc b/nsr/__pycache__/script_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09daffcf76b2817e4986ac718016f073997ba9bd
Binary files /dev/null and b/nsr/__pycache__/script_util.cpython-39.pyc differ
diff --git a/nsr/__pycache__/superresolution.cpython-39.pyc b/nsr/__pycache__/superresolution.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ec74e75a9b62ceb31befd01b806816431be8326
Binary files /dev/null and b/nsr/__pycache__/superresolution.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_nv_util.cpython-39.pyc b/nsr/__pycache__/train_nv_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8fbc5b4b975853df4c6ecb3a20f279582f939c6
Binary files /dev/null and b/nsr/__pycache__/train_nv_util.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util.cpython-39.pyc b/nsr/__pycache__/train_util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1d4945b156e9ef6466002907ad2e934397422d9
Binary files /dev/null and b/nsr/__pycache__/train_util.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_cvD.cpython-39.pyc b/nsr/__pycache__/train_util_cvD.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b1e8e050f9f14ed0db095b818ef6cacc8d1e85f
Binary files /dev/null and b/nsr/__pycache__/train_util_cvD.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_diffusion.cpython-39.pyc b/nsr/__pycache__/train_util_diffusion.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b70bf9a5483785e4747ca2af6c2f63592d18cb9
Binary files /dev/null and b/nsr/__pycache__/train_util_diffusion.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_diffusion_dit.cpython-39.pyc b/nsr/__pycache__/train_util_diffusion_dit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4a4eda223038a48c1233ba53bfb5c9f9f2ff1db
Binary files /dev/null and b/nsr/__pycache__/train_util_diffusion_dit.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_diffusion_single_stage.cpython-39.pyc b/nsr/__pycache__/train_util_diffusion_single_stage.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b901d2007b89c2ee0baaa1f17b2d10a7c9125e1d
Binary files /dev/null and b/nsr/__pycache__/train_util_diffusion_single_stage.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_diffusion_single_stage_sds.cpython-39.pyc b/nsr/__pycache__/train_util_diffusion_single_stage_sds.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8aa3c207d865f9ba48bdf50e48d246429c373329
Binary files /dev/null and b/nsr/__pycache__/train_util_diffusion_single_stage_sds.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_with_eg3d.cpython-39.pyc b/nsr/__pycache__/train_util_with_eg3d.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf7743776c835f316c2de2d2347ab0b848222b32
Binary files /dev/null and b/nsr/__pycache__/train_util_with_eg3d.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_with_eg3d_hybrid.cpython-39.pyc b/nsr/__pycache__/train_util_with_eg3d_hybrid.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d28e461eef54880350a21d8f06a75e75ea7d81f8
Binary files /dev/null and b/nsr/__pycache__/train_util_with_eg3d_hybrid.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_with_eg3d_hybrid_eg3dD.cpython-39.pyc b/nsr/__pycache__/train_util_with_eg3d_hybrid_eg3dD.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0befc05dd2a86d9510ee8937d7e463b55a638cca
Binary files /dev/null and b/nsr/__pycache__/train_util_with_eg3d_hybrid_eg3dD.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_with_eg3d_real.cpython-39.pyc b/nsr/__pycache__/train_util_with_eg3d_real.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93b4268322a6cc747a84e09302544a11ffaee457
Binary files /dev/null and b/nsr/__pycache__/train_util_with_eg3d_real.cpython-39.pyc differ
diff --git a/nsr/__pycache__/train_util_with_eg3d_real_D.cpython-39.pyc b/nsr/__pycache__/train_util_with_eg3d_real_D.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cd86530471286cd380ee64f755fdb35287a1b26
Binary files /dev/null and b/nsr/__pycache__/train_util_with_eg3d_real_D.cpython-39.pyc differ
diff --git a/nsr/__pycache__/triplane.cpython-39.pyc b/nsr/__pycache__/triplane.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be8d7f94493f9f579a265ee32e79cb68b7f7bf68
Binary files /dev/null and b/nsr/__pycache__/triplane.cpython-39.pyc differ
diff --git a/nsr/augment.py b/nsr/augment.py
new file mode 100755
index 0000000000000000000000000000000000000000..c84a70bd6fc96ea1c58af9fbda7eb0b498ada7b1
--- /dev/null
+++ b/nsr/augment.py
@@ -0,0 +1,431 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import scipy.signal
+import torch
+from utils.torch_utils import persistence
+from utils.torch_utils import misc
+from utils.torch_utils.ops import upfirdn2d
+from utils.torch_utils.ops import grid_sample_gradfix
+from utils.torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+# Coefficients of various wavelet decomposition low-pass filters.
+
+wavelets = {
+ 'haar': [0.7071067811865476, 0.7071067811865476],
+ 'db1': [0.7071067811865476, 0.7071067811865476],
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
+}
+
+#----------------------------------------------------------------------------
+# Helpers for constructing transformation matrices.
+
+def matrix(*rows, device=None):
+ assert all(len(row) == len(rows[0]) for row in rows)
+ elems = [x for row in rows for x in row]
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
+ if len(ref) == 0:
+ return misc.constant(np.asarray(rows), device=device)
+ assert device is None or device == ref[0].device
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
+
+def translate2d(tx, ty, **kwargs):
+ return matrix(
+ [1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1],
+ **kwargs)
+
+def translate3d(tx, ty, tz, **kwargs):
+ return matrix(
+ [1, 0, 0, tx],
+ [0, 1, 0, ty],
+ [0, 0, 1, tz],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def scale2d(sx, sy, **kwargs):
+ return matrix(
+ [sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1],
+ **kwargs)
+
+def scale3d(sx, sy, sz, **kwargs):
+ return matrix(
+ [sx, 0, 0, 0],
+ [0, sy, 0, 0],
+ [0, 0, sz, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def rotate2d(theta, **kwargs):
+ return matrix(
+ [torch.cos(theta), torch.sin(-theta), 0],
+ [torch.sin(theta), torch.cos(theta), 0],
+ [0, 0, 1],
+ **kwargs)
+
+def rotate3d(v, theta, **kwargs):
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
+ return matrix(
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def translate2d_inv(tx, ty, **kwargs):
+ return translate2d(-tx, -ty, **kwargs)
+
+def scale2d_inv(sx, sy, **kwargs):
+ return scale2d(1 / sx, 1 / sy, **kwargs)
+
+def rotate2d_inv(theta, **kwargs):
+ return rotate2d(-theta, **kwargs)
+
+#----------------------------------------------------------------------------
+# Versatile image augmentation pipeline from the paper
+# "Training Generative Adversarial Networks with Limited Data".
+#
+# All augmentations are disabled by default; individual augmentations can
+# be enabled by setting their probability multipliers to 1.
+
+@persistence.persistent_class
+class AugmentPipe(torch.nn.Module):
+ def __init__(self,
+ xflip=0, rotate90=0, xint=0, xint_max=0.125,
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
+ noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
+ ):
+ super().__init__()
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
+
+ # Pixel blitting.
+ self.xflip = float(xflip) # Probability multiplier for x-flip.
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
+ self.xint = float(xint) # Probability multiplier for integer translation.
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
+
+ # General geometric transformations.
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
+
+ # Color transformations.
+ self.brightness = float(brightness) # Probability multiplier for brightness.
+ self.contrast = float(contrast) # Probability multiplier for contrast.
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
+ self.hue = float(hue) # Probability multiplier for hue rotation.
+ self.saturation = float(saturation) # Probability multiplier for saturation.
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
+
+ # Image-space filtering.
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
+
+ # Image-space corruptions.
+ self.noise = float(noise) # Probability multiplier for additive RGB noise.
+ self.cutout = float(cutout) # Probability multiplier for cutout.
+ self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
+ self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
+
+ # Setup orthogonal lowpass filter for geometric augmentations.
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
+
+ # Construct filter bank for image-space filtering.
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
+ for i in range(1, Hz_fbank.shape[0]):
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
+
+ def forward(self, images, debug_percentile=None):
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
+ batch_size, num_channels, height, width = images.shape
+ device = images.device
+ if debug_percentile is not None:
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
+
+ # -------------------------------------
+ # Select parameters for pixel blitting.
+ # -------------------------------------
+
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
+ I_3 = torch.eye(3, device=device)
+ G_inv = I_3
+
+ # Apply x-flip with probability (xflip * strength).
+ if self.xflip > 0:
+ i = torch.floor(torch.rand([batch_size], device=device) * 2)
+ i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
+ G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
+
+ # Apply 90 degree rotations with probability (rotate90 * strength).
+ if self.rotate90 > 0:
+ i = torch.floor(torch.rand([batch_size], device=device) * 4)
+ i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 4))
+ G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
+
+ # Apply integer translation with probability (xint * strength).
+ if self.xint > 0:
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
+
+ # --------------------------------------------------------
+ # Select parameters for general geometric transformations.
+ # --------------------------------------------------------
+
+ # Apply isotropic scaling with probability (scale * strength).
+ if self.scale > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
+ G_inv = G_inv @ scale2d_inv(s, s)
+
+ # Apply pre-rotation with probability p_rot.
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
+
+ # Apply anisotropic scaling with probability (aniso * strength).
+ if self.aniso > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
+
+ # Apply post-rotation with probability p_rot.
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.zeros_like(theta)
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
+
+ # Apply fractional translation with probability (xfrac * strength).
+ if self.xfrac > 0:
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
+
+ # ----------------------------------
+ # Execute geometric transformations.
+ # ----------------------------------
+
+ # Execute if the transform is not identity.
+ if G_inv is not I_3:
+
+ # Calculate padding.
+ cx = (width - 1) / 2
+ cy = (height - 1) / 2
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
+ Hz_pad = self.Hz_geom.shape[0] // 4
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
+
+ # Pad image and adjust origin.
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
+
+ # Upsample.
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
+
+ # Execute transformation.
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
+ images = grid_sample_gradfix.grid_sample(images, grid)
+
+ # Downsample and crop.
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
+
+ # --------------------------------------------
+ # Select parameters for color transformations.
+ # --------------------------------------------
+
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
+ I_4 = torch.eye(4, device=device)
+ C = I_4
+
+ # Apply brightness with probability (brightness * strength).
+ if self.brightness > 0:
+ b = torch.randn([batch_size], device=device) * self.brightness_std
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
+ if debug_percentile is not None:
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
+ C = translate3d(b, b, b) @ C
+
+ # Apply contrast with probability (contrast * strength).
+ if self.contrast > 0:
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
+ if debug_percentile is not None:
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
+ C = scale3d(c, c, c) @ C
+
+ # Apply luma flip with probability (lumaflip * strength).
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
+ if self.lumaflip > 0:
+ i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
+ i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
+ C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
+
+ # Apply hue rotation with probability (hue * strength).
+ if self.hue > 0 and num_channels > 1:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
+ C = rotate3d(v, theta) @ C # Rotate around v.
+
+ # Apply saturation with probability (saturation * strength).
+ if self.saturation > 0 and num_channels > 1:
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
+
+ # ------------------------------
+ # Execute color transformations.
+ # ------------------------------
+
+ # Execute if the transform is not identity.
+ if C is not I_4:
+ images = images.reshape([batch_size, num_channels, height * width])
+ if num_channels == 3:
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
+ elif num_channels == 1:
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
+ else:
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ----------------------
+ # Image-space filtering.
+ # ----------------------
+
+ if self.imgfilter > 0:
+ num_bands = self.Hz_fbank.shape[0]
+ assert len(self.imgfilter_bands) == num_bands
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
+
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
+ for i, band_strength in enumerate(self.imgfilter_bands):
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
+ if debug_percentile is not None:
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
+ t[:, i] = t_i # Replace i'th element.
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
+ g = g * t # Accumulate into global gain.
+
+ # Construct combined amplification filter.
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
+
+ # Apply filter.
+ p = self.Hz_fbank.shape[1] // 2
+ images = images.reshape([1, batch_size * num_channels, height, width])
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ------------------------
+ # Image-space corruptions.
+ # ------------------------
+
+ # Apply additive RGB noise with probability (noise * strength).
+ if self.noise > 0:
+ sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
+ sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
+ if debug_percentile is not None:
+ sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
+ images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
+
+ # Apply cutout with probability (cutout * strength).
+ if self.cutout > 0:
+ size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
+ size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
+ center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
+ if debug_percentile is not None:
+ size = torch.full_like(size, self.cutout_size)
+ center = torch.full_like(center, debug_percentile)
+ coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
+ coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
+ mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
+ mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
+ mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
+ images = images * mask
+
+ return images
+
+#----------------------------------------------------------------------------
diff --git a/nsr/camera_utils.py b/nsr/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f2d3125daf4f6de74a5605f0a0ff730b41bbe03
--- /dev/null
+++ b/nsr/camera_utils.py
@@ -0,0 +1,193 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""
+Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+
+from nsr.volumetric_rendering import math_utils
+
+
+class GaussianCameraPoseSampler:
+ """
+ Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
+ Camera is specified as looking at the origin.
+ If horizontal and vertical stddev (specified in radians) are zero, gives a
+ deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
+ The coordinate system is specified with y-up, z-forward, x-left.
+ Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
+ vertical mean is the polar angle (angle from the y axis) in radians.
+ A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean,
+ vertical_mean,
+ horizontal_stddev=0,
+ vertical_stddev=0,
+ radius=1,
+ batch_size=1,
+ device='cpu'):
+ h = torch.randn((batch_size, 1),
+ device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn(
+ (batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2 * v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi -
+ theta)
+ camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi -
+ theta)
+ camera_origins[:, 1:2] = radius * torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+class LookAtPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ camera is specified as looking at 'lookat_position', a 3-vector.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean,
+ vertical_mean,
+ lookat_position,
+ horizontal_stddev=0.,
+ vertical_stddev=0.,
+ radius=1,
+ batch_size=1,
+ device='cpu'):
+ h = torch.randn((batch_size, 1),
+ device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn(
+ (batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2 * v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi -
+ theta)
+ camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi -
+ theta)
+ camera_origins[:, 1:2] = radius * torch.cos(phi)
+
+ # forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ forward_vectors = math_utils.normalize_vecs(lookat_position -
+ camera_origins)
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+class UniformCameraPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
+
+ Example:
+ For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
+
+ cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean,
+ vertical_mean,
+ horizontal_stddev=0,
+ vertical_stddev=0,
+ radius=1,
+ batch_size=1,
+ device='cpu'):
+ h = (torch.rand((batch_size, 1), device=device) * 2 -
+ 1) * horizontal_stddev + horizontal_mean
+ v = (torch.rand((batch_size, 1), device=device) * 2 -
+ 1) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2 * v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi -
+ theta)
+ camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi -
+ theta)
+ camera_origins[:, 1:2] = radius * torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+def create_cam2world_matrix(forward_vector, origin):
+ """
+ Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
+ Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
+ """
+
+ forward_vector = math_utils.normalize_vecs(forward_vector)
+ up_vector = torch.tensor([0, 1, 0],
+ dtype=torch.float,
+ device=origin.device).expand_as(forward_vector)
+
+ right_vector = -math_utils.normalize_vecs(
+ torch.cross(up_vector, forward_vector, dim=-1))
+ up_vector = math_utils.normalize_vecs(
+ torch.cross(forward_vector, right_vector, dim=-1))
+
+ rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(
+ forward_vector.shape[0], 1, 1)
+ rotation_matrix[:, :3, :3] = torch.stack(
+ (right_vector, up_vector, forward_vector), axis=-1)
+
+ translation_matrix = torch.eye(4,
+ device=origin.device).unsqueeze(0).repeat(
+ forward_vector.shape[0], 1, 1)
+ translation_matrix[:, :3, 3] = origin
+ cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
+ assert (cam2world.shape[1:] == (4, 4))
+ return cam2world
+
+
+def FOV_to_intrinsics(fov_degrees, device='cpu'):
+ """
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
+ Assumes principal point is at image center.
+ """
+
+ focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
+ intrinsics = torch.tensor(
+ [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]],
+ device=device)
+ return intrinsics
diff --git a/nsr/common_blks.py b/nsr/common_blks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2789aea5a060c5a94d21e8a69cd566e6211a76e5
--- /dev/null
+++ b/nsr/common_blks.py
@@ -0,0 +1,216 @@
+
+
+# https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py
+from torch import nn
+import torch
+
+from vit.vision_transformer import Mlp, DropPath
+
+
+# Resnet Blocks
+class ResnetBlockFC(nn.Module):
+ """
+ Fully connected ResNet Block class.
+ Taken from DVR code.
+ :param size_in (int): input dimension
+ :param size_out (int): output dimension
+ :param size_h (int): hidden dimension
+ """
+ def __init__(self, size_in, size_out=None, size_h=None, beta=0.0, init_as_zero=False):
+ super().__init__()
+ # Attributes
+ if size_out is None:
+ size_out = size_in
+
+ if size_h is None:
+ size_h = min(size_in, size_out)
+
+ self.size_in = size_in
+ self.size_h = size_h
+ self.size_out = size_out
+ # Submodules
+ self.fc_0 = nn.Linear(size_in, size_h)
+ self.fc_1 = nn.Linear(size_h, size_out)
+
+ # Init
+ nn.init.constant_(self.fc_0.bias, 0.0)
+ if init_as_zero:
+ nn.init.zeros_(self.fc_0.weight)
+ else:
+ nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in")
+ nn.init.constant_(self.fc_1.bias, 0.0)
+ nn.init.zeros_(self.fc_1.weight)
+
+ if beta > 0:
+ self.activation = nn.Softplus(beta=beta)
+ else:
+ self.activation = nn.ReLU()
+
+ if size_in == size_out:
+ self.shortcut = None
+ else:
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
+ # nn.init.constant_(self.shortcut.bias, 0.0)
+ nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in")
+
+ def forward(self, x):
+ # with profiler.record_function("resblock"):
+ net = self.fc_0(self.activation(x))
+ dx = self.fc_1(self.activation(net))
+
+ if self.shortcut is not None:
+ x_s = self.shortcut(x)
+ else:
+ x_s = x
+ return x_s + dx
+
+
+
+
+# Resnet Blocks
+class ResnetBlockFCViT(nn.Module):
+ """
+ Fully connected ResNet Block class.
+ Taken from DVR code.
+ :param size_in (int): input dimension
+ :param size_out (int): output dimension
+ :param size_h (int): hidden dimension
+ """
+ def __init__(self, size_in, size_out=None, size_h=None, beta=0.0, init_as_zero=False):
+ super().__init__()
+ # Attributes
+ if size_out is None:
+ size_out = size_in
+
+ if size_h is None:
+ size_h = min(size_in, size_out)
+
+ self.size_in = size_in
+ self.size_h = size_h
+ self.size_out = size_out
+ # Submodules
+ self.fc_0 = nn.Linear(size_in, size_h)
+ self.fc_1 = nn.Linear(size_h, size_out)
+
+ # Init
+ nn.init.constant_(self.fc_0.bias, 0.0)
+ if init_as_zero:
+ nn.init.zeros_(self.fc_0.weight)
+ else:
+ nn.init.kaiming_normal_(self.fc_0.weight, a=0, mode="fan_in")
+ nn.init.constant_(self.fc_1.bias, 0.0)
+ nn.init.zeros_(self.fc_1.weight)
+
+ if beta > 0:
+ self.activation = nn.Softplus(beta=beta)
+ else:
+ self.activation = nn.ReLU()
+
+ if size_in == size_out:
+ self.shortcut = None
+ else:
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
+ # nn.init.constant_(self.shortcut.bias, 0.0)
+ nn.init.kaiming_normal_(self.shortcut.weight, a=0, mode="fan_in")
+
+ def forward(self, x):
+ # with profiler.record_function("resblock"):
+ net = self.fc_0(self.activation(x))
+ dx = self.fc_1(self.activation(net))
+
+ if self.shortcut is not None:
+ x_s = self.shortcut(x)
+ else:
+ x_s = x
+ return x_s + dx
+
+
+# class Block(nn.Module):
+# def __init__(self,
+# dim,
+# num_heads,
+# mlp_ratio=4.,
+# qkv_bias=False,
+# qk_scale=None,
+# drop=0.,
+# attn_drop=0.,
+# drop_path=0.,
+# act_layer=nn.GELU,
+# norm_layer=nn.LayerNorm):
+# super().__init__()
+# self.norm1 = norm_layer(dim)
+# self.attn = Attention(dim,
+# num_heads=num_heads,
+# qkv_bias=qkv_bias,
+# qk_scale=qk_scale,
+# attn_drop=attn_drop,
+# proj_drop=drop)
+# self.drop_path = DropPath(
+# drop_path) if drop_path > 0. else nn.Identity()
+# self.norm2 = norm_layer(dim)
+# mlp_hidden_dim = int(dim * mlp_ratio)
+# self.mlp = Mlp(in_features=dim,
+# hidden_features=mlp_hidden_dim,
+# act_layer=act_layer,
+# drop=drop)
+
+# def forward(self, x, return_attention=False):
+# y, attn = self.attn(self.norm1(x))
+# if return_attention:
+# return attn
+# x = x + self.drop_path(y)
+# x = x + self.drop_path(self.mlp(self.norm2(x)))
+# return x
+
+
+
+
+class ResMlp(nn.Module):
+ def __init__(self,
+
+ size_in,
+ size_out=None,
+ size_h=None,
+ drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+
+ # Attributes
+ if size_out is None:
+ size_out = size_in
+ if size_h is None:
+ size_h = min(size_in, size_out)
+ self.size_in = size_in
+ self.size_h = size_h
+ self.size_out = size_out
+
+ # Submodules
+ self.norm1 = norm_layer(size_in) # ? how to use
+
+ self.mlp = Mlp(in_features=size_in,
+ out_features=size_out,
+ act_layer=act_layer,
+ drop=drop)
+
+ # Residual shortcuts
+ if size_in == size_out:
+ self.shortcut = None
+ else:
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
+ self.norm2 = norm_layer(size_in)
+
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ dx = self.mlp(self.norm1(x))
+
+ if self.shortcut is not None:
+ x_s = self.shortcut(self.norm2(x))
+ else:
+ x_s = x
+
+ return x_s + self.drop_path(dx)
\ No newline at end of file
diff --git a/nsr/confnet.py b/nsr/confnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..65d908ffaf5c5131aaef75bac769c23b8a760b2f
--- /dev/null
+++ b/nsr/confnet.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+import torchvision
+
+
+EPS = 1e-7
+
+class ConfNet(nn.Module):
+ def __init__(self, cin=3, cout=1, zdim=128, nf=64):
+ super(ConfNet, self).__init__()
+ ## downsampling
+ network = [
+ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
+ nn.GroupNorm(16, nf),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
+ nn.GroupNorm(16*2, nf*2),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
+ nn.GroupNorm(16*4, nf*4),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
+ nn.ReLU(inplace=True)]
+ ## upsampling
+ network += [
+ nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, padding=0, bias=False), # 1x1 -> 4x4
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8
+ nn.GroupNorm(16*4, nf*4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16
+ nn.GroupNorm(16*2, nf*2),
+ nn.ReLU(inplace=True)]
+ self.network = nn.Sequential(*network)
+
+ # ! only the symmetric confidence is required
+ # out_net1 = [
+ # nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32
+ # nn.GroupNorm(16, nf),
+ # nn.ReLU(inplace=True),
+ # nn.ConvTranspose2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 64x64
+ # nn.GroupNorm(16, nf),
+ # nn.ReLU(inplace=True),
+ # nn.Conv2d(nf, 2, kernel_size=5, stride=1, padding=2, bias=False), # 64x64
+ # # nn.Conv2d(nf, 1, kernel_size=5, stride=1, padding=2, bias=False), # 64x64
+ # nn.Softplus()
+ # ]
+ # self.out_net1 = nn.Sequential(*out_net1)
+
+ # ! for perceptual loss
+ out_net2 = [nn.Conv2d(nf*2, 2, kernel_size=3, stride=1, padding=1, bias=False), # 16x16
+ nn.Softplus()
+ # nn.Sigmoid()
+ ]
+ self.out_net2 = nn.Sequential(*out_net2)
+
+ def forward(self, input):
+ out = self.network(input)
+ # return self.out_net1(out)
+ return self.out_net2(out)
+ # return self.out_net1(out), self.out_net2(out)
\ No newline at end of file
diff --git a/nsr/cvD/__init__.py b/nsr/cvD/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nsr/cvD/__pycache__/__init__.cpython-39.pyc b/nsr/cvD/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a41e868d5a8a79adbf1a56bac12308877dd31ba
Binary files /dev/null and b/nsr/cvD/__pycache__/__init__.cpython-39.pyc differ
diff --git a/nsr/cvD/__pycache__/canoD.cpython-39.pyc b/nsr/cvD/__pycache__/canoD.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f8f953224e2c13e00fc3b28604c8833fa32db74
Binary files /dev/null and b/nsr/cvD/__pycache__/canoD.cpython-39.pyc differ
diff --git a/nsr/cvD/__pycache__/nvsD.cpython-39.pyc b/nsr/cvD/__pycache__/nvsD.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dab57efb2716afbd8d8ea92aa87934ae69b14da2
Binary files /dev/null and b/nsr/cvD/__pycache__/nvsD.cpython-39.pyc differ
diff --git a/nsr/cvD/__pycache__/nvsD_canoD.cpython-39.pyc b/nsr/cvD/__pycache__/nvsD_canoD.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ec69e89aaf3fabcd95523f211399993c6b57462
Binary files /dev/null and b/nsr/cvD/__pycache__/nvsD_canoD.cpython-39.pyc differ
diff --git a/nsr/cvD/__pycache__/nvsD_canoD_mask.cpython-39.pyc b/nsr/cvD/__pycache__/nvsD_canoD_mask.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..063c8fecaa209ec3364fbe1ecef034fe5a618ced
Binary files /dev/null and b/nsr/cvD/__pycache__/nvsD_canoD_mask.cpython-39.pyc differ
diff --git a/nsr/cvD/__pycache__/nvsD_canoD_multiview.cpython-39.pyc b/nsr/cvD/__pycache__/nvsD_canoD_multiview.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b3a731bc521c75bad15484f9af728ecef240c27
Binary files /dev/null and b/nsr/cvD/__pycache__/nvsD_canoD_multiview.cpython-39.pyc differ
diff --git a/nsr/cvD/__pycache__/nvsD_nosr.cpython-39.pyc b/nsr/cvD/__pycache__/nvsD_nosr.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fbfe744686164b44ad047a656113784255e4d1b
Binary files /dev/null and b/nsr/cvD/__pycache__/nvsD_nosr.cpython-39.pyc differ
diff --git a/nsr/cvD/nvsD_canoD.py b/nsr/cvD/nvsD_canoD.py
new file mode 100644
index 0000000000000000000000000000000000000000..e44b0fdcb9d6701bb3a13e73ea67c519bccda4a5
--- /dev/null
+++ b/nsr/cvD/nvsD_canoD.py
@@ -0,0 +1,1021 @@
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+import torchvision
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from tqdm import tqdm
+
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion import dist_util, logger
+from guided_diffusion.train_util import (calc_average_loss,
+ log_rec3d_loss_dict,
+ find_resume_checkpoint)
+
+from torch.optim import AdamW
+
+from ..train_util import TrainLoopBasic, TrainLoop3DRec
+import vision_aided_loss
+from dnnlib.util import calculate_adaptive_weight
+
+def flip_yaw(pose_matrix):
+ flipped = pose_matrix.clone()
+ flipped[:, 0, 1] *= -1
+ flipped[:, 0, 2] *= -1
+ flipped[:, 1, 0] *= -1
+ flipped[:, 2, 0] *= -1
+ flipped[:, 0, 3] *= -1
+ # st()
+ return flipped
+
+
+def get_blob_logdir():
+ # You can change this to be a separate path to save checkpoints to
+ # a blobstore or some external drive.
+ return logger.get_dir()
+
+
+from ..train_util_cvD import TrainLoop3DcvD
+# from .nvD import
+
+
+class TrainLoop3DcvD_nvsD_canoD(TrainLoop3DcvD):
+ # class TrainLoop3DcvD_nvsD_canoD():
+
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ use_amp=use_amp,
+ **kwargs)
+
+ device = dist_util.dev()
+
+ self.cano_cvD = vision_aided_loss.Discriminator(
+ cv_type='clip', loss_type='multilevel_sigmoid_s',
+ device=device).to(device)
+ self.cano_cvD.cv_ensemble.requires_grad_(
+ False) # Freeze feature extractor
+ # self.cano_cvD.train()
+
+ cvD_model_params = list(self.cano_cvD.parameters())
+ SR_TRAINING = False
+ if SR_TRAINING: # replace the conv1 with 6 channel input
+ # width, patch_size = self.nvs_cvD.cv_ensemble
+ vision_width, vision_patch_size = [
+ self.cano_cvD.cv_ensemble.models[0].model.conv1.weight.shape[k]
+ for k in [0, -1]
+ ]
+ self.cano_cvD.cv_ensemble.models[0].model.conv1 = th.nn.Conv2d(
+ in_channels=6,
+ out_channels=vision_width,
+ kernel_size=vision_patch_size,
+ stride=vision_patch_size,
+ bias=False).to(dist_util.dev())
+ cvD_model_params += list(
+ self.cano_cvD.cv_ensemble.models[0].model.conv1.parameters())
+
+ self.cano_cvD.cv_ensemble.models[
+ 0].image_mean = self.cano_cvD.cv_ensemble.models[
+ 0].image_mean.repeat(2)
+ self.cano_cvD.cv_ensemble.models[
+ 0].image_std = self.cano_cvD.cv_ensemble.models[
+ 0].image_std.repeat(2)
+
+ # logger.log(f'cano_cvD_model_params: {cvD_model_params}')
+
+ self._load_and_sync_parameters(model=self.cano_cvD,
+ model_name='cano_cvD')
+
+ self.mp_trainer_canonical_cvD = MixedPrecisionTrainer(
+ model=self.cano_cvD,
+ use_fp16=self.use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ model_name='canonical_cvD',
+ use_amp=use_amp,
+ model_params=cvD_model_params)
+
+ # cano_lr = 2e-5 * (lr / 1e-5) # D_lr=2e-4 in cvD by default
+ # cano_lr = 5e-5 * (lr / 1e-5) # D_lr=2e-4 in cvD by default
+ cano_lr = 2e-4 * (
+ lr / 1e-5) # D_lr=2e-4 in cvD by default. 1e-4 still overfitting
+ self.opt_cano_cvD = AdamW(
+ self.mp_trainer_canonical_cvD.master_params,
+ lr=cano_lr, # same as the G
+ betas=(0, 0.999),
+ eps=1e-8) # dlr in biggan cfg
+
+ logger.log(f'cpt_cano_cvD lr: {cano_lr}')
+
+ if self.use_ddp:
+ self.ddp_cano_cvD = DDP(
+ self.cano_cvD,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ self.ddp_cano_cvD = self.cano_cvD
+
+ th.cuda.empty_cache()
+
+ def run_step(self, batch, step='g_step'):
+ # self.forward_backward(batch)
+
+ if step == 'g_step_rec':
+ self.forward_G_rec(batch)
+ took_step_g_rec = self.mp_trainer_rec.optimize(self.opt)
+
+ if took_step_g_rec:
+ self._update_ema() # g_ema
+
+ elif step == 'd_step_rec':
+ self.forward_D(batch, behaviour='rec')
+ # _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+ _ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD)
+
+ elif step == 'g_step_nvs':
+ self.forward_G_nvs(batch)
+ took_step_g_nvs = self.mp_trainer_rec.optimize(self.opt)
+
+ if took_step_g_nvs:
+ self._update_ema() # g_ema
+
+ elif step == 'd_step_nvs':
+ self.forward_D(batch, behaviour='nvs')
+ _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+ # _ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD)
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch, cond = next(self.data)
+ # if batch is None:
+ batch = next(self.data)
+
+ if self.novel_view_poses is None:
+ self.novel_view_poses = th.roll(batch['c'], 1, 0).to(
+ dist_util.dev()) # save for eval visualization use
+
+ self.run_step(batch, 'g_step_rec')
+
+ # if self.step % 2 == 0:
+ batch = next(self.data)
+ self.run_step(batch, 'd_step_rec')
+
+ # if self.step % 2 == 1:
+ batch = next(self.data)
+ self.run_step(batch, 'g_step_nvs')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step_nvs')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ # if self.step % self.eval_interval == 0 and self.step != 0:
+ if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ self.eval_loop()
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save()
+ self.save(self.mp_trainer_cvD, self.mp_trainer_cvD.model_name)
+ self.save(self.mp_trainer_canonical_cvD,
+ self.mp_trainer_canonical_cvD.model_name)
+
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+
+ self.save()
+ self.save(self.mp_trainer_cvD,
+ self.mp_trainer_cvD.model_name)
+ self.save(self.mp_trainer_canonical_cvD,
+ self.mp_trainer_canonical_cvD.model_name)
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ self.save(self.mp_trainer_canonical_cvD, 'cvD')
+
+ def forward_D(self, batch, behaviour): # update D
+ self.mp_trainer_canonical_cvD.zero_grad()
+ self.mp_trainer_cvD.zero_grad()
+
+ self.rec_model.requires_grad_(False)
+ # self.ddp_model.requires_grad_(False)
+
+ # update two D
+ if behaviour == 'nvs':
+ self.ddp_nvs_cvD.requires_grad_(True)
+ self.ddp_cano_cvD.requires_grad_(False)
+ else: # update rec canonical D
+ self.ddp_nvs_cvD.requires_grad_(False)
+ self.ddp_cano_cvD.requires_grad_(True)
+
+ batch_size = batch['img'].shape[0]
+
+ # * sample a new batch for D training
+ for i in range(0, batch_size, self.microbatch):
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous()
+ for k, v in batch.items()
+ }
+
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_canonical_cvD.use_amp):
+
+ novel_view_c = th.cat([micro['c'][1:], micro['c'][:1]])
+
+ latent = self.rec_model(img=micro['img_to_encoder'],
+ behaviour='enc_dec_wo_triplane')
+
+ cano_pred = self.rec_model(latent=latent,
+ c=micro['c'],
+ behaviour='triplane_dec')
+
+ # TODO, optimize with one encoder, and two triplane decoder
+ # FIXME: quit autocast to runbackward
+ if behaviour == 'rec':
+
+ if 'image_sr' in cano_pred:
+ # try concat them in batch
+ d_loss = self.run_D_Diter(
+ real=th.cat([
+ th.nn.functional.interpolate(
+ micro['img'],
+ size=micro['img_sr'].shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=True),
+ micro['img_sr'],
+ ],
+ dim=1),
+ fake=th.cat([
+ th.nn.functional.interpolate(
+ cano_pred['image_raw'],
+ size=cano_pred['image_sr'].shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=True),
+ cano_pred['image_sr'],
+ ],
+ dim=1),
+ D=self.ddp_cano_cvD) # TODO, add SR for FFHQ
+
+ else:
+ d_loss = self.run_D_Diter(
+ real=micro['img'],
+ fake=cano_pred['image_raw'],
+ D=self.ddp_cano_cvD) # TODO, add SR for FFHQ
+
+ log_rec3d_loss_dict(
+ {'vision_aided_loss/D_cano': d_loss})
+ # self.mp_trainer_canonical_cvD.backward(d_loss)
+ else:
+ assert behaviour == 'nvs'
+
+ nvs_pred = self.rec_model(latent=latent,
+ c=novel_view_c,
+ behaviour='triplane_dec')
+
+ if 'image_sr' in nvs_pred:
+
+ d_loss = self.run_D_Diter(
+ real=th.cat([
+ th.nn.functional.interpolate(
+ cano_pred['image_raw'],
+ size=cano_pred['image_sr'].shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=True),
+ cano_pred['image_sr'],
+ ],
+ dim=1),
+ fake=th.cat([
+ th.nn.functional.interpolate(
+ nvs_pred['image_raw'],
+ size=nvs_pred['image_sr'].shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=True),
+ nvs_pred['image_sr'],
+ ],
+ dim=1),
+ D=self.ddp_nvs_cvD) # TODO, add SR for FFHQ
+
+ else:
+ d_loss = self.run_D_Diter(
+ real=cano_pred['image_raw'],
+ fake=nvs_pred['image_raw'],
+ D=self.ddp_nvs_cvD) # TODO, add SR for FFHQ
+
+ log_rec3d_loss_dict(
+ {'vision_aided_loss/D_nvs': d_loss})
+ # self.mp_trainer_cvD.backward(d_loss)
+
+ if behaviour == 'rec':
+ self.mp_trainer_canonical_cvD.backward(d_loss)
+ else:
+ assert behaviour == 'nvs'
+ self.mp_trainer_cvD.backward(d_loss)
+
+ def forward_G_rec(self, batch): # update G
+
+ self.mp_trainer_rec.zero_grad()
+ self.rec_model.requires_grad_(True)
+
+ self.ddp_cano_cvD.requires_grad_(False)
+ self.ddp_nvs_cvD.requires_grad_(False)
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous()
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ pred = self.rec_model(
+ img=micro['img_to_encoder'], c=micro['c']
+ ) # render novel view for first half of the batch for D loss
+
+ target_for_rec = micro
+ cano_pred = pred
+
+ # if last_batch or not self.use_ddp:
+ # loss, loss_dict = self.loss_class(cano_pred,
+ # target_for_rec,
+ # test_mode=False,
+ # step=self.step +
+ # self.resume_step)
+ # else:
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict, fg_mask = self.loss_class(cano_pred,
+ target_for_rec,
+ test_mode=False,
+ step=self.step +
+ self.resume_step,
+ return_fg_mask=True)
+
+ # cano_pred_img = cano_pred['image_raw']
+
+ if self.loss_class.opt.symmetry_loss:
+ pose, intrinsics = micro['c'][:, :16].reshape(
+ -1, 4, 4), micro['c'][:, 16:]
+ flipped_pose = flip_yaw(pose)
+ mirror_c = th.cat(
+ [flipped_pose.reshape(-1, 16), intrinsics], -1)
+
+ nvs_pred = self.rec_model(latent={
+ k: v
+ for k, v in pred.items() if 'latent' in k
+ },
+ c=mirror_c,
+ behaviour='triplane_dec',
+ return_raw_only=True)
+ # cano_pred_img = th.cat([cano_pred_img, nvs_pred['image_raw']], 0)
+
+ # concat data for supervision
+ nvs_gt = {
+ k: th.flip(target_for_rec[k], [-1])
+ for k in
+ ['img'] # fliplr leads to wrong color; B 3 H W shape
+ }
+ flipped_fg_mask = th.flip(fg_mask, [-1])
+ if 'conf_sigma' in pred:
+ conf_sigma = th.flip(pred['conf_sigma'], [-1])
+ conf_sigma = th.nn.AdaptiveAvgPool2d(fg_mask.shape[-2:])(conf_sigma) # dynamically resize to target img size
+ else:
+ conf_sigma=None
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss_symm, loss_dict_symm = self.loss_class.calc_2d_rec_loss(
+ nvs_pred['image_raw'],
+ nvs_gt['img'],
+ flipped_fg_mask,
+ # test_mode=True,
+ test_mode=False,
+ step=self.step + self.resume_step,
+ conf_sigma=conf_sigma,
+ )
+ loss += (loss_symm * 1.0) # as in unsup3d
+ # if conf_sigma is not None:
+ # conf_loss = th.nn.functional.mse_loss(conf_sigma, flipped_fg_mask) * 0.2
+ # loss += conf_loss # a log that regularizes all confidence to 1
+ # loss_dict[f'conf_loss'] = conf_loss
+ for k, v in loss_dict_symm.items():
+ loss_dict[f'{k}_symm'] = v
+
+
+ # add cvD supervision
+ # ! TODO
+
+ if 'image_sr' in cano_pred:
+ # concat both resolution
+ vision_aided_loss = self.ddp_cano_cvD(
+ th.cat([
+ th.nn.functional.interpolate(
+ cano_pred['image_raw'],
+ size=cano_pred['image_sr'].shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=True),
+ cano_pred['image_sr'],
+ ],
+ dim=1), # 6 channel input
+ for_G=True).mean() # [B, 1] shape
+
+ else:
+ vision_aided_loss = self.ddp_cano_cvD(
+ cano_pred['image_raw'],
+ for_G=True).mean() # [B, 1] shape
+
+ # last_layer = self.rec_model.module.decoder.triplane_decoder.decoder.net[ # type: ignore
+ # -1].weight # type: ignore
+
+ d_weight = th.tensor(self.loss_class.opt.rec_cvD_lambda).to(
+ dist_util.dev())
+ # d_weight = calculate_adaptive_weight(
+ # loss,
+ # vision_aided_loss,
+ # last_layer,
+ # disc_weight_max=0.1) * self.loss_class.opt.rec_cvD_lambda
+ loss += vision_aided_loss * d_weight
+
+ loss_dict.update({
+ 'vision_aided_loss/G_rec':
+ (vision_aided_loss * d_weight).detach(),
+ 'd_weight':
+ d_weight
+ })
+
+ log_rec3d_loss_dict(loss_dict)
+
+ self.mp_trainer_rec.backward(
+ loss) # no nvs cvD loss, following VQ3D
+
+ # DDP some parameters no grad:
+ # for name, p in self.ddp_model.named_parameters():
+ # if p.grad is None:
+ # print(f"(in rec)found rec unused param: {name}")
+
+ # ! move to other places, add tensorboard
+
+ # if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ # with th.no_grad():
+ # # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ # gt_depth = micro['depth']
+ # if gt_depth.ndim == 3:
+ # gt_depth = gt_depth.unsqueeze(1)
+ # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ # gt_depth.min())
+ # # if True:
+ # pred_depth = pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (
+ # pred_depth.max() - pred_depth.min())
+ # pred_img = pred['image_raw']
+ # gt_img = micro['img']
+
+ # if 'image_sr' in pred:
+ # if pred['image_sr'].shape[-1] == 512:
+ # pred_img = th.cat(
+ # [self.pool_512(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat(
+ # [self.pool_512(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_512(pred_depth)
+ # gt_depth = self.pool_512(gt_depth)
+
+ # elif pred['image_sr'].shape[-1] == 256:
+ # pred_img = th.cat(
+ # [self.pool_256(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat(
+ # [self.pool_256(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_256(pred_depth)
+ # gt_depth = self.pool_256(gt_depth)
+
+ # else:
+ # pred_img = th.cat(
+ # [self.pool_128(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat(
+ # [self.pool_128(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # gt_depth = self.pool_128(gt_depth)
+ # pred_depth = self.pool_128(pred_depth)
+ # else:
+ # gt_img = self.pool_64(gt_img)
+ # gt_depth = self.pool_64(gt_depth)
+
+ # gt_vis = th.cat(
+ # [gt_img, gt_depth.repeat_interleave(3, dim=1)],
+ # dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ # pred_vis = th.cat(
+ # [pred_img,
+ # pred_depth.repeat_interleave(3, dim=1)],
+ # dim=-1) # B, 3, H, W
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ # 1, 2, 0).cpu() # ! pred in range[-1, 1]
+ # # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ # vis = vis.numpy() * 127.5 + 127.5
+ # vis = vis.clip(0, 255).astype(np.uint8)
+ # Image.fromarray(vis).save(
+ # f'{logger.get_dir()}/{self.step+self.resume_step}_rec.jpg'
+ # )
+ # print(
+ # 'log vis to: ',
+ # f'{logger.get_dir()}/{self.step+self.resume_step}_rec.jpg'
+ # )
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ def norm_depth(pred_depth): # to [-1,1]
+ # pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ return -(pred_depth * 2 - 1)
+
+ pred_img = pred['image_raw'].clip(-1,1)
+ gt_img = micro['img']
+
+ # infer novel view also
+ pred_nv_img = self.rec_model(
+ img=micro['img_to_encoder'],
+ c=self.novel_view_poses) # pred: (B, 3, 64, 64)
+
+ # if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = norm_depth(gt_depth)
+ # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ # gt_depth.min())
+ # if True:
+ if 'image_depth' in pred:
+ # pred_depth = pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (
+ # pred_depth.max() - pred_depth.min())
+ pred_depth = norm_depth(pred['image_depth'])
+ pred_nv_depth = norm_depth(
+ pred_nv_img['image_depth'])
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+ pred_nv_depth = th.zeros_like(gt_depth)
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat(
+ [self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat(
+ [self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+
+ if gt_img.shape[-1] == 64:
+ gt_depth = self.pool_64(gt_depth)
+ elif gt_img.shape[-1] == 128:
+ gt_depth = self.pool_128(gt_depth)
+ # else:
+ # gt_depth = self.pool_64(gt_depth)
+
+ # st()
+ pred_vis = th.cat(
+ [pred_img,
+ pred_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # B, 3, H, W
+
+ pred_vis_nv = th.cat([
+ pred_nv_img['image_raw'].clip(-1,1),
+ pred_nv_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+ pred_vis = th.cat([pred_vis, pred_vis_nv],
+ dim=-2) # cat in H dim
+
+ gt_vis = th.cat(
+ [gt_img, gt_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ vis = th.cat([gt_vis, pred_vis], dim=-2)
+ # .permute(
+ # 0, 2, 3, 1).cpu()
+ vis_tensor = torchvision.utils.make_grid(
+ vis, nrow=vis.shape[-1] // 64) # HWC
+ torchvision.utils.save_image(
+ vis_tensor,
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', normalize=True, value_range=(-1,1))
+ # vis = vis.numpy() * 127.5 + 127.5
+ # vis = vis.clip(0, 255).astype(np.uint8)
+
+ # Image.fromarray(vis).save(
+ # f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ logger.log(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+
+ def forward_G_nvs(self, batch): # update G
+
+ self.mp_trainer_rec.zero_grad()
+ self.rec_model.requires_grad_(True)
+
+ self.ddp_cano_cvD.requires_grad_(False)
+ self.ddp_nvs_cvD.requires_grad_(False) # only use novel view D
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous()
+ for k, v in batch.items()
+ }
+
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ nvs_pred = self.rec_model(
+ img=micro['img_to_encoder'],
+ c=th.cat([
+ micro['c'][1:],
+ micro['c'][:1],
+ ])) # ! render novel views only for D loss
+
+ # add cvD supervision
+
+ if 'image_sr' in nvs_pred:
+ # concat sr and raw results
+ vision_aided_loss = self.ddp_nvs_cvD(
+ # pred_nv['image_sr'],
+ # 0.5 * pred_nv['image_sr'] + 0.5 * th.nn.functional.interpolate(pred_nv['image_raw'], size=pred_nv['image_sr'].shape[2:], mode='bilinear'),
+ th.cat([
+ th.nn.functional.interpolate(
+ nvs_pred['image_raw'],
+ size=nvs_pred['image_sr'].shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=True),
+ nvs_pred['image_sr'],
+ ],
+ dim=1),
+ for_G=True).mean() # ! for debugging
+
+ # supervise sr only
+ # vision_aided_loss = self.ddp_nvs_cvD(
+ # # pred_nv['image_sr'],
+ # # 0.5 * pred_nv['image_sr'] + 0.5 * th.nn.functional.interpolate(pred_nv['image_raw'], size=pred_nv['image_sr'].shape[2:], mode='bilinear'),
+ # th.cat([nvs_pred['image_sr'],
+ # th.nn.functional.interpolate(nvs_pred['image_raw'], size=nvs_pred['image_sr'].shape[2:], mode='bilinear',
+ # align_corners=False,
+ # antialias=True),]),
+ # for_G=True).mean() # ! for debugging
+
+ # pred_nv['image_raw'], for_G=True).mean() # [B, 1] shape
+ else:
+ vision_aided_loss = self.ddp_nvs_cvD(
+ nvs_pred['image_raw'],
+ for_G=True).mean() # [B, 1] shape
+
+ loss = vision_aided_loss * self.loss_class.opt.nvs_cvD_lambda
+
+ log_rec3d_loss_dict({
+ 'vision_aided_loss/G_nvs': loss
+ # vision_aided_loss * self.loss_class.opt.nvs_cvD_lambda,
+ })
+
+ self.mp_trainer_rec.backward(loss)
+
+ # ! move to other places, add tensorboard
+
+ # if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ if dist_util.get_rank() == 0 and self.step % 500 == 1:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ def norm_depth(pred_depth): # to [-1,1]
+ # pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ return -(pred_depth * 2 - 1)
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = norm_depth(gt_depth)
+
+ # if True:
+ # pred_depth = nvs_pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (
+ # pred_depth.max() - pred_depth.min())
+ pred_depth = norm_depth(nvs_pred['image_depth'])
+ pred_img = nvs_pred['image_raw']
+ gt_img = micro['img']
+
+ if 'image_sr' in nvs_pred:
+
+ if nvs_pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat([
+ self.pool_512(pred_img), nvs_pred['image_sr']
+ ],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif nvs_pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat([
+ self.pool_256(pred_img), nvs_pred['image_sr']
+ ],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat([
+ self.pool_128(pred_img), nvs_pred['image_sr']
+ ],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+
+
+ if gt_img.shape[-1] == 64:
+ gt_depth = self.pool_64(gt_depth)
+ elif gt_img.shape[-1] == 128:
+ gt_depth = self.pool_128(gt_depth)
+
+ # else:
+ # gt_img = self.pool_64(gt_img)
+ # gt_depth = self.pool_64(gt_depth)
+
+ gt_vis = th.cat(
+ [gt_img, gt_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ pred_vis = th.cat(
+ [pred_img,
+ pred_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # B, 3, H, W
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ # 1, 2, 0).cpu() # ! pred in range[-1, 1]
+ vis = th.cat([gt_vis, pred_vis], dim=-2)
+
+ vis = torchvision.utils.make_grid(
+ vis,
+ normalize=True,
+ scale_each=True,
+ value_range=(-1, 1)).cpu().permute(1, 2, 0) # H W 3
+ vis = vis.numpy() * 255
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # print(vis.shape)
+
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg'
+ )
+ print(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg'
+ )
+
+class TrainLoop3DcvD_nvsD_canoD_eg3d(TrainLoop3DcvD_nvsD_canoD):
+ def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, use_amp=False, **kwargs):
+ super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, use_amp=use_amp, **kwargs)
+ self.rendering_kwargs = self.rec_model.module.decoder.triplane_decoder.rendering_kwargs # type: ignore
+ self._prepare_nvs_pose() # for eval novelview visualization
+
+ @th.inference_mode()
+ def eval_novelview_loop(self):
+ # novel view synthesis given evaluation camera trajectory
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_batch_{i}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ for idx, c in enumerate(self.all_nvs_params):
+ pred = self.rec_model(img=micro['img_to_encoder'],
+ c=c.unsqueeze(0).repeat_interleave(micro['img'].shape[0], 0)) # pred: (B, 3, 64, 64)
+ # c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+
+ # st()
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ # ! cooncat h dim
+ pred_vis = pred_vis.permute(0,2,3,1).flatten(0,1) # H W 3
+
+ # vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ # vis = pred_vis.permute(1,2,0).cpu().numpy()
+ vis = pred_vis.cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # for j in range(vis.shape[0]):
+ # video_out.append_data(vis[j])
+ video_out.append_data(vis)
+
+ video_out.close()
+
+ th.cuda.empty_cache()
+
+
+ def _prepare_nvs_pose(self):
+ from nsr.camera_utils import LookAtPoseSampler, FOV_to_intrinsics
+
+ device = dist_util.dev()
+
+ fov_deg = 18.837 # for ffhq/afhq
+ intrinsics = FOV_to_intrinsics(fov_deg, device=device)
+
+ all_nvs_params = []
+
+ pitch_range = 0.25
+ yaw_range = 0.35
+ num_keyframes = 10 # how many nv poses to sample from
+ w_frames = 1
+
+ cam_pivot = th.Tensor(self.rendering_kwargs.get('avg_camera_pivot')).to(device)
+ cam_radius = self.rendering_kwargs.get('avg_camera_radius')
+
+ for frame_idx in range(num_keyframes):
+
+ cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / (num_keyframes * w_frames)),
+ 3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / (num_keyframes * w_frames)),
+ cam_pivot, radius=cam_radius, device=device)
+
+ camera_params = th.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
+
+ all_nvs_params.append(camera_params)
+
+ self.all_nvs_params = th.cat(all_nvs_params, 0)
\ No newline at end of file
diff --git a/nsr/dual_discriminator.py b/nsr/dual_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c8f39fa4857d7c41a017c9a9d736dadb99f5c2a
--- /dev/null
+++ b/nsr/dual_discriminator.py
@@ -0,0 +1,480 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Discriminator architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+from utils.torch_utils import persistence
+from utils.torch_utils.ops import upfirdn2d
+from .networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue
+from pdb import set_trace as st
+
+
+@persistence.persistent_class
+class SingleDiscriminator(torch.nn.Module):
+ def __init__(
+ self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base=32768, # Overall multiplier for the number of channels.
+ channel_max=512, # Maximum number of channels in any layer.
+ num_fp16_res=4, # Use FP16 for the N highest resolutions.
+ conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
+ sr_upsample_factor=1, # Ignored for SingleDiscriminator
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [
+ 2**i for i in range(self.img_resolution_log2, 2, -1)
+ ]
+ channels_dict = {
+ res: min(channel_base // res, channel_max)
+ for res in self.block_resolutions + [4]
+ }
+ fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res),
+ 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels,
+ architecture=architecture,
+ conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels,
+ tmp_channels,
+ out_channels,
+ resolution=res,
+ first_layer_idx=cur_layer_idx,
+ use_fp16=use_fp16,
+ **block_kwargs,
+ **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0,
+ c_dim=c_dim,
+ w_dim=cmap_dim,
+ num_ws=None,
+ w_avg_beta=None,
+ **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4],
+ cmap_dim=cmap_dim,
+ resolution=4,
+ **epilogue_kwargs,
+ **common_kwargs)
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ img = img['image']
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+
+#----------------------------------------------------------------------------
+
+
+def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
+ if filter_mode == 'antialiased':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor,
+ size=(size, size),
+ mode='bilinear',
+ align_corners=False,
+ antialias=True)
+ elif filter_mode == 'classic':
+ ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
+ ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64,
+ size=(size * 2 + 2,
+ size * 2 + 2),
+ mode='bilinear',
+ align_corners=False)
+ ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64,
+ f,
+ down=2,
+ flip_filter=True,
+ padding=-1)
+ elif filter_mode == 'none':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor,
+ size=(size, size),
+ mode='bilinear',
+ align_corners=False)
+ elif type(filter_mode) == float:
+ assert 0 < filter_mode < 1
+
+ filtered = torch.nn.functional.interpolate(image_orig_tensor,
+ size=(size, size),
+ mode='bilinear',
+ align_corners=False,
+ antialias=True)
+ aliased = torch.nn.functional.interpolate(image_orig_tensor,
+ size=(size, size),
+ mode='bilinear',
+ align_corners=False,
+ antialias=False)
+ ada_filtered_64 = (1 -
+ filter_mode) * aliased + (filter_mode) * filtered
+
+ return ada_filtered_64
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class DualDiscriminator(torch.nn.Module):
+ def __init__(
+ self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base=32768, # Overall multiplier for the number of channels.
+ channel_max=512, # Maximum number of channels in any layer.
+ num_fp16_res=4, # Use FP16 for the N highest resolutions.
+ conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
+ disc_c_noise=0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ # img_channels *= 2
+ if img_channels == 3:
+ img_channels *= 2
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [
+ 2**i for i in range(self.img_resolution_log2, 2, -1)
+ ]
+ channels_dict = {
+ res: min(channel_base // res, channel_max)
+ for res in self.block_resolutions + [4]
+ }
+ fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res),
+ 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels,
+ architecture=architecture,
+ conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels,
+ tmp_channels,
+ out_channels,
+ resolution=res,
+ first_layer_idx=cur_layer_idx,
+ use_fp16=use_fp16,
+ **block_kwargs,
+ **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0,
+ c_dim=c_dim,
+ w_dim=cmap_dim,
+ num_ws=None,
+ w_avg_beta=None,
+ **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4],
+ cmap_dim=cmap_dim,
+ resolution=4,
+ **epilogue_kwargs,
+ **common_kwargs)
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter([1, 3, 3, 1]))
+ self.disc_c_noise = disc_c_noise
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'],
+ # size=img['image'].shape[-1],
+ size=img['image_sr'].shape[-1],
+ f=self.resample_filter)
+ # img = torch.cat([img['image'], image_raw], 1)
+ img = torch.cat([img['image_sr'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ if self.disc_c_noise > 0:
+ c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+
+@persistence.persistent_class
+class GeoDualDiscriminator(DualDiscriminator):
+ def __init__(self, c_dim, img_resolution, img_channels, architecture='resnet', channel_base=32768, channel_max=512, num_fp16_res=4, conv_clamp=256, cmap_dim=None, disc_c_noise=0, block_kwargs={}, mapping_kwargs={}, epilogue_kwargs={}, normal_condition=False):
+ super().__init__(c_dim, img_resolution, img_channels, architecture, channel_base, channel_max, num_fp16_res, conv_clamp, cmap_dim, disc_c_noise, block_kwargs, mapping_kwargs, epilogue_kwargs)
+ self.normal_condition = normal_condition
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ image= img['image']
+ image_raw = filtered_resizing(img['image_raw'],
+ size=img['image'].shape[-1],
+ f=self.resample_filter)
+ D_input_img = torch.cat([image, image_raw], 1)
+
+ image_depth = filtered_resizing(img['image_depth'], size=img['image'].shape[-1], f=self.resample_filter)
+ if self.normal_condition and 'normal' in img:
+ image_normal = filtered_resizing(img['normal'], size=img['image'].shape[-1], f=self.resample_filter)
+ D_input_img = torch.cat([D_input_img, image_depth, image_normal], 1)
+ else:
+ D_input_img = torch.cat([D_input_img, image_depth], 1)
+
+ img = D_input_img
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ if self.disc_c_noise > 0:
+ c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class DummyDualDiscriminator(torch.nn.Module):
+ def __init__(
+ self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base=32768, # Overall multiplier for the number of channels.
+ channel_max=512, # Maximum number of channels in any layer.
+ num_fp16_res=4, # Use FP16 for the N highest resolutions.
+ conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels *= 2
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [
+ 2**i for i in range(self.img_resolution_log2, 2, -1)
+ ]
+ channels_dict = {
+ res: min(channel_base // res, channel_max)
+ for res in self.block_resolutions + [4]
+ }
+ fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res),
+ 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels,
+ architecture=architecture,
+ conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels,
+ tmp_channels,
+ out_channels,
+ resolution=res,
+ first_layer_idx=cur_layer_idx,
+ use_fp16=use_fp16,
+ **block_kwargs,
+ **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0,
+ c_dim=c_dim,
+ w_dim=cmap_dim,
+ num_ws=None,
+ w_avg_beta=None,
+ **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4],
+ cmap_dim=cmap_dim,
+ resolution=4,
+ **epilogue_kwargs,
+ **common_kwargs)
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter([1, 3, 3, 1]))
+
+ self.raw_fade = 1
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ self.raw_fade = max(0, self.raw_fade - 1 / (500000 / 32))
+
+ image_raw = filtered_resizing(img['image_raw'],
+ size=img['image'].shape[-1],
+ f=self.resample_filter) * self.raw_fade
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+
+#----------------------------------------------------------------------------
+
+# panohead
+# Tri-discriminator: upsampled image, super-resolved image, and segmentation mask
+# V2: first concatenate imgs and seg mask, using only one conv block
+@persistence.persistent_class
+class MaskDualDiscriminatorV2(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ seg_resolution, # Input resolution.
+ seg_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels = img_channels * 2 + seg_channels
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+ self.disc_c_noise = disc_c_noise
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter)
+ seg = 2 * seg - 1 # normalize to [-1,1]
+ img = torch.cat([img['image'], image_raw, seg], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'c_dim={self.c_dim:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'seg_resolution={self.seg_resolution:d}, seg_channels={self.seg_channels:d}'])
\ No newline at end of file
diff --git a/nsr/losses/__init__.py b/nsr/losses/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..1447a4305c77bfbe3d50a0d27bb9d2d893ec7621
--- /dev/null
+++ b/nsr/losses/__init__.py
@@ -0,0 +1,10 @@
+# 2d reconstruction losses
+from .id_loss import IDLoss
+# from .lms import HeatmapLoss # for faces
+# from .lpips_deprecated.lpips import LPIPS
+
+# manage import
+__all__ = [
+ # 'LPIPS',
+ 'IDLoss',
+]
diff --git a/nsr/losses/__pycache__/__init__.cpython-39.pyc b/nsr/losses/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13306737bfd24d6fa403e400693b1629859acb86
Binary files /dev/null and b/nsr/losses/__pycache__/__init__.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/builder.cpython-39.pyc b/nsr/losses/__pycache__/builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e916fc36e3a9377608d2304e41aac5209cba00ae
Binary files /dev/null and b/nsr/losses/__pycache__/builder.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/helpers.cpython-39.pyc b/nsr/losses/__pycache__/helpers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..227f099ba874773e74b067b12ac3a5ece7906943
Binary files /dev/null and b/nsr/losses/__pycache__/helpers.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/id_loss.cpython-39.pyc b/nsr/losses/__pycache__/id_loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fe8504f21dfcaf36ff7688e0353d258ba2dca69
Binary files /dev/null and b/nsr/losses/__pycache__/id_loss.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/model_irse.cpython-39.pyc b/nsr/losses/__pycache__/model_irse.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87c611e42aa036f7e4aff2a51c5448ce143a7158
Binary files /dev/null and b/nsr/losses/__pycache__/model_irse.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/paths_config.cpython-39.pyc b/nsr/losses/__pycache__/paths_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6f10ddd838afee6585642014c6de0ae5e51bd17
Binary files /dev/null and b/nsr/losses/__pycache__/paths_config.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/sdfstudio_losses.cpython-39.pyc b/nsr/losses/__pycache__/sdfstudio_losses.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e1976d56d75ede382a118eb710a4c77d1ffb07c
Binary files /dev/null and b/nsr/losses/__pycache__/sdfstudio_losses.cpython-39.pyc differ
diff --git a/nsr/losses/__pycache__/vqperceptual.cpython-39.pyc b/nsr/losses/__pycache__/vqperceptual.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f9466272b393bf2b58c3bc9b1bfcb35b734cf62
Binary files /dev/null and b/nsr/losses/__pycache__/vqperceptual.cpython-39.pyc differ
diff --git a/nsr/losses/builder.py b/nsr/losses/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e391c16e00a5b71cf7fae4f461b051f5f064130c
--- /dev/null
+++ b/nsr/losses/builder.py
@@ -0,0 +1,999 @@
+EPS = 1e-7
+
+import kornia
+from typing import Dict, Iterator, List, Optional, Tuple, Union
+import torchvision
+from guided_diffusion import dist_util, logger
+from pdb import set_trace as st
+from torch.nn import functional as F
+import numpy as np
+import torch
+import torch.nn as nn
+import lpips
+
+from . import *
+
+from .sdfstudio_losses import ScaleAndShiftInvariantLoss
+from ldm.util import default, instantiate_from_config
+from .vqperceptual import hinge_d_loss, vanilla_d_loss
+from torch.autograd import Variable
+
+from math import exp
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7)
+def extract_key(prediction, key):
+ if isinstance(prediction, dict):
+ return prediction[key]
+ return prediction
+
+
+class SILogLoss(nn.Module):
+ """SILog loss (pixel-wise)"""
+
+ def __init__(self, beta=0.15):
+ super(SILogLoss, self).__init__()
+ self.name = 'SILog'
+ self.beta = beta
+
+ def forward(self,
+ input,
+ target,
+ mask=None,
+ interpolate=True,
+ return_interpolated=False):
+ # input = extract_key(input, KEY_OUTPUT)
+ if input.shape[-1] != target.shape[-1] and interpolate:
+ input = nn.functional.interpolate(input,
+ target.shape[-2:],
+ mode='bilinear',
+ align_corners=True)
+ intr_input = input
+ else:
+ intr_input = input
+
+ if target.ndim == 3:
+ target = target.unsqueeze(1)
+
+ if mask is not None:
+ if mask.ndim == 3:
+ mask = mask.unsqueeze(1)
+
+ input = input[mask]
+ target = target[mask]
+
+ # with torch.amp.autocast(enabled=False): # amp causes NaNs in this loss function
+
+ alpha = 1e-7
+ g = torch.log(input + alpha) - torch.log(target + alpha)
+
+ # n, c, h, w = g.shape
+ # norm = 1/(h*w)
+ # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2
+
+ Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2)
+
+ loss = 10 * torch.sqrt(Dg)
+
+ if torch.isnan(loss):
+ print("Nan SILog loss")
+ print("input:", input.shape)
+ print("target:", target.shape)
+ print("G", torch.sum(torch.isnan(g)))
+ print("Input min max", torch.min(input), torch.max(input))
+ print("Target min max", torch.min(target), torch.max(target))
+ print("Dg", torch.isnan(Dg))
+ print("loss", torch.isnan(loss))
+
+ if not return_interpolated:
+ return loss
+
+ return loss, intr_input
+
+
+def get_outnorm(x: torch.Tensor, out_norm: str = '') -> torch.Tensor:
+ """ Common function to get a loss normalization value. Can
+ normalize by either the batch size ('b'), the number of
+ channels ('c'), the image size ('i') or combinations
+ ('bi', 'bci', etc)
+ """
+ # b, c, h, w = x.size()
+ img_shape = x.shape
+
+ if not out_norm:
+ return 1
+
+ norm = 1
+ if 'b' in out_norm:
+ # normalize by batch size
+ # norm /= b
+ norm /= img_shape[0]
+ if 'c' in out_norm:
+ # normalize by the number of channels
+ # norm /= c
+ norm /= img_shape[-3]
+ if 'i' in out_norm:
+ # normalize by image/map size
+ # norm /= h*w
+ norm /= img_shape[-1] * img_shape[-2]
+
+ return norm
+
+
+class CharbonnierLoss(torch.nn.Module):
+ """Charbonnier Loss (L1)"""
+
+ def __init__(self, eps=1e-6, out_norm: str = 'bci'):
+ super(CharbonnierLoss, self).__init__()
+ self.eps = eps
+ self.out_norm = out_norm
+
+ def forward(self, x, y):
+ norm = get_outnorm(x, self.out_norm)
+ loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
+ return loss * norm
+
+
+def feature_vae_loss(feature):
+ # kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
+
+ # feature dim: B C H W
+ mu = feature.mean(1)
+ var = feature.var(1)
+ log_var = torch.log(var)
+ kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu**2 - var, dim=1), dim=0)
+ return kld
+
+
+def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff):
+ # return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff)
+ kl_lambda = max(
+ min(
+ min_kl_coeff + (max_kl_coeff - min_kl_coeff) *
+ (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff)
+ return torch.tensor(kl_lambda, device=dist_util.dev())
+
+
+def depth_smoothness_loss(alpha_pred, depth_pred):
+ # from PesonNeRF paper.
+ # all Tensor shape B 1 H W
+ geom_loss = (
+ alpha_pred[..., :-1] * alpha_pred[..., 1:] * (
+ depth_pred[..., :-1] - depth_pred[..., 1:] # W dim
+ ).square()).mean() # mean of ([8, 1, 64, 63])
+
+ geom_loss += (alpha_pred[..., :-1, :] * alpha_pred[..., 1:, :] *
+ (depth_pred[..., :-1, :] - depth_pred[..., 1:, :]).square()
+ ).mean() # H dim, ([8, 1, 63, 64])
+
+ return geom_loss
+
+
+# https://github.com/elliottwu/unsup3d/blob/master/unsup3d/networks.py#L140
+class LPIPSLoss(torch.nn.Module):
+
+ def __init__(
+ self,
+ loss_weight=1.0,
+ use_input_norm=True,
+ range_norm=True,
+ # n1p1_input=True,
+ ):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="alex", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ # if self.use_input_norm:
+ # # the mean is for image with range [0, 1]
+ # self.register_buffer(
+ # 'mean',
+ # torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # # the std is for image with range [0, 1]
+ # self.register_buffer(
+ # 'std',
+ # torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target, conf_sigma_percl=None):
+ # st()
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean()
+
+
+# mask-aware perceptual loss
+class PerceptualLoss(nn.Module):
+
+ def __init__(self, requires_grad=False):
+ super(PerceptualLoss, self).__init__()
+ mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406])
+ std_rgb = torch.FloatTensor([0.229, 0.224, 0.225])
+ self.register_buffer('mean_rgb', mean_rgb)
+ self.register_buffer('std_rgb', std_rgb)
+
+ vgg_pretrained_features = torchvision.models.vgg16(
+ pretrained=True).features
+ self.slice1 = nn.Sequential()
+ self.slice2 = nn.Sequential()
+ self.slice3 = nn.Sequential()
+ self.slice4 = nn.Sequential()
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def normalize(self, x):
+ out = x / 2 + 0.5
+ out = (out - self.mean_rgb.view(1, 3, 1, 1)) / self.std_rgb.view(
+ 1, 3, 1, 1)
+ return out
+
+ def __call__(self, im1, im2, mask=None, conf_sigma=None):
+ im = torch.cat([im1, im2], 0)
+ im = self.normalize(im) # normalize input
+
+ ## compute features
+ feats = []
+ f = self.slice1(im)
+ feats += [torch.chunk(f, 2, dim=0)]
+ f = self.slice2(f)
+ feats += [torch.chunk(f, 2, dim=0)]
+ f = self.slice3(f)
+ feats += [torch.chunk(f, 2, dim=0)]
+ f = self.slice4(f)
+ feats += [torch.chunk(f, 2, dim=0)]
+
+ losses = []
+ for f1, f2 in feats[2:3]: # use relu3_3 features only
+ loss = (f1 - f2)**2
+ if conf_sigma is not None:
+ loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma +
+ EPS).log()
+ if mask is not None:
+ b, c, h, w = loss.shape
+ _, _, hm, wm = mask.shape
+ sh, sw = hm // h, wm // w
+ mask0 = nn.functional.avg_pool2d(mask,
+ kernel_size=(sh, sw),
+ stride=(sh,
+ sw)).expand_as(loss)
+ loss = (loss * mask0).sum() / mask0.sum()
+ else:
+ loss = loss.mean()
+ losses += [loss]
+ return sum(losses)
+
+
+# add confidence support, unsup3d version
+def photometric_loss_laplace(im1, im2, mask=None, conf_sigma=None):
+ loss = (im1 - im2).abs()
+ # loss = (im1 - im2).square()
+ if conf_sigma is not None:
+ loss = loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + EPS).log()
+
+ if mask is not None:
+ mask = mask.expand_as(loss)
+ loss = (loss * mask).sum() / mask.sum()
+
+ else:
+ loss = loss.mean()
+
+ return loss
+
+
+# gaussian likelihood version, What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?
+# also used in the mask-aware vgg loss
+def photometric_loss(im1, im2, mask=None, conf_sigma=None):
+ # loss = torch.nn.functional.mse_loss(im1, im2, reduce='none')
+ loss = (im1 - im2).square()
+
+ if conf_sigma is not None:
+ loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + EPS).log()
+
+ if mask is not None:
+ mask = mask.expand_as(loss)
+ loss = (loss * mask).sum() / mask.sum()
+
+ else:
+ loss = loss.mean()
+
+ return loss
+
+
+class E3DGELossClass(torch.nn.Module):
+
+ def __init__(self, device, opt) -> None:
+ super().__init__()
+
+ self.opt = opt
+ self.device = device
+ self.criterionImg = {
+ 'mse': torch.nn.MSELoss(),
+ 'l1': torch.nn.L1Loss(),
+ 'charbonnier': CharbonnierLoss(),
+ }[opt.color_criterion]
+
+ self.criterion_latent = {
+ 'mse': torch.nn.MSELoss(),
+ 'l1': torch.nn.L1Loss(),
+ 'vae': feature_vae_loss
+ }[opt.latent_criterion]
+
+ # self.criterionLPIPS = LPIPS(net_type='alex', device=device).eval()
+ if opt.lpips_lambda > 0:
+ self.criterionLPIPS = LPIPSLoss(loss_weight=opt.lpips_lambda)
+ # self.criterionLPIPS = torch.nn.MSELoss()
+
+ if opt.id_lambda > 0:
+ self.criterionID = IDLoss(device=device).eval()
+ self.id_loss_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+
+ # define 3d rec loss, for occupancy
+ # self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none')
+ # self.criterion_alpha = torch.nn.SmoothL1Loss()
+
+ # self.criterion3d_rec = torch.nn.MSELoss(reduction='none')
+ self.criterion_alpha = torch.nn.L1Loss()
+
+ if self.opt.depth_lambda > 0:
+ self.criterion3d_rec = ScaleAndShiftInvariantLoss(alpha=0.5,
+ scales=1)
+ else:
+ self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none')
+
+ # self.silog_loss = SILogLoss()
+
+ logger.log('init loss class finished', )
+
+ def calc_scale_invariant_depth_loss(self, pred_depth: torch.Tensor,
+ gt_depth: torch.Tensor,
+ gt_depth_mask: torch.Tensor):
+ """apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss
+ """
+
+ shape_loss_dict = {}
+ assert gt_depth_mask is not None
+ shape_loss = self.criterion3d_rec(pred_depth, gt_depth, gt_depth_mask)
+
+ if shape_loss > 0.2: # hinge loss, avoid ood gradient
+ shape_loss = torch.zeros_like(shape_loss)
+ else:
+ shape_loss *= self.opt.depth_lambda
+
+ shape_loss_dict['loss_depth'] = shape_loss
+ shape_loss_dict['depth_fgratio'] = gt_depth_mask.mean()
+
+ # return l_si, shape_loss_dict
+ return shape_loss, shape_loss_dict
+
+ def calc_depth_loss(self, pred_depth: torch.Tensor, gt_depth: torch.Tensor,
+ gt_depth_mask: torch.Tensor):
+ """apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss
+ """
+
+ shape_loss_dict = {}
+ shape_loss = self.criterion3d_rec(pred_depth, gt_depth)
+ if gt_depth_mask is not None:
+ # pred_depth *= gt_depth_mask
+ # gt_depth *= gt_depth_mask
+ shape_loss *= gt_depth_mask
+ shape_loss = shape_loss.sum() / gt_depth_mask.sum()
+ # else:
+ # shape_loss /= pred_depth.numel()
+ # l_si = self.silog_loss(pred_depth, gt_depth, mask=None, interpolate=True, return_interpolated=False)
+
+ # l_si *= self.opt.depth_lambda
+ # shape_loss_dict['loss_depth'] = l_si
+ shape_loss_dict['loss_depth'] = shape_loss.clamp(
+ min=0, max=0.1) * self.opt.depth_lambda
+
+ # return l_si, shape_loss_dict
+ return shape_loss, shape_loss_dict
+
+ @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def calc_alpha_loss(self, pred_alpha, gt_depth_mask):
+ # return self.criterionImg(alpha, gt_depth_mask.float())
+
+ if gt_depth_mask.ndim == 3:
+ gt_depth_mask = gt_depth_mask.unsqueeze(1)
+
+ if gt_depth_mask.shape[1] == 3:
+ gt_depth_mask = gt_depth_mask[:, 0:1, ...] # B 1 H W
+
+ assert pred_alpha.shape == gt_depth_mask.shape
+
+ alpha_loss = self.criterion_alpha(pred_alpha, gt_depth_mask)
+
+ return alpha_loss
+
+ @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def calc_mask_mse_loss(
+ self,
+ input,
+ gt,
+ gt_depth_mask,
+ # conf_sigma=None,
+ conf_sigma_l1=None,
+ # conf_sigma_percl=None,
+ use_fg_ratio=False):
+ if gt_depth_mask.ndim == 3:
+ gt_depth_mask = gt_depth_mask.unsqueeze(1).repeat_interleave(3, 1)
+ else:
+ assert gt_depth_mask.shape == input.shape
+ gt_depth_mask = gt_depth_mask.float()
+
+ if conf_sigma_l1 is None:
+ rec_loss = torch.nn.functional.mse_loss(
+ input.float(), gt.float(),
+ reduction='none') # 'sum' already divide by batch size n
+ else:
+ rec_loss = photometric_loss(
+ input, gt, gt_depth_mask, conf_sigma_l1
+ ) # ! only cauclate laplace on the foreground, or bg confidence low, large gradient.
+ return rec_loss
+ # rec_loss = torch.nn.functional.l1_loss( # for laplace loss
+ # input.float(), gt.float(),
+ # reduction='none') # 'sum' already divide by batch size n
+ # gt_depth_mask = torch.ones_like(gt_depth_mask) # ! DEBUGGING
+
+ # if conf_sigma is not None: # from unsup3d, but a L2 version
+ # rec_loss = rec_loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma +
+ # EPS).log()
+ # return rec_loss.mean()
+ # rec_loss = torch.exp(-(rec_loss * 2**0.5 / (conf_sigma + EPS))) * 1/(conf_sigma +
+ # EPS) / (2**0.5)
+
+ fg_size = gt_depth_mask.sum()
+ # fg_ratio = fg_size / torch.ones_like(gt_depth_mask).sum() if use_fg_ratio else 1
+ fg_loss = rec_loss * gt_depth_mask
+ fg_loss = fg_loss.sum() / fg_size # * fg_ratio
+
+ if self.opt.bg_lamdba > 0:
+ bg_loss = rec_loss * (1 - gt_depth_mask)
+ bg_loss = bg_loss.sum() / (1 - gt_depth_mask).sum()
+ rec_loss = fg_loss + bg_loss * self.opt.bg_lamdba
+ else:
+ rec_loss = fg_loss
+
+ return rec_loss
+
+ @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def calc_2d_rec_loss(
+ self,
+ input,
+ gt,
+ depth_fg_mask,
+ test_mode=True,
+ step=1,
+ ignore_lpips=False,
+ # conf_sigma=None,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None,
+ ):
+ opt = self.opt
+ loss_dict = {}
+
+ # logger.log(test_mode)
+ # logger.log(input.min(), input.max(), gt.min(), gt.max())
+ if test_mode or not opt.fg_mse:
+ rec_loss = self.criterionImg(input, gt)
+ else:
+ rec_loss = self.calc_mask_mse_loss(
+ input,
+ gt,
+ depth_fg_mask,
+ conf_sigma_l1=conf_sigma_l1,
+ )
+ # conf_sigma_percl=conf_sigma_percl)
+ # conf_sigma)
+
+ # if step == 300:
+ # st()
+
+ if opt.lpips_lambda > 0 and step >= opt.lpips_delay_iter and not ignore_lpips: # tricky solution to avoid NAN in LPIPS loss
+
+ # with torch.autocast(device_type='cuda',
+ # dtype=torch.float16,
+ # enabled=False):
+ # if test_mode or not opt.fg_mse: # no need to calculate background lpips for ease of computation
+ lpips_loss = self.criterionLPIPS(
+ input,
+ gt,
+ conf_sigma_percl=conf_sigma_percl,
+ )
+ # else: # fg lpips
+ # assert depth_fg_mask.shape == input.shape
+ # lpips_loss = self.criterionLPIPS(
+ # input.contiguous() * depth_fg_mask,
+ # gt.contiguous() * depth_fg_mask).mean()
+ else:
+ lpips_loss = torch.tensor(0., device=input.device)
+
+ if opt.ssim_lambda > 0:
+ loss_ssim = self.ssim_loss(input, gt) #?
+ else:
+ loss_ssim = torch.tensor(0., device=input.device)
+
+ loss_psnr = self.psnr((input / 2 + 0.5), (gt / 2 + 0.5), 1.0)
+
+ if opt.id_lambda > 0:
+ loss_id = self._calc_loss_id(input, gt)
+ else:
+ loss_id = torch.tensor(0., device=input.device)
+
+ if opt.l1_lambda > 0:
+ loss_l1 = F.l1_loss(input, gt)
+ else:
+ loss_l1 = torch.tensor(0., device=input.device)
+
+ # loss = rec_loss * opt.l2_lambda + lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda
+ loss = rec_loss * opt.l2_lambda + lpips_loss + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda + opt.l1_lambda * loss_l1
+
+ # if return_dict:
+ loss_dict['loss_l2'] = rec_loss
+ loss_dict['loss_id'] = loss_id
+ loss_dict['loss_lpips'] = lpips_loss
+ loss_dict['loss'] = loss
+ loss_dict['loss_ssim'] = loss_ssim
+
+ # metrics to report, not involved in training
+ loss_dict['mae'] = loss_l1
+ loss_dict['PSNR'] = loss_psnr
+ loss_dict['SSIM'] = 1 - loss_ssim # Todo
+ loss_dict['ID_SIM'] = 1 - loss_id
+
+ return loss, loss_dict
+
+ @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def calc_shape_rec_loss(
+ self,
+ pred_shape: dict,
+ gt_shape: dict,
+ device,
+ ):
+ """apply 3d shape reconstruction supervision. Basically supervise the densities with L1 loss
+
+ Args:
+ pred_shape (dict): dict contains reconstructed shape information
+ gt_shape (dict): dict contains gt shape information
+ supervise_sdf (bool, optional): whether supervise sdf rec. Defaults to True.
+ supervise_surface_normal (bool, optional): whether supervise surface rec. Defaults to False.
+
+ Returns:
+ dict: shape reconstruction loss
+ """
+
+ shape_loss_dict = {}
+ shape_loss = 0
+ # assert supervise_sdf or supervise_surface_normal, 'should at least supervise one types of shape reconstruction'
+ # todo, add weights
+
+ if self.opt.shape_uniform_lambda > 0:
+ shape_loss_dict['coarse'] = self.criterion3d_rec(
+ pred_shape['coarse_densities'].squeeze(),
+ gt_shape['coarse_densities'].squeeze())
+ shape_loss += shape_loss_dict[
+ 'coarse'] * self.opt.shape_uniform_lambda
+
+ if self.opt.shape_importance_lambda > 0:
+ shape_loss_dict['fine'] = self.criterion3d_rec(
+ pred_shape['fine_densities'].squeeze(), # ? how to supervise
+ gt_shape['fine_densities'].squeeze())
+ shape_loss += shape_loss_dict[
+ 'fine'] * self.opt.shape_importance_lambda
+
+ loss_depth = self.criterion_alpha(pred_shape['image_depth'],
+ gt_shape['image_depth'])
+
+ shape_loss += loss_depth * self.opt.shape_depth_lambda
+ shape_loss_dict.update(dict(loss_depth=loss_depth))
+ # TODO, add on surface pts supervision ?
+
+ return shape_loss, shape_loss_dict
+
+ @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def psnr(self, input, target, max_val):
+ return kornia.metrics.psnr(input, target, max_val)
+
+ # @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def ssim_loss(self, img1, img2, window_size=11, size_average=True):
+ channel = img1.size(-3)
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return 1 - _ssim(img1, img2, window, window_size, channel, size_average)
+
+ @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False)
+ def forward(self,
+ pred,
+ gt,
+ test_mode=True,
+ step=1,
+ return_fg_mask=False,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None,
+ *args,
+ **kwargs):
+
+ with torch.autocast(device_type='cuda',
+ dtype=torch.float16,
+ enabled=False):
+
+ loss = torch.tensor(0., device=self.device)
+ loss_dict = {}
+
+ # balance rec_loss with logvar
+ # if 'depth_mask' in gt:
+ if self.opt.online_mask:
+ # https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L193
+ margin = (self.opt.max_depth - self.opt.min_depth) / 2
+ fg_mask = (pred['image_depth']
+ < self.opt.max_depth + margin).float() # B 1 H W
+ fg_mask = fg_mask.repeat_interleave(3, 1).float()
+ else:
+ if 'depth_mask' in gt:
+ fg_mask = gt['depth_mask'].unsqueeze(1).repeat_interleave(
+ 3, 1).float()
+ else:
+ fg_mask = None
+
+ loss_2d, loss_2d_dict = self.calc_2d_rec_loss(
+ pred['image_raw'],
+ gt['img'],
+ fg_mask,
+ test_mode=test_mode,
+ step=step,
+ ignore_lpips=False,
+ conf_sigma_l1=conf_sigma_l1,
+ conf_sigma_percl=conf_sigma_percl)
+ # ignore_lpips=self.opt.fg_mse)
+
+ if self.opt.kl_lambda > 0:
+ # assert 'posterior' in pred, 'logvar' in pred
+ assert 'posterior' in pred
+ kl_loss = pred['posterior'].kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ if self.opt.kl_anneal:
+ kl_lambda = kl_coeff(
+ step=step,
+ constant_step=5e3, # 1w steps
+ total_step=25e3, # 5w steps in total
+ min_kl_coeff=max(1e-9, self.opt.kl_lambda / 1e4),
+ max_kl_coeff=self.opt.kl_lambda)
+ loss_dict['kl_lambda'] = kl_lambda
+ else:
+ loss_dict['kl_lambda'] = torch.tensor(
+ self.opt.kl_lambda, device=dist_util.dev())
+
+ # loss_dict['kl_loss'] = kl_loss * kl_lambda
+ # loss_dict['kl_loss'] = kl_loss * kl_lambda
+ loss_dict['kl_loss'] = kl_loss * loss_dict['kl_lambda']
+ loss += loss_dict['kl_loss']
+
+ # nll_loss = loss_2d / torch.exp(pred['logvar']) + pred['logvar'] # nll_loss
+ nll_loss = loss_2d
+ loss += nll_loss
+
+ loss_dict.update(dict(nll_loss=nll_loss))
+
+ # loss_dict['latent_mu'] = pred['latent_normalized'].mean()
+ # loss_dict['latent_max'] = pred['latent_normalized'].max()
+ # loss_dict['latent_min'] = pred['latent_normalized'].min()
+ # loss_dict['latent_std'] = pred['latent_normalized'].std()
+ loss_dict['latent_mu'] = pred[
+ 'latent_normalized_2Ddiffusion'].mean()
+ loss_dict['latent_max'] = pred[
+ 'latent_normalized_2Ddiffusion'].max()
+ loss_dict['latent_min'] = pred[
+ 'latent_normalized_2Ddiffusion'].min()
+ loss_dict['latent_std'] = pred[
+ 'latent_normalized_2Ddiffusion'].std()
+
+ else:
+ loss += loss_2d
+
+ # if 'image_sr' in pred and pred['image_sr'].shape==gt['img_sr']:
+ if 'image_sr' in pred:
+
+ if 'depth_mask_sr' in gt:
+ depth_mask_sr = gt['depth_mask_sr'].unsqueeze(
+ 1).repeat_interleave(3, 1).float()
+ else:
+ depth_mask_sr = None
+
+ loss_sr, loss_sr_dict = self.calc_2d_rec_loss(
+ pred['image_sr'],
+ gt['img_sr'],
+ depth_fg_mask=depth_mask_sr,
+ # test_mode=test_mode,
+ test_mode=True,
+ step=step)
+ loss_sr_lambda = 1
+ if step < self.opt.sr_delay_iter:
+ loss_sr_lambda = 0
+ loss += loss_sr * loss_sr_lambda
+ for k, v in loss_sr_dict.items():
+ loss_dict['sr_' + k] = v * loss_sr_lambda
+
+ if self.opt.depth_lambda > 0: # TODO, switch to scale-agnostic depth loss
+ assert 'depth' in gt
+ pred_depth = pred['image_depth']
+ if pred_depth.ndim == 4:
+ pred_depth = pred_depth.squeeze(1) # B H W
+
+ # loss_3d, shape_loss_dict = self.calc_depth_loss(
+ # pred_depth, gt['depth'], fg_mask[:, 0, ...])
+ _, shape_loss_dict = self.calc_scale_invariant_depth_loss(
+ pred_depth, gt['depth'], fg_mask[:, 0, ...])
+ loss += shape_loss_dict['loss_depth']
+ loss_dict.update(shape_loss_dict)
+
+ # if self.opt.latent_lambda > 0: # make sure the latent suits diffusion learning
+ # latent_mu = pred['latent'].mean()
+ # loss_latent = self.criterion_latent(
+ # latent_mu, torch.zeros_like(
+ # latent_mu)) # only regularize the mean value here
+ # loss_dict['loss_latent'] = loss_latent
+ # loss += loss_latent * self.opt.latent_lambda
+
+ if 'image_mask' in pred:
+ pred_alpha = pred['image_mask'] # B 1 H W
+ else:
+ N, _, H, W = pred['image_depth'].shape
+ pred_alpha = pred['weights_samples'].permute(0, 2, 1).reshape(
+ N, 1, H, W)
+
+ if self.opt.alpha_lambda > 0 and 'image_depth' in pred:
+ loss_alpha = self.calc_alpha_loss(pred_alpha, fg_mask)
+ loss_dict['loss_alpha'] = loss_alpha * self.opt.alpha_lambda
+ loss += loss_alpha * self.opt.alpha_lambda
+
+ if self.opt.depth_smoothness_lambda > 0:
+ loss_depth_smoothness = depth_smoothness_loss(
+ pred_alpha,
+ pred['image_depth']) * self.opt.depth_smoothness_lambda
+ loss_dict['loss_depth_smoothness'] = loss_depth_smoothness
+ loss += loss_depth_smoothness
+
+ loss_2d_dict['all_loss'] = loss
+ loss_dict.update(loss_2d_dict)
+
+ # if return_fg_mask:
+ return loss, loss_dict, fg_mask
+ # else:
+ # return loss, loss_dict
+
+ def _calc_loss_id(self, input, gt):
+ if input.shape[-1] != 256:
+ arcface_input = self.id_loss_pool(input)
+ id_loss_gt = self.id_loss_pool(gt)
+ else:
+ arcface_input = input
+ id_loss_gt = gt
+
+ loss_id, _, _ = self.criterionID(arcface_input, id_loss_gt, id_loss_gt)
+
+ return loss_id
+
+ def calc_2d_rec_loss_misaligned(self, input, gt):
+ """id loss + vgg loss
+
+ Args:
+ input (_type_): _description_
+ gt (_type_): _description_
+ depth_mask (_type_): _description_
+ test_mode (bool, optional): _description_. Defaults to True.
+ """
+ opt = self.opt
+ loss_dict = {}
+
+ if opt.lpips_lambda > 0:
+ with torch.autocast(
+ device_type='cuda', dtype=torch.float16,
+ enabled=False): # close AMP for lpips to avoid nan
+ lpips_loss = self.criterionLPIPS(input, gt)
+ else:
+ lpips_loss = torch.tensor(0., device=input.device)
+
+ if opt.id_lambda > 0:
+ loss_id = self._calc_loss_id(input, gt)
+ else:
+ loss_id = torch.tensor(0., device=input.device)
+
+ loss_dict['loss_id_real'] = loss_id
+ loss_dict['loss_lpips_real'] = lpips_loss
+
+ loss = lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda
+
+ return loss, loss_dict
+
+
+class E3DGE_with_AdvLoss(E3DGELossClass):
+ # adapted from sgm/modules/autoencoding/losses/discriminator_loss.py
+
+ def __init__(
+ self,
+ device,
+ opt,
+ discriminator_config: Optional[Dict] = None,
+ disc_num_layers: int = 3,
+ disc_in_channels: int = 3,
+ disc_start: int = 0,
+ disc_loss: str = "hinge",
+ disc_factor: float = 1.0,
+ disc_weight: float = 1.0,
+ regularization_weights: Union[None, Dict[str, float]] = None,
+ # additional_log_keys: Optional[List[str]] = None,
+ ) -> None:
+ super().__init__(
+ device,
+ opt,
+ )
+
+ # ! initialize GAN loss
+ discriminator_config = default(
+ discriminator_config,
+ {
+ "target":
+ "nsr.losses.disc.NLayerDiscriminator",
+ "params": {
+ "input_nc": disc_in_channels,
+ "n_layers": disc_num_layers,
+ "use_actnorm": False,
+ },
+ },
+ )
+
+ self.discriminator = instantiate_from_config(
+ discriminator_config).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight # self.regularization_weights = default(regularization_weights, {})
+
+ # self.forward_keys = [
+ # "optimizer_idx",
+ # "global_step",
+ # "last_layer",
+ # "split",
+ # "regularization_log",
+ # ]
+
+ # self.additional_log_keys = set(default(additional_log_keys, []))
+ # self.additional_log_keys.update(set(
+ # self.regularization_weights.keys()))
+
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
+ return self.discriminator.parameters()
+
+ def forward(self,
+ pred,
+ gt,
+ behaviour: str,
+ test_mode=True,
+ step=1,
+ return_fg_mask=False,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None,
+ *args,
+ **kwargs):
+
+ # now the GAN part
+ reconstructions = pred['image_raw']
+ inputs = gt['img']
+
+ if behaviour == 'g_step':
+
+ nll_loss, loss_dict, fg_mask = super().forward(
+ pred,
+ gt,
+ test_mode,
+ step,
+ return_fg_mask,
+ conf_sigma_l1,
+ conf_sigma_percl,
+ *args,
+ **kwargs)
+
+ # generator update
+ if step >= self.discriminator_iter_start or not self.training:
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+ if self.training:
+ d_weight = torch.tensor(self.discriminator_weight)
+ # d_weight = self.calculate_adaptive_weight(
+ # nll_loss, g_loss, last_layer=last_layer)
+ else:
+ d_weight = torch.tensor(1.0)
+ else:
+ d_weight = torch.tensor(0.0)
+ g_loss = torch.tensor(0.0, requires_grad=True)
+
+ loss = nll_loss + d_weight * self.disc_factor * g_loss
+
+ # TODO
+ loss_dict.update({
+ f"loss/g": g_loss.detach().mean(),
+ })
+
+ # return loss, log
+ return loss, loss_dict, fg_mask
+
+ elif behaviour == 'd_step':
+ # second pass for discriminator update
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(
+ reconstructions.contiguous().detach())
+
+ if step >= self.discriminator_iter_start or not self.training:
+ d_loss = self.disc_factor * self.disc_loss(
+ logits_real, logits_fake)
+ else:
+ d_loss = torch.tensor(0.0, requires_grad=True)
+
+ loss_dict = {}
+
+ loss_dict.update({
+ "loss/disc": d_loss.clone().detach().mean(),
+ "logits/real": logits_real.detach().mean(),
+ "logits/fake": logits_fake.detach().mean(),
+ })
+
+ return d_loss, loss_dict, None
+ else:
+ raise NotImplementedError(f"Unknown optimizer behaviour {behaviour}")
diff --git a/nsr/losses/helpers.py b/nsr/losses/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..abaf044f5fb0d7ec98b21e820cb8291560024e3b
--- /dev/null
+++ b/nsr/losses/helpers.py
@@ -0,0 +1,378 @@
+from collections import namedtuple
+from pdb import set_trace as st
+import torch
+import numpy as np
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
+"""
+ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+# from nsr.networks_stylegan2 import FullyConnectedLayer as EqualLinear
+
+# class GradualStyleBlock(Module):
+
+# def __init__(self, in_c, out_c, spatial):
+# super(GradualStyleBlock, self).__init__()
+# self.out_c = out_c
+# self.spatial = spatial
+# num_pools = int(np.log2(spatial))
+# modules = []
+# modules += [
+# Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
+# nn.LeakyReLU()
+# ]
+# for i in range(num_pools - 1):
+# modules += [
+# Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
+# nn.LeakyReLU()
+# ]
+# self.convs = nn.Sequential(*modules)
+# self.linear = EqualLinear(out_c, out_c, lr_multiplier=1)
+
+# def forward(self, x):
+# x = self.convs(x)
+# x = x.reshape(-1, self.out_c)
+# x = self.linear(x)
+# return x
+
+
+# from project.models.model import ModulatedConv2d
+class DemodulatedConv2d(nn.Module):
+ def __init__(self,
+ in_channel,
+ out_channel,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ dilation=1):
+ super().__init__()
+ # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/411. fix droplet issue
+
+ self.eps = 1e-8
+
+ if not isinstance(kernel_size, tuple):
+ self.kernel_size = (kernel_size, kernel_size)
+ else:
+ self.kernel_size = kernel_size
+
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ self.weight = nn.Parameter(
+ # torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ torch.randn(1, out_channel, in_channel, *kernel_size))
+ self.bias = None
+ if bias:
+ self.bias = nn.Parameter(torch.randn(out_channel))
+
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+
+ def forward(self, input):
+ batch, in_channel, height, width = input.shape
+
+ demod = torch.rsqrt(self.weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ demod = demod.repeat_interleave(batch, 0)
+ weight = self.weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ # batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ batch * self.out_channel,
+ in_channel,
+ *self.kernel_size)
+
+ input = input.view(1, batch * in_channel, height, width)
+ if self.bias is None:
+ out = F.conv2d(input,
+ weight,
+ padding=self.padding,
+ groups=batch,
+ dilation=self.dilation,
+ stride=self.stride)
+ else:
+ out = F.conv2d(input,
+ weight,
+ bias=self.bias,
+ padding=self.padding,
+ groups=batch,
+ dilation=self.dilation,
+ stride=self.stride)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class Flatten(Module):
+ def forward(self, input):
+ return input.reshape(input.size(0), -1)
+
+
+def l2_norm(input, axis=1):
+ norm = torch.norm(input, 2, axis, True)
+ output = torch.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)
+ ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError(
+ "Invalid number of layers: {}. Must be one of [50, 100, 152]".
+ format(num_layers))
+ return blocks
+
+
+class SEModule(Module):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(channels,
+ channels // reduction,
+ kernel_size=1,
+ padding=0,
+ bias=False)
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(channels // reduction,
+ channels,
+ kernel_size=1,
+ padding=0,
+ bias=False)
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class bottleneck_IR(Module):
+ def __init__(self,
+ in_channel,
+ depth,
+ stride,
+ norm_layer=None,
+ demodulate=False):
+ super(bottleneck_IR, self).__init__()
+ if norm_layer is None:
+ norm_layer = BatchNorm2d
+ if demodulate:
+ conv2d = DemodulatedConv2d
+ else:
+ conv2d = Conv2d
+
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ # Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ norm_layer(depth))
+
+
+# BatchNorm2d(depth)
+ self.res_layer = Sequential(
+ # BatchNorm2d(in_channel),
+ norm_layer(in_channel),
+ # Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth),
+ # Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ norm_layer(depth))
+ # BatchNorm2d(depth))
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class bottleneck_IR_SE(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR_SE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth))
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth), SEModule(depth, 16))
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+def _upsample_add(x, y):
+ """Upsample and add two feature maps.
+ Args:
+ x: (Variable) top feature map to be upsampled.
+ y: (Variable) lateral feature map.
+ Returns:
+ (Variable) added feature map.
+ Note in PyTorch, when input size is odd, the upsampled feature map
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
+ maybe not equal to the lateral feature map size.
+ e.g.
+ original input size: [N,_,15,15] ->
+ conv2d feature map size: [N,_,8,8] ->
+ upsampled feature map size: [N,_,16,16]
+ So we choose bilinear upsample which supports arbitrary output sizes.
+ """
+ _, _, H, W = y.size()
+ return F.interpolate(x, size=(H, W), mode='bilinear',
+ align_corners=True) + y
+
+
+# from NeuRay
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ padding_mode='reflect')
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ padding_mode='reflect')
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self,
+ dim_in,
+ dim_out,
+ dim_inter=None,
+ use_norm=True,
+ norm_layer=nn.BatchNorm2d,
+ bias=False):
+ super().__init__()
+ if dim_inter is None:
+ dim_inter = dim_out
+
+ if use_norm:
+ self.conv = nn.Sequential(
+ norm_layer(dim_in),
+ nn.ReLU(True),
+ nn.Conv2d(dim_in,
+ dim_inter,
+ 3,
+ 1,
+ 1,
+ bias=bias,
+ padding_mode='reflect'),
+ norm_layer(dim_inter),
+ nn.ReLU(True),
+ nn.Conv2d(dim_inter,
+ dim_out,
+ 3,
+ 1,
+ 1,
+ bias=bias,
+ padding_mode='reflect'),
+ )
+ else:
+ self.conv = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(dim_in, dim_inter, 3, 1, 1),
+ nn.ReLU(True),
+ nn.Conv2d(dim_inter, dim_out, 3, 1, 1),
+ )
+
+ self.short_cut = None
+ if dim_in != dim_out:
+ self.short_cut = nn.Conv2d(dim_in, dim_out, 1, 1)
+
+ def forward(self, feats):
+ feats_out = self.conv(feats)
+ if self.short_cut is not None:
+ feats_out = self.short_cut(feats) + feats_out
+ else:
+ feats_out = feats_out + feats
+ return feats_out
+
+
+class conv(nn.Module):
+ def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
+ super(conv, self).__init__()
+ self.kernel_size = kernel_size
+ self.conv = nn.Conv2d(num_in_layers,
+ num_out_layers,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(self.kernel_size - 1) // 2,
+ padding_mode='reflect')
+ self.bn = nn.InstanceNorm2d(num_out_layers,
+ track_running_stats=False,
+ affine=True)
+
+ def forward(self, x):
+ return F.elu(self.bn(self.conv(x)), inplace=True)
+
+
+class upconv(nn.Module):
+ def __init__(self, num_in_layers, num_out_layers, kernel_size, scale):
+ super(upconv, self).__init__()
+ self.scale = scale
+ self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1)
+
+ def forward(self, x):
+ x = nn.functional.interpolate(x,
+ scale_factor=self.scale,
+ align_corners=True,
+ mode='bilinear')
+ return self.conv(x)
+
diff --git a/nsr/losses/id_loss.py b/nsr/losses/id_loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b0cc5f3421a6ff74048963d9734eae8f6226d7e
--- /dev/null
+++ b/nsr/losses/id_loss.py
@@ -0,0 +1,63 @@
+import torch
+from pdb import set_trace as st
+from torch import nn
+from .model_irse import Backbone
+from .paths_config import model_paths
+
+
+class IDLoss(nn.Module):
+
+ def __init__(self, device):
+ # super(IDLoss, self).__init__()
+ super().__init__()
+ print('Loading ResNet ArcFace')
+ self.facenet = Backbone(input_size=112,
+ num_layers=50,
+ drop_ratio=0.6,
+ mode='ir_se').to(device)
+ # self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
+ try:
+ face_net_model = torch.load(model_paths['ir_se50'],
+ map_location=device)
+ except Exception as e:
+ face_net_model = torch.load(model_paths['ir_se50_hwc'],
+ map_location=device)
+
+ self.facenet.load_state_dict(face_net_model)
+
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
+ self.facenet.eval()
+
+ def extract_feats(self, x):
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
+ x = self.face_pool(x)
+ x_feats = self.facenet(x)
+ return x_feats
+
+ def forward(self, y_hat, y, x):
+ n_samples, _, H, W = x.shape
+ assert H == W == 256, 'idloss needs 256*256 input images'
+
+ x_feats = self.extract_feats(x)
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
+ y_hat_feats = self.extract_feats(y_hat)
+ y_feats = y_feats.detach()
+ loss = 0
+ sim_improvement = 0
+ id_logs = []
+ count = 0
+ for i in range(n_samples):
+ diff_target = y_hat_feats[i].dot(y_feats[i])
+ diff_input = y_hat_feats[i].dot(x_feats[i])
+ diff_views = y_feats[i].dot(x_feats[i])
+ id_logs.append({
+ 'diff_target': float(diff_target),
+ 'diff_input': float(diff_input),
+ 'diff_views': float(diff_views)
+ })
+ loss += 1 - diff_target
+ id_diff = float(diff_target) - float(diff_views)
+ sim_improvement += id_diff
+ count += 1
+
+ return loss / count, sim_improvement / count, id_logs
diff --git a/nsr/losses/lms.py b/nsr/losses/lms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8e71f89d6e429f4e81f11675042beb09c4009ff
--- /dev/null
+++ b/nsr/losses/lms.py
@@ -0,0 +1,94 @@
+# ------------------------------------------------------------------------------
+# https://github.dev/HRNet/HigherHRNet-Human-Pose-Estimation
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (leoxiaobin@gmail.com)
+# Modified by Bowen Cheng (bcheng9@illinois.edu)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import logging
+
+import torch
+import torch.nn as nn
+from pdb import set_trace as st
+
+logger = logging.getLogger(__name__)
+
+
+class HeatmapGenerator():
+ def __init__(self, heatmap_size, num_joints=68, sigma=2):
+ self.heatmap_size = heatmap_size
+ # self.image_size = image_size
+ self.num_joints = num_joints
+ if sigma < 0:
+ sigma = self.heatmap_size / 64
+ self.sigma = sigma
+ size = 6 * sigma + 3
+ x = np.arange(0, size, 1, float)
+ y = x[:, np.newaxis]
+ x0, y0 = 3 * sigma + 1, 3 * sigma + 1
+ self.g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
+
+ # def __call__(self, joints, image_size: np.ndarray):
+ def __call__(self, joints, image_size: int):
+ """generate heatmap gt from joints
+
+ Args:
+ joints (np.ndarray): N,3
+
+ Returns:
+ hms: N,H,W
+ """
+ hms = np.zeros((self.num_joints, self.heatmap_size, self.heatmap_size),
+ dtype=np.float32)
+ sigma = self.sigma
+
+ # feat_stride = image_size / [self.heatmap_size, self.heatmap_size]
+ feat_stride = image_size / self.heatmap_size
+ for idx, pt in enumerate(joints):
+ # for idx, pt in enumerate(p):
+ if pt[2] > 0:
+ # x = int(pt[0] / feat_stride[0] + 0.5)
+ # y = int(pt[1] / feat_stride[1] + 0.5) # normalize joints to heatmap size
+ x = int(pt[0] / feat_stride + 0.5)
+ y = int(pt[1] / feat_stride +
+ 0.5) # normalize joints to heatmap size
+ if x < 0 or y < 0 or \
+ x >= self.heatmap_size or y >= self.heatmap_size:
+ continue
+
+ ul = int(np.round(x - 3 * sigma - 1)), int(
+ np.round(y - 3 * sigma - 1))
+ br = int(np.round(x + 3 * sigma + 2)), int(
+ np.round(y + 3 * sigma + 2))
+
+ c, d = max(0, -ul[0]), min(br[0], self.heatmap_size) - ul[0]
+ a, b = max(0, -ul[1]), min(br[1], self.heatmap_size) - ul[1]
+
+ cc, dd = max(0, ul[0]), min(br[0], self.heatmap_size)
+ aa, bb = max(0, ul[1]), min(br[1], self.heatmap_size)
+ hms[idx, aa:bb, cc:dd] = np.maximum(hms[idx, aa:bb, cc:dd],
+ self.g[a:b, c:d])
+ return hms
+
+
+class HeatmapLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, pred, gt, mask=None):
+ # todo, add mask
+ assert pred.size() == gt.size()
+ loss = ((pred - gt)**2)
+ if mask is not None:
+ loss = loss * mask[:, None, :, :].expand_as(pred)
+ # loss = loss.mean(dim=3).mean(dim=2).mean(dim=1)
+ loss = loss.mean()
+ # loss = loss.mean(dim=3).mean(dim=2).sum(dim=1)
+ return loss
diff --git a/nsr/losses/lpips/__pycache__/__init__.cpython-39.pyc b/nsr/losses/lpips/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f9192e56b92eaed484b15aa4143550e09543f2b
Binary files /dev/null and b/nsr/losses/lpips/__pycache__/__init__.cpython-39.pyc differ
diff --git a/nsr/losses/lpips/__pycache__/lpips.cpython-39.pyc b/nsr/losses/lpips/__pycache__/lpips.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffc80d0c88061907fde00c55eaaace7a7a109186
Binary files /dev/null and b/nsr/losses/lpips/__pycache__/lpips.cpython-39.pyc differ
diff --git a/nsr/losses/lpips/__pycache__/networks.cpython-39.pyc b/nsr/losses/lpips/__pycache__/networks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f6f39b4a1f2d2d42f77e8e0b94ce911e79b7306
Binary files /dev/null and b/nsr/losses/lpips/__pycache__/networks.cpython-39.pyc differ
diff --git a/nsr/losses/lpips/__pycache__/utils.cpython-39.pyc b/nsr/losses/lpips/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..779af69091a64dc52212caa4f39a98283805f13a
Binary files /dev/null and b/nsr/losses/lpips/__pycache__/utils.cpython-39.pyc differ
diff --git a/nsr/losses/model_irse.py b/nsr/losses/model_irse.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c5a0f6d927eb8523a6df7589bd5e05c936b9669
--- /dev/null
+++ b/nsr/losses/model_irse.py
@@ -0,0 +1,110 @@
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
+from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
+"""
+Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Backbone(Module):
+ def __init__(self,
+ input_size,
+ num_layers,
+ mode='ir',
+ drop_ratio=0.4,
+ affine=True):
+ super(Backbone, self).__init__()
+ assert input_size in [112, 224], "input_size should be 112 or 224"
+ assert num_layers in [50, 100,
+ 152], "num_layers should be 50, 100 or 152"
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64), PReLU(64))
+ if input_size == 112:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio), Flatten(),
+ Linear(512 * 7 * 7, 512),
+ BatchNorm1d(512, affine=affine))
+ else:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio), Flatten(),
+ Linear(512 * 14 * 14, 512),
+ BatchNorm1d(512, affine=affine))
+
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(
+ unit_module(bottleneck.in_channel, bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer(x)
+ return l2_norm(x)
+
+
+def IR_50(input_size):
+ """Constructs a ir-50 model."""
+ model = Backbone(input_size,
+ num_layers=50,
+ mode='ir',
+ drop_ratio=0.4,
+ affine=False)
+ return model
+
+
+def IR_101(input_size):
+ """Constructs a ir-101 model."""
+ model = Backbone(input_size,
+ num_layers=100,
+ mode='ir',
+ drop_ratio=0.4,
+ affine=False)
+ return model
+
+
+def IR_152(input_size):
+ """Constructs a ir-152 model."""
+ model = Backbone(input_size,
+ num_layers=152,
+ mode='ir',
+ drop_ratio=0.4,
+ affine=False)
+ return model
+
+
+def IR_SE_50(input_size):
+ """Constructs a ir_se-50 model."""
+ model = Backbone(input_size,
+ num_layers=50,
+ mode='ir_se',
+ drop_ratio=0.4,
+ affine=False)
+ return model
+
+
+def IR_SE_101(input_size):
+ """Constructs a ir_se-101 model."""
+ model = Backbone(input_size,
+ num_layers=100,
+ mode='ir_se',
+ drop_ratio=0.4,
+ affine=False)
+ return model
+
+
+def IR_SE_152(input_size):
+ """Constructs a ir_se-152 model."""
+ model = Backbone(input_size,
+ num_layers=152,
+ mode='ir_se',
+ drop_ratio=0.4,
+ affine=False)
+ return model
diff --git a/nsr/losses/paths_config.py b/nsr/losses/paths_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..efd4e8094ee6bbf1acb0c82c22528c31e4b5480e
--- /dev/null
+++ b/nsr/losses/paths_config.py
@@ -0,0 +1,24 @@
+model_paths = {
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
+ 'resnet34': 'pretrained_models/resnet34-333f7ec4.pth',
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
+ 'stylegan_cars': 'pretrained_models/stylegan2-car-config-f.pt',
+ 'stylegan_church': 'pretrained_models/stylegan2-church-config-f.pt',
+ 'stylegan_horse': 'pretrained_models/stylegan2-horse-config-f.pt',
+ 'stylegan_ada_wild': 'pretrained_models/afhqwild.pt',
+ 'stylegan_toonify': 'pretrained_models/ffhq_cartoon_blended.pt',
+ 'shape_predictor':
+ 'pretrained_models/shape_predictor_68_face_landmarks.dat',
+ 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
+ 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
+ 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
+ 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pt'
+}
+
+project_basedir = '/mnt/lustre/yslan/Repo/Research/SIGA22/BaseModels/StyleSDF'
+
+for k, v in model_paths.items():
+ model_paths[k] = f'{project_basedir}/project/utils/misc/' + model_paths[k]
+
+model_paths['ir_se50_hwc'] = '/home/yslan/datasets/model_ir_se50.pth'
diff --git a/nsr/losses/sdfstudio_losses.py b/nsr/losses/sdfstudio_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e560b07c8baec71febb8b4235d09e2cb68b2d0c
--- /dev/null
+++ b/nsr/losses/sdfstudio_losses.py
@@ -0,0 +1,771 @@
+# Copyright 2022 The Nerfstudio Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Collection of Losses.
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchtyping import TensorType
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+# from nerfstudio.cameras.rays import RaySamples
+# from nerfstudio.field_components.field_heads import FieldHeadNames
+
+L1Loss = nn.L1Loss
+MSELoss = nn.MSELoss
+
+LOSSES = {"L1": L1Loss, "MSE": MSELoss}
+
+EPS = 1.0e-7
+
+
+def outer(
+ t0_starts: TensorType[..., "num_samples_0"],
+ t0_ends: TensorType[..., "num_samples_0"],
+ t1_starts: TensorType[..., "num_samples_1"],
+ t1_ends: TensorType[..., "num_samples_1"],
+ y1: TensorType[..., "num_samples_1"],
+) -> TensorType[..., "num_samples_0"]:
+ """Faster version of
+
+ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64
+
+ Args:
+ t0_starts: start of the interval edges
+ t0_ends: end of the interval edges
+ t1_starts: start of the interval edges
+ t1_ends: end of the interval edges
+ y1: weights
+ """
+ cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1)
+
+ idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1
+ idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
+ idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right")
+ idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
+ cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
+ cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
+ y0_outer = cy1_hi - cy1_lo
+
+ return y0_outer
+
+
+def lossfun_outer(
+ t: TensorType[..., "num_samples+1"],
+ w: TensorType[..., "num_samples"],
+ t_env: TensorType[..., "num_samples+1"],
+ w_env: TensorType[..., "num_samples"],
+):
+ """
+ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80
+
+ Args:
+ t: interval edges
+ w: weights
+ t_env: interval edges of the upper bound enveloping historgram
+ w_env: weights that should upper bound the inner (t,w) histogram
+ """
+ w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env)
+ return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS)
+
+
+def ray_samples_to_sdist(ray_samples):
+ """Convert ray samples to s space"""
+ starts = ray_samples.spacing_starts
+ ends = ray_samples.spacing_ends
+ sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1) # (num_rays, num_samples + 1)
+ return sdist
+
+
+def interlevel_loss(weights_list, ray_samples_list):
+ """Calculates the proposal loss in the MipNeRF-360 paper.
+
+ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133
+ """
+ c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
+ w = weights_list[-1][..., 0].detach()
+ loss_interlevel = 0.0
+ for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]):
+ sdist = ray_samples_to_sdist(ray_samples)
+ cp = sdist # (num_rays, num_samples + 1)
+ wp = weights[..., 0] # (num_rays, num_samples)
+ loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp))
+ return loss_interlevel
+
+
+## zip-NeRF losses
+def blur_stepfun(x, y, r):
+ x_c = torch.cat([x - r, x + r], dim=-1)
+ x_r, x_idx = torch.sort(x_c, dim=-1)
+ zeros = torch.zeros_like(y[:, :1])
+ y_1 = (torch.cat([y, zeros], dim=-1) - torch.cat([zeros, y], dim=-1)) / (2 * r)
+ x_idx = x_idx[:, :-1]
+ y_2 = torch.cat([y_1, -y_1], dim=-1)[
+ torch.arange(x_idx.shape[0]).reshape(-1, 1).expand(x_idx.shape).to(x_idx.device), x_idx
+ ]
+
+ y_r = torch.cumsum((x_r[:, 1:] - x_r[:, :-1]) * torch.cumsum(y_2, dim=-1), dim=-1)
+ y_r = torch.cat([zeros, y_r], dim=-1)
+ return x_r, y_r
+
+
+def interlevel_loss_zip(weights_list, ray_samples_list):
+ """Calculates the proposal loss in the Zip-NeRF paper."""
+ c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
+ w = weights_list[-1][..., 0].detach()
+
+ # 1. normalize
+ w_normalize = w / (c[:, 1:] - c[:, :-1])
+
+ loss_interlevel = 0.0
+ for ray_samples, weights, r in zip(ray_samples_list[:-1], weights_list[:-1], [0.03, 0.003]):
+ # 2. step blur with different r
+ x_r, y_r = blur_stepfun(c, w_normalize, r)
+ y_r = torch.clip(y_r, min=0)
+ assert (y_r >= 0.0).all()
+
+ # 3. accumulate
+ y_cum = torch.cumsum((y_r[:, 1:] + y_r[:, :-1]) * 0.5 * (x_r[:, 1:] - x_r[:, :-1]), dim=-1)
+ y_cum = torch.cat([torch.zeros_like(y_cum[:, :1]), y_cum], dim=-1)
+
+ # 4 loss
+ sdist = ray_samples_to_sdist(ray_samples)
+ cp = sdist # (num_rays, num_samples + 1)
+ wp = weights[..., 0] # (num_rays, num_samples)
+
+ # resample
+ inds = torch.searchsorted(x_r, cp, side="right")
+ below = torch.clamp(inds - 1, 0, x_r.shape[-1] - 1)
+ above = torch.clamp(inds, 0, x_r.shape[-1] - 1)
+ cdf_g0 = torch.gather(x_r, -1, below)
+ bins_g0 = torch.gather(y_cum, -1, below)
+ cdf_g1 = torch.gather(x_r, -1, above)
+ bins_g1 = torch.gather(y_cum, -1, above)
+
+ t = torch.clip(torch.nan_to_num((cp - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
+ bins = bins_g0 + t * (bins_g1 - bins_g0)
+
+ w_gt = bins[:, 1:] - bins[:, :-1]
+
+ # TODO here might be unstable when wp is very small
+ loss_interlevel += torch.mean(torch.clip(w_gt - wp, min=0) ** 2 / (wp + 1e-5))
+
+ return loss_interlevel
+
+
+# Verified
+def lossfun_distortion(t, w):
+ """
+ https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266
+ """
+ ut = (t[..., 1:] + t[..., :-1]) / 2
+ dut = torch.abs(ut[..., :, None] - ut[..., None, :])
+ loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)
+
+ loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3
+
+ return loss_inter + loss_intra
+
+
+def distortion_loss(weights_list, ray_samples_list):
+ """From mipnerf360"""
+ c = ray_samples_to_sdist(ray_samples_list[-1])
+ w = weights_list[-1][..., 0]
+ loss = torch.mean(lossfun_distortion(c, w))
+ return loss
+
+
+# def nerfstudio_distortion_loss(
+# ray_samples: RaySamples,
+# densities: TensorType["bs":..., "num_samples", 1] = None,
+# weights: TensorType["bs":..., "num_samples", 1] = None,
+# ) -> TensorType["bs":..., 1]:
+# """Ray based distortion loss proposed in MipNeRF-360. Returns distortion Loss.
+
+# .. math::
+
+# \\mathcal{L}(\\mathbf{s}, \\mathbf{w}) =\\iint\\limits_{-\\infty}^{\\,\\,\\,\\infty}
+# \\mathbf{w}_\\mathbf{s}(u)\\mathbf{w}_\\mathbf{s}(v)|u - v|\\,d_{u}\\,d_{v}
+
+# where :math:`\\mathbf{w}_\\mathbf{s}(u)=\\sum_i w_i \\mathbb{1}_{[\\mathbf{s}_i, \\mathbf{s}_{i+1})}(u)`
+# is the weight at location :math:`u` between bin locations :math:`s_i` and :math:`s_{i+1}`.
+
+# Args:
+# ray_samples: Ray samples to compute loss over
+# densities: Predicted sample densities
+# weights: Predicted weights from densities and sample locations
+# """
+# if torch.is_tensor(densities):
+# assert not torch.is_tensor(weights), "Cannot use both densities and weights"
+# # Compute the weight at each sample location
+# weights = ray_samples.get_weights(densities)
+# if torch.is_tensor(weights):
+# assert not torch.is_tensor(densities), "Cannot use both densities and weights"
+
+# starts = ray_samples.spacing_starts
+# ends = ray_samples.spacing_ends
+
+# assert starts is not None and ends is not None, "Ray samples must have spacing starts and ends"
+# midpoints = (starts + ends) / 2.0 # (..., num_samples, 1)
+
+# loss = (
+# weights * weights[..., None, :, 0] * torch.abs(midpoints - midpoints[..., None, :, 0])
+# ) # (..., num_samples, num_samples)
+# loss = torch.sum(loss, dim=(-1, -2))[..., None] # (..., num_samples)
+# loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2)
+
+# return loss
+
+
+def orientation_loss(
+ weights: TensorType["bs":..., "num_samples", 1],
+ normals: TensorType["bs":..., "num_samples", 3],
+ viewdirs: TensorType["bs":..., 3],
+):
+ """Orientation loss proposed in Ref-NeRF.
+ Loss that encourages that all visible normals are facing towards the camera.
+ """
+ w = weights
+ n = normals
+ v = viewdirs
+ n_dot_v = (n * v[..., None, :]).sum(axis=-1)
+ return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1)
+
+
+def pred_normal_loss(
+ weights: TensorType["bs":..., "num_samples", 1],
+ normals: TensorType["bs":..., "num_samples", 3],
+ pred_normals: TensorType["bs":..., "num_samples", 3],
+):
+ """Loss between normals calculated from density and normals from prediction network."""
+ return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1)
+
+
+def monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor):
+ """normal consistency loss as monosdf
+
+ Args:
+ normal_pred (torch.Tensor): volume rendered normal
+ normal_gt (torch.Tensor): monocular normal
+ """
+ normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1)
+ normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1)
+ l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean()
+ cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean()
+ return l1 + cos
+
+
+# copy from MiDaS
+def compute_scale_and_shift(prediction, target, mask):
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
+ a_00 = torch.sum(mask * prediction * prediction, (1, 2))
+ a_01 = torch.sum(mask * prediction, (1, 2))
+ a_11 = torch.sum(mask, (1, 2))
+
+ # right hand side: b = [b_0, b_1]
+ b_0 = torch.sum(mask * prediction * target, (1, 2))
+ b_1 = torch.sum(mask * target, (1, 2))
+
+ # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
+ x_0 = torch.zeros_like(b_0)
+ x_1 = torch.zeros_like(b_1)
+
+ det = a_00 * a_11 - a_01 * a_01
+ valid = det.nonzero()
+
+ x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
+ x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
+
+ return x_0, x_1
+
+
+def reduction_batch_based(image_loss, M):
+ # average of all valid pixels of the batch
+
+ # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
+ divisor = torch.sum(M)
+
+ if divisor == 0:
+ return 0
+ else:
+ return torch.sum(image_loss) / divisor
+
+
+def reduction_image_based(image_loss, M):
+ # mean of average of valid pixels of an image
+
+ # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
+ valid = M.nonzero()
+
+ image_loss[valid] = image_loss[valid] / M[valid]
+
+ return torch.mean(image_loss)
+
+
+def mse_loss(prediction, target, mask, reduction=reduction_batch_based):
+ M = torch.sum(mask, (1, 2))
+ res = prediction - target
+ image_loss = torch.sum(mask * res * res, (1, 2))
+
+ return reduction(image_loss, 2 * M)
+
+
+def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
+ M = torch.sum(mask, (1, 2))
+
+ diff = prediction - target
+ diff = torch.mul(mask, diff)
+
+ grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
+ mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
+ grad_x = torch.mul(mask_x, grad_x)
+
+ grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
+ mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
+ grad_y = torch.mul(mask_y, grad_y)
+
+ image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
+
+ return reduction(image_loss, M)
+
+
+class MiDaSMSELoss(nn.Module):
+ def __init__(self, reduction="batch-based"):
+ super().__init__()
+
+ if reduction == "batch-based":
+ self.__reduction = reduction_batch_based
+ else:
+ self.__reduction = reduction_image_based
+
+ def forward(self, prediction, target, mask):
+ return mse_loss(prediction, target, mask, reduction=self.__reduction)
+
+
+class GradientLoss(nn.Module):
+ def __init__(self, scales=4, reduction="batch-based"):
+ super().__init__()
+
+ if reduction == "batch-based":
+ self.__reduction = reduction_batch_based
+ else:
+ self.__reduction = reduction_image_based
+
+ self.__scales = scales
+
+ def forward(self, prediction, target, mask):
+ total = 0
+
+ for scale in range(self.__scales):
+ step = pow(2, scale)
+
+ total += gradient_loss(
+ prediction[:, ::step, ::step],
+ target[:, ::step, ::step],
+ mask[:, ::step, ::step],
+ reduction=self.__reduction,
+ )
+
+ return total
+
+
+class ScaleAndShiftInvariantLoss(nn.Module):
+ def __init__(self, alpha=0.5, scales=4, reduction="batch-based"):
+ super().__init__()
+
+ self.__data_loss = MiDaSMSELoss(reduction=reduction)
+ self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
+ self.__alpha = alpha
+
+ self.__prediction_ssi = None
+
+ def forward(self, prediction, target, mask):
+ scale, shift = compute_scale_and_shift(prediction, target, mask)
+ self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
+
+ total = self.__data_loss(self.__prediction_ssi, target, mask)
+ if self.__alpha > 0:
+ total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)
+
+ return total
+
+ def __get_prediction_ssi(self):
+ return self.__prediction_ssi
+
+ prediction_ssi = property(__get_prediction_ssi)
+
+
+# end copy
+
+
+# copy from https://github.com/svip-lab/Indoor-SfMLearner/blob/0d682b7ce292484e5e3e2161fc9fc07e2f5ca8d1/layers.py#L218
+class SSIM(nn.Module):
+ """Layer to compute the SSIM loss between a pair of images"""
+
+ def __init__(self, patch_size):
+ super(SSIM, self).__init__()
+ self.mu_x_pool = nn.AvgPool2d(patch_size, 1)
+ self.mu_y_pool = nn.AvgPool2d(patch_size, 1)
+ self.sig_x_pool = nn.AvgPool2d(patch_size, 1)
+ self.sig_y_pool = nn.AvgPool2d(patch_size, 1)
+ self.sig_xy_pool = nn.AvgPool2d(patch_size, 1)
+
+ self.refl = nn.ReflectionPad2d(patch_size // 2)
+
+ self.C1 = 0.01**2
+ self.C2 = 0.03**2
+
+ def forward(self, x, y):
+ x = self.refl(x)
+ y = self.refl(y)
+
+ mu_x = self.mu_x_pool(x)
+ mu_y = self.mu_y_pool(y)
+
+ sigma_x = self.sig_x_pool(x**2) - mu_x**2
+ sigma_y = self.sig_y_pool(y**2) - mu_y**2
+ sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
+
+ SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
+ SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2)
+
+ return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
+
+
+# TODO test different losses
+class NCC(nn.Module):
+ """Layer to compute the normalization cross correlation (NCC) of patches"""
+
+ def __init__(self, patch_size: int = 11, min_patch_variance: float = 0.01):
+ super(NCC, self).__init__()
+ self.patch_size = patch_size
+ self.min_patch_variance = min_patch_variance
+
+ def forward(self, x, y):
+ # TODO if we use gray image we should do it right after loading the image to save computations
+ # to gray image
+ x = torch.mean(x, dim=1)
+ y = torch.mean(y, dim=1)
+
+ x_mean = torch.mean(x, dim=(1, 2), keepdim=True)
+ y_mean = torch.mean(y, dim=(1, 2), keepdim=True)
+
+ x_normalized = x - x_mean
+ y_normalized = y - y_mean
+
+ norm = torch.sum(x_normalized * y_normalized, dim=(1, 2))
+ var = torch.square(x_normalized).sum(dim=(1, 2)) * torch.square(y_normalized).sum(dim=(1, 2))
+ denom = torch.sqrt(var + 1e-6)
+
+ ncc = norm / (denom + 1e-6)
+
+ # ignore pathces with low variances
+ not_valid = (torch.square(x_normalized).sum(dim=(1, 2)) < self.min_patch_variance) | (
+ torch.square(y_normalized).sum(dim=(1, 2)) < self.min_patch_variance
+ )
+ ncc[not_valid] = 1.0
+
+ score = 1 - ncc.clip(-1.0, 1.0) # 0->2: smaller, better
+ return score[:, None, None, None]
+
+
+class MultiViewLoss(nn.Module):
+ """compute multi-view consistency loss"""
+
+ def __init__(self, patch_size: int = 11, topk: int = 4, min_patch_variance: float = 0.01):
+ super(MultiViewLoss, self).__init__()
+ self.patch_size = patch_size
+ self.topk = topk
+ self.min_patch_variance = min_patch_variance
+ # TODO make metric configurable
+ # self.ssim = SSIM(patch_size=patch_size)
+ # self.ncc = NCC(patch_size=patch_size)
+ self.ssim = NCC(patch_size=patch_size, min_patch_variance=min_patch_variance)
+
+ self.iter = 0
+
+ def forward(self, patches: torch.Tensor, valid: torch.Tensor):
+ """take the mim
+
+ Args:
+ patches (torch.Tensor): _description_
+ valid (torch.Tensor): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ num_imgs, num_rays, _, num_channels = patches.shape
+
+ if num_rays <= 0:
+ return torch.tensor(0.0).to(patches.device)
+
+ ref_patches = (
+ patches[:1, ...]
+ .reshape(1, num_rays, self.patch_size, self.patch_size, num_channels)
+ .expand(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels)
+ .reshape(-1, self.patch_size, self.patch_size, num_channels)
+ .permute(0, 3, 1, 2)
+ ) # [N_src*N_rays, 3, patch_size, patch_size]
+ src_patches = (
+ patches[1:, ...]
+ .reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels)
+ .reshape(-1, self.patch_size, self.patch_size, num_channels)
+ .permute(0, 3, 1, 2)
+ ) # [N_src*N_rays, 3, patch_size, patch_size]
+
+ # apply same reshape to the valid mask
+ src_patches_valid = (
+ valid[1:, ...]
+ .reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, 1)
+ .reshape(-1, self.patch_size, self.patch_size, 1)
+ .permute(0, 3, 1, 2)
+ ) # [N_src*N_rays, 1, patch_size, patch_size]
+
+ ssim = self.ssim(ref_patches.detach(), src_patches)
+ ssim = torch.mean(ssim, dim=(1, 2, 3))
+ ssim = ssim.reshape(num_imgs - 1, num_rays)
+
+ # ignore invalid patch by setting ssim error to very large value
+ ssim_valid = (
+ src_patches_valid.reshape(-1, self.patch_size * self.patch_size).all(dim=-1).reshape(num_imgs - 1, num_rays)
+ )
+ # we should mask the error after we select the topk value, otherwise we might select far way patches that happens to be inside the image
+ # ssim[torch.logical_not(ssim_valid)] = 1.1 # max ssim_error is 1
+
+ min_ssim, idx = torch.topk(ssim, k=self.topk, largest=False, dim=0, sorted=True)
+
+ min_ssim_valid = ssim_valid[idx, torch.arange(num_rays)[None].expand_as(idx)]
+ # TODO how to set this value for better visualization
+ min_ssim[torch.logical_not(min_ssim_valid)] = 0.0 # max ssim_error is 1
+
+ if False:
+ # visualization of topK error computations
+
+ import cv2
+ import numpy as np
+
+ vis_patch_num = num_rays
+ K = min(100, vis_patch_num)
+
+ image = (
+ patches[:, :vis_patch_num, :, :]
+ .reshape(-1, vis_patch_num, self.patch_size, self.patch_size, 3)
+ .permute(1, 2, 0, 3, 4)
+ .reshape(vis_patch_num * self.patch_size, -1, 3)
+ )
+
+ src_patches_reshaped = src_patches.reshape(
+ num_imgs - 1, num_rays, 3, self.patch_size, self.patch_size
+ ).permute(1, 0, 3, 4, 2)
+ idx = idx.permute(1, 0)
+
+ selected_patch = (
+ src_patches_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx]
+ .permute(0, 2, 1, 3, 4)
+ .reshape(num_rays, self.patch_size, self.topk * self.patch_size, 3)[:vis_patch_num]
+ .reshape(-1, self.topk * self.patch_size, 3)
+ )
+
+ # apply same reshape to the valid mask
+ src_patches_valid_reshaped = src_patches_valid.reshape(
+ num_imgs - 1, num_rays, 1, self.patch_size, self.patch_size
+ ).permute(1, 0, 3, 4, 2)
+
+ selected_patch_valid = (
+ src_patches_valid_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx]
+ .permute(0, 2, 1, 3, 4)
+ .reshape(num_rays, self.patch_size, self.topk * self.patch_size, 1)[:vis_patch_num]
+ .reshape(-1, self.topk * self.patch_size, 1)
+ )
+ # valid to image
+ selected_patch_valid = selected_patch_valid.expand_as(selected_patch).float()
+ # breakpoint()
+
+ image = torch.cat([selected_patch_valid, selected_patch, image], dim=1)
+ # select top rays with highest errors
+
+ image = image.reshape(num_rays, self.patch_size, -1, 3)
+
+ _, idx2 = torch.topk(
+ torch.sum(min_ssim, dim=0) / (min_ssim_valid.float().sum(dim=0) + 1e-6),
+ k=K,
+ largest=True,
+ dim=0,
+ sorted=True,
+ )
+
+ image = image[idx2].reshape(K * self.patch_size, -1, 3)
+
+ cv2.imwrite(f"vis/{self.iter}.png", (image.detach().cpu().numpy() * 255).astype(np.uint8)[..., ::-1])
+ self.iter += 1
+ if self.iter == 9:
+ breakpoint()
+
+ return torch.sum(min_ssim) / (min_ssim_valid.float().sum() + 1e-6)
+
+
+# sensor depth loss, adapted from https://github.com/dazinovic/neural-rgbd-surface-reconstruction/blob/main/losses.py
+# class SensorDepthLoss(nn.Module):
+# """Sensor Depth loss"""
+
+# def __init__(self, truncation: float):
+# super(SensorDepthLoss, self).__init__()
+# self.truncation = truncation # 0.05 * 0.3 5cm scaled
+
+# def forward(self, batch, outputs):
+# """take the mim
+
+# Args:
+# batch (Dict): inputs
+# outputs (Dict): outputs data from surface model
+
+# Returns:
+# l1_loss: l1 loss
+# freespace_loss: free space loss
+# sdf_loss: sdf loss
+# """
+# depth_pred = outputs["depth"]
+# depth_gt = batch["sensor_depth"].to(depth_pred.device)[..., None]
+# valid_gt_mask = depth_gt > 0.0
+
+# l1_loss = torch.sum(valid_gt_mask * torch.abs(depth_gt - depth_pred)) / (valid_gt_mask.sum() + 1e-6)
+
+# # free space loss and sdf loss
+# ray_samples = outputs["ray_samples"]
+# filed_outputs = outputs["field_outputs"]
+# pred_sdf = filed_outputs[FieldHeadNames.SDF][..., 0]
+# directions_norm = outputs["directions_norm"]
+
+# z_vals = ray_samples.frustums.starts[..., 0] / directions_norm
+
+# truncation = self.truncation
+# front_mask = valid_gt_mask & (z_vals < (depth_gt - truncation))
+# back_mask = valid_gt_mask & (z_vals > (depth_gt + truncation))
+# sdf_mask = valid_gt_mask & (~front_mask) & (~back_mask)
+
+# num_fs_samples = front_mask.sum()
+# num_sdf_samples = sdf_mask.sum()
+# num_samples = num_fs_samples + num_sdf_samples + 1e-6
+# fs_weight = 1.0 - num_fs_samples / num_samples
+# sdf_weight = 1.0 - num_sdf_samples / num_samples
+
+# free_space_loss = torch.mean((F.relu(truncation - pred_sdf) * front_mask) ** 2) * fs_weight
+
+# sdf_loss = torch.mean(((z_vals + pred_sdf) - depth_gt) ** 2 * sdf_mask) * sdf_weight
+
+# return l1_loss, free_space_loss, sdf_loss
+
+r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm.
+It is proposed in the ICCV2023 paper
+`S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`.
+
+Arguments:
+ s3im_kernel_size (int): kernel size in ssim's convolution(default: 4)
+ s3im_stride (int): stride in ssim's convolution(default: 4)
+ s3im_repeat_time (int): repeat time in re-shuffle virtual patch(default: 10)
+ s3im_patch_height (height): height of virtual patch(default: 64)
+"""
+
+class S3IM(torch.nn.Module):
+ def __init__(self, s3im_kernel_size = 4, s3im_stride=4, s3im_repeat_time=10, s3im_patch_height=64, size_average = True):
+ super(S3IM, self).__init__()
+ self.s3im_kernel_size = s3im_kernel_size
+ self.s3im_stride = s3im_stride
+ self.s3im_repeat_time = s3im_repeat_time
+ self.s3im_patch_height = s3im_patch_height
+ self.size_average = size_average
+ self.channel = 1
+ self.s3im_kernel = self.create_kernel(s3im_kernel_size, self.channel)
+
+
+ def gaussian(self, s3im_kernel_size, sigma):
+ gauss = torch.Tensor([exp(-(x - s3im_kernel_size//2)**2/float(2*sigma**2)) for x in range(s3im_kernel_size)])
+ return gauss/gauss.sum()
+
+ def create_kernel(self, s3im_kernel_size, channel):
+ _1D_window = self.gaussian(s3im_kernel_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ s3im_kernel = Variable(_2D_window.expand(channel, 1, s3im_kernel_size, s3im_kernel_size).contiguous())
+ return s3im_kernel
+
+ def _ssim(self, img1, img2, s3im_kernel, s3im_kernel_size, channel, size_average = True, s3im_stride=None):
+ mu1 = F.conv2d(img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride)
+ mu2 = F.conv2d(img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1*mu2
+
+ sigma1_sq = F.conv2d(img1*img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_sq
+ sigma2_sq = F.conv2d(img2*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu2_sq
+ sigma12 = F.conv2d(img1*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_mu2
+
+ C1 = 0.01**2
+ C2 = 0.03**2
+
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+ def ssim_loss(self, img1, img2):
+ """
+ img1, img2: torch.Tensor([b,c,h,w])
+ """
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.s3im_kernel.data.type() == img1.data.type():
+ s3im_kernel = self.s3im_kernel
+ else:
+ s3im_kernel = self.create_kernel(self.s3im_kernel_size, channel)
+
+ if img1.is_cuda:
+ s3im_kernel = s3im_kernel.cuda(img1.get_device())
+ s3im_kernel = s3im_kernel.type_as(img1)
+
+ self.s3im_kernel = s3im_kernel
+ self.channel = channel
+
+
+ return self._ssim(img1, img2, s3im_kernel, self.s3im_kernel_size, channel, self.size_average, s3im_stride=self.s3im_stride)
+
+ def forward(self, src_vec, tar_vec):
+ loss = 0.0
+ index_list = []
+ for i in range(self.s3im_repeat_time):
+ if i == 0:
+ tmp_index = torch.arange(len(tar_vec))
+ index_list.append(tmp_index)
+ else:
+ ran_idx = torch.randperm(len(tar_vec))
+ index_list.append(ran_idx)
+ res_index = torch.cat(index_list)
+ tar_all = tar_vec[res_index]
+ src_all = src_vec[res_index]
+ tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1)
+ src_patch = src_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1)
+ loss = (1 - self.ssim_loss(src_patch, tar_patch))
+ return loss
+
diff --git a/nsr/losses/vqperceptual.py b/nsr/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..6195f0a6ed7ee6fd32c1bccea071e6075e95ee43
--- /dev/null
+++ b/nsr/losses/vqperceptual.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn.functional as F
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real))
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
+ )
+ return d_loss
diff --git a/nsr/lsgm/__init__.py b/nsr/lsgm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dffd407f972e1a806773fe468915f69106e458d
--- /dev/null
+++ b/nsr/lsgm/__init__.py
@@ -0,0 +1,5 @@
+# sde diffusion
+from .train_util_diffusion_lsgm import TrainLoop3DDiffusionLSGM
+from .train_util_diffusion_vpsde import TrainLoop3DDiffusion_vpsde
+from .crossattn_cldm import TrainLoop3DDiffusionLSGM_crossattn
+from .train_util_diffusion_lsgm_noD_joint import TrainLoop3DDiffusionLSGMJointnoD, TrainLoop3DDiffusionLSGMJointnoD_ponly
diff --git a/nsr/lsgm/__pycache__/__init__.cpython-39.pyc b/nsr/lsgm/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..635e8c75a266d49d38f1da74b931f17fc8351199
Binary files /dev/null and b/nsr/lsgm/__pycache__/__init__.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/controlLDM.cpython-39.pyc b/nsr/lsgm/__pycache__/controlLDM.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90c25f5e3acc1b968d3b69e1ae303cd633084a73
Binary files /dev/null and b/nsr/lsgm/__pycache__/controlLDM.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/crossattn_cldm.cpython-39.pyc b/nsr/lsgm/__pycache__/crossattn_cldm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38ba2cb7ea5f23444325838de5131bf9305322e5
Binary files /dev/null and b/nsr/lsgm/__pycache__/crossattn_cldm.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/crossattn_cldm_objv.cpython-39.pyc b/nsr/lsgm/__pycache__/crossattn_cldm_objv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d22b8c6528e123d210eccbec81148f12488bcf27
Binary files /dev/null and b/nsr/lsgm/__pycache__/crossattn_cldm_objv.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/train_util_diffusion_lsgm.cpython-39.pyc b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e735d7a3879f9d4de2139a70f6730d24b5db5702
Binary files /dev/null and b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_cvD_joint.cpython-39.pyc b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_cvD_joint.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a86c16315ec9063a0c41d94a16b658ea392f3c65
Binary files /dev/null and b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_cvD_joint.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_noD.cpython-39.pyc b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_noD.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36f826596140e9ccd01dd734b7a99cdda672e439
Binary files /dev/null and b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_noD.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_noD_joint.cpython-39.pyc b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_noD_joint.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f78e55d13ac9ab22bf5b242fd41e925c1bb40456
Binary files /dev/null and b/nsr/lsgm/__pycache__/train_util_diffusion_lsgm_noD_joint.cpython-39.pyc differ
diff --git a/nsr/lsgm/__pycache__/train_util_diffusion_vpsde.cpython-39.pyc b/nsr/lsgm/__pycache__/train_util_diffusion_vpsde.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db083636e329819d4a7e6627cc3b92797940d7f9
Binary files /dev/null and b/nsr/lsgm/__pycache__/train_util_diffusion_vpsde.cpython-39.pyc differ
diff --git a/nsr/lsgm/crossattn_cldm.py b/nsr/lsgm/crossattn_cldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e68fb234dcb1002aea2f67ff3b6fed4ac651341
--- /dev/null
+++ b/nsr/lsgm/crossattn_cldm.py
@@ -0,0 +1,719 @@
+"""
+https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/diffusion/ddpm.py#L30
+"""
+import copy
+
+from matplotlib import pyplot as plt
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+from typing import Any
+import einops
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+# from .train_util import TrainLoop3DRec
+from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_loss_dict,
+ log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+from guided_diffusion.gaussian_diffusion import ModelMeanType
+
+from ldm.modules.encoders.modules import FrozenClipImageEmbedder, TextEmbedder, FrozenCLIPTextEmbedder
+
+import dnnlib
+from dnnlib.util import requires_grad
+from dnnlib.util import calculate_adaptive_weight
+
+from ..train_util_diffusion import TrainLoop3DDiffusion
+from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD
+
+from guided_diffusion.continuous_diffusion_utils import get_mixed_prediction, different_p_q_objectives, kl_per_group_vada, kl_balancer
+# from .train_util_diffusion_lsgm_noD_joint import TrainLoop3DDiffusionLSGMJointnoD # joint diffusion and rec class
+# from .controlLDM import TrainLoop3DDiffusionLSGM_Control # joint diffusion and rec class
+from .train_util_diffusion_lsgm_noD_joint import TrainLoop3DDiffusionLSGMJointnoD # joint diffusion and rec class
+
+__conditioning_keys__ = {
+ 'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class TrainLoop3DDiffusionLSGM_crossattn(TrainLoop3DDiffusionLSGMJointnoD):
+
+ def __init__(self,
+ *,
+ rec_model,
+ denoise_model,
+ diffusion,
+ sde_diffusion,
+ control_model,
+ control_key,
+ only_mid_control,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ resume_cldm_checkpoint=None,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ schedule_sampler=None,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ ignore_resume_opt=False,
+ freeze_ae=False,
+ denoised_ae=True,
+ triplane_scaling_divider=10,
+ use_amp=False,
+ diffusion_input_size=224,
+ normalize_clip_encoding=False,
+ scale_clip_encoding=1.0,
+ cfg_dropout_prob=0.,
+ cond_key='img_sr',
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ denoise_model=denoise_model,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ control_model=control_model,
+ control_key=control_key,
+ only_mid_control=only_mid_control,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ resume_cldm_checkpoint=resume_cldm_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ schedule_sampler=schedule_sampler,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ ignore_resume_opt=ignore_resume_opt,
+ freeze_ae=freeze_ae,
+ denoised_ae=denoised_ae,
+ triplane_scaling_divider=triplane_scaling_divider,
+ use_amp=use_amp,
+ diffusion_input_size=diffusion_input_size,
+ **kwargs)
+ self.conditioning_key = 'c_crossattn'
+ self.cond_key = cond_key
+ self.instantiate_cond_stage(normalize_clip_encoding,
+ scale_clip_encoding, cfg_dropout_prob)
+ requires_grad(self.rec_model, False)
+ self.rec_model.eval()
+ # self.normalize_clip_encoding = normalize_clip_encoding
+ # self.cfg_dropout_prob = cfg_dropout_prob
+
+ def instantiate_cond_stage(self, normalize_clip_encoding,
+ scale_clip_encoding, cfg_dropout_prob):
+ # https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/diffusion/ddpm.py#L509C1-L509C46
+ # self.cond_stage_model.train = disabled_train # type: ignore
+ # st()
+ if self.cond_key == 'caption': # for objaverse training (with extracted cap3d caption)
+ self.cond_txt_model = TextEmbedder(dropout_prob=cfg_dropout_prob)
+ else: # zero-shot Text to 3D using normalized clip latent
+ self.cond_stage_model = FrozenClipImageEmbedder(
+ 'ViT-L/14',
+ dropout_prob=cfg_dropout_prob,
+ normalize_encoding=normalize_clip_encoding,
+ scale_clip_encoding=scale_clip_encoding)
+ self.cond_stage_model.freeze()
+
+ self.cond_txt_model = FrozenCLIPTextEmbedder(
+ dropout_prob=cfg_dropout_prob,
+ scale_clip_encoding=scale_clip_encoding)
+ self.cond_txt_model.freeze()
+
+ @th.no_grad()
+ def get_c_input(self,
+ batch,
+ bs=None,
+ use_text=False,
+ prompt="",
+ *args,
+ **kwargs):
+
+ # using clip to transform control to tokens for crossattn
+ cond_inp = None
+
+ if self.cond_key == 'caption':
+ c = self.cond_txt_model(
+ cond_inp, train=self.ddpm_model.training
+ ) # ! SD training text condition injection layer
+ # st() # check whether context repeat?
+ else: # zero shot
+ if use_text: # for test
+ assert prompt != ""
+ c = self.cond_txt_model.encode(prompt) # ! for test
+ # st()
+ else:
+
+ cond_inp = batch[self.cond_key]
+ if bs is not None:
+ cond_inp = cond_inp[:bs]
+
+ cond_inp = cond_inp.to(
+ memory_format=th.contiguous_format).float()
+ c = self.cond_stage_model(cond_inp) # BS 768
+
+ # return dict(c_concat=[control])
+ # return dict(c_crossattn=[c], c_concat=[control])
+ # return dict(__conditioning_keys__[self.cond_key]=)
+ # return {self.conditioning_key: [c], 'c_concat': [cond_inp]}
+ return {self.conditioning_key: c, 'c_concat': [cond_inp]}
+
+ # TODO, merge the APIs
+ def apply_model_inference(self, x_noisy, t, c, model_kwargs={}):
+ pred_params = self.ddp_ddpm_model(
+ x_noisy, t, **{
+ **model_kwargs, 'context': c['c_crossattn']
+ })
+ return pred_params
+
+ def apply_model(self, p_sample_batch, cond, model_kwargs={}):
+ return super().apply_model(
+ p_sample_batch, **{
+ **model_kwargs, 'context': cond['c_crossattn']
+ })
+
+ def run_step(self, batch, step='ldm_step'):
+
+ # if step == 'diffusion_step_rec':
+
+ if step == 'ldm_step':
+ self.ldm_train_step(batch)
+
+ # if took_step_ddpm:
+ # self._update_cldm_ema()
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ # dist_util.synchronize()
+
+ batch = next(self.data)
+ self.run_step(batch, step='ldm_step')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ # if self.step % self.eval_interval == 0 and self.step != 0:
+ if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ # self.eval_ddpm_sample()
+ # self.eval_cldm(use_ddim=True, unconditional_guidance_scale=7.5, prompt="") # during training, use image as condition
+ self.eval_cldm(use_ddim=False,
+ prompt="") # fix condition bug first
+ # if self.sde_diffusion.args.train_vae:
+ # self.eval_loop()
+
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ # if self.sde_diffusion.args.train_vae:
+ # self.save(self.mp_trainer_rec,
+ # self.mp_trainer_rec.model_name)
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save(self.mp_trainer,
+ self.mp_trainer.model_name) # rec and ddpm all fixed.
+ # st()
+ # self.save(self.mp_trainer_canonical_cvD, 'cvD')
+
+ # ddpm + rec loss
+ def ldm_train_step(self, batch, behaviour='cano', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+
+ # ! enable the gradient of both models
+ requires_grad(self.ddpm_model, True)
+
+ self.mp_trainer.zero_grad() # !!!!
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+
+ loss = th.tensor(0.).to(dist_util.dev())
+
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='encoder_vae',
+ ) # pred: (B, 3, 64, 64)
+ eps = vae_out[self.latent_name]
+ # eps = vae_out.pop(self.latent_name)
+
+ if 'bg_plane' in vae_out:
+ eps = th.cat((eps, vae_out['bg_plane']),
+ dim=1) # include background, B 12+4 32 32
+
+ p_sample_batch = self.prepare_ddpm(eps)
+ cond = self.get_c_input(micro)
+
+ # ! running diffusion forward
+ ddpm_ret = self.apply_model(p_sample_batch, cond)
+ if self.sde_diffusion.args.p_rendering_loss:
+
+ target = micro
+ pred = self.ddp_rec_model(
+ # latent=vae_out,
+ latent={
+ # **vae_out,
+ self.latent_name: ddpm_ret['pred_x0_p'],
+ 'latent_name': self.latent_name
+ },
+ c=micro['c'],
+ behaviour=self.render_latent_behaviour)
+
+ # vae reconstruction loss
+ with self.ddp_control_model.no_sync(): # type: ignore
+ p_vae_recon_loss, rec_loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+ log_rec3d_loss_dict(rec_loss_dict)
+ # log_rec3d_loss_dict(
+ # dict(p_vae_recon_loss=p_vae_recon_loss, ))
+ loss = p_vae_recon_loss + ddpm_ret[
+ 'p_eps_objective'] # TODO, add obj_weight_t_p?
+ else:
+ loss = ddpm_ret['p_eps_objective'].mean()
+
+ # =====================================================================
+
+ self.mp_trainer.backward(loss) # joint gradient descent
+
+ # update ddpm accordingly
+ self.mp_trainer.optimize(self.opt)
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ self.log_control_images(vae_out, p_sample_batch, micro, ddpm_ret)
+
+ @th.inference_mode()
+ def log_control_images(self, vae_out, p_sample_batch, micro, ddpm_ret):
+
+ eps_t_p, t_p, logsnr_p = (p_sample_batch[k] for k in (
+ 'eps_t_p',
+ 't_p',
+ 'logsnr_p',
+ ))
+ pred_eps_p = ddpm_ret['pred_eps_p']
+
+ vae_out.pop('posterior') # for calculating kl loss
+ vae_out_for_pred = {
+ k: v[0:1].to(dist_util.dev()) if isinstance(v, th.Tensor) else v
+ for k, v in vae_out.items()
+ }
+
+ pred = self.ddp_rec_model(latent=vae_out_for_pred,
+ c=micro['c'][0:1],
+ behaviour=self.render_latent_behaviour)
+ assert isinstance(pred, dict)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ else:
+ gt_depth = th.zeros_like(gt_img[:, 0:1, ...])
+
+ if 'image_depth' in pred:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+
+ gt_img = self.pool_128(gt_img)
+ gt_depth = self.pool_128(gt_depth)
+ # cond = self.get_c_input(micro)
+ # hint = th.cat(cond['c_concat'], 1)
+
+ gt_vis = th.cat(
+ [
+ gt_img,
+ gt_img,
+ gt_img,
+ # self.pool_128(hint),
+ # gt_img,
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ # eps_t_p_3D = eps_t_p.reshape(batch_size, eps_t_p.shape[1]//3, 3, -1) # B C 3 L
+
+ if 'bg_plane' in vae_out:
+ noised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ eps_t_p[0:1, :12] * self.triplane_scaling_divider,
+ 'bg_plane':
+ eps_t_p[0:1, 12:16] * self.triplane_scaling_divider,
+ }
+ else:
+ noised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ eps_t_p[0:1] * self.triplane_scaling_divider,
+ }
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=noised_latent,
+ # latent=eps_t_p[0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ pred_x0 = self.sde_diffusion._predict_x0_from_eps(
+ eps_t_p, pred_eps_p, logsnr_p) # for VAE loss, denosied latent
+
+ if 'bg_plane' in vae_out:
+ denoised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ pred_x0[0:1, :12] * self.triplane_scaling_divider,
+ 'bg_plane':
+ pred_x0[0:1, 12:16] * self.triplane_scaling_divider,
+ }
+ else:
+ denoised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ pred_x0[0:1] * self.triplane_scaling_divider,
+ }
+
+ # pred_xstart_3D
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=denoised_latent,
+ # latent=pred_x0[0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ pred_vis = th.cat(
+ [
+ self.pool_128(img) for img in (
+ pred_img[0:1],
+ noised_ae_pred['image_raw'][0:1],
+ denoised_ae_pred['image_raw'][0:1], # controlnet result
+ pred_depth[0:1].repeat_interleave(3, dim=1))
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis],
+ dim=-2)[0].permute(1, 2,
+ 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}.jpg'
+ )
+
+ if self.cond_key == 'caption':
+ with open(
+ f'{logger.get_dir()}/{self.step+self.resume_step}caption_{t_p[0].item():3}.txt',
+ 'w') as f:
+ f.write(micro['caption'][0])
+
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}.jpg'
+ )
+
+ th.cuda.empty_cache()
+
+ @th.inference_mode()
+ def eval_cldm(self,
+ prompt="",
+ use_ddim=False,
+ unconditional_guidance_scale=1.0,
+ save_img=False,
+ use_train_trajectory=False,
+ export_mesh=False,
+ camera=None,
+ overwrite_diff_inp_size=None):
+ self.ddpm_model.eval()
+
+ args = dnnlib.EasyDict(
+ dict(
+ batch_size=self.batch_size,
+ image_size=self.diffusion_input_size,
+ denoise_in_channels=self.rec_model.decoder.triplane_decoder.
+ out_chans, # type: ignore
+ clip_denoised=False,
+ class_cond=False,
+ use_ddim=use_ddim))
+
+ model_kwargs = {}
+
+ if args.class_cond:
+ classes = th.randint(low=0,
+ high=NUM_CLASSES,
+ size=(args.batch_size, ),
+ device=dist_util.dev())
+ model_kwargs["y"] = classes
+
+ diffusion = self.diffusion
+ sample_fn = (diffusion.p_sample_loop
+ if not args.use_ddim else diffusion.ddim_sample_loop)
+ extra_kwargs = {}
+ if args.use_ddim:
+ extra_kwargs.update(
+ dict(
+ unconditional_guidance_scale=unconditional_guidance_scale))
+
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+ # if use_train_trajectory:
+ # batch = next(iter(self.data))
+ # else:
+ # batch = next(iter(self.eval_data))
+
+ # st() # th.save(batch['c'].cpu(), 'assets/shapenet_eval_pose.pt')
+
+ assert camera is not None # for evaluation
+ batch = {'c': camera.clone()}
+ # st()
+
+ # use the first frame as the condition now
+ novel_view_cond = {
+ k:
+ v[0:1].to(dist_util.dev()) if isinstance(v, th.Tensor) else v[0:1]
+ # micro['img'].shape[0], 0)
+ for k, v in batch.items()
+ }
+ cond = self.get_c_input(novel_view_cond,
+ use_text=prompt != "",
+ prompt=prompt) # use specific prompt for debug
+
+ # broadcast to args.batch_size
+ cond = {
+ k: cond_v.repeat_interleave(args.batch_size, 0)
+ for k, cond_v in cond.items() if k == self.conditioning_key
+ }
+
+ for i in range(1):
+ # st()
+ noise_size = (
+ args.batch_size,
+ self.ddpm_model.in_channels,
+ self.diffusion_input_size if not overwrite_diff_inp_size else int(overwrite_diff_inp_size),
+ self.diffusion_input_size if not overwrite_diff_inp_size else int(overwrite_diff_inp_size)
+ )
+
+ triplane_sample = sample_fn(
+ self,
+ noise_size,
+ cond=cond,
+ clip_denoised=args.clip_denoised,
+ model_kwargs=model_kwargs,
+ mixing_normal=True, # !
+ device=dist_util.dev(),
+ **extra_kwargs)
+ # triplane_sample = th.zeros((args.batch_size, self.ddpm_model.in_channels, self.diffusion_input_size, self.diffusion_input_size), device=dist_util.dev())
+ th.cuda.empty_cache()
+
+ for sub_idx in range(triplane_sample.shape[0]):
+
+ self.render_video_given_triplane(
+ triplane_sample[sub_idx:sub_idx + 1],
+ self.rec_model, # compatible with join_model
+ name_prefix=f'{self.step + self.resume_step}_{i+sub_idx}',
+ save_img=save_img,
+ render_reference=batch,
+ # render_reference=None,
+ export_mesh=export_mesh,
+ render_all=True,
+ )
+
+ del triplane_sample
+ th.cuda.empty_cache()
+
+ self.ddpm_model.train()
+
+ @th.inference_mode()
+ # def eval_loop(self, c_list:list):
+ def eval_novelview_loop(self, rec_model):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ all_loss_dict = []
+ novel_view_micro = {}
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i == 0:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in batch.items()
+ }
+
+ torchvision.utils.save_image(
+ self.pool_128(novel_view_micro['img']),
+ logger.get_dir() + '/FID_Cals/gt.png',
+ normalize=True,
+ val_range=(0, 1),
+ padding=0)
+
+ else:
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ th.manual_seed(0) # avoid vae re-sampling changes results
+ pred = rec_model(img=novel_view_micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # ! move to other places, add tensorboard
+
+ # pred_vis = th.cat([
+ # pred['image_raw'],
+ # -pred['image_depth'].repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+
+ # ! save
+
+ pooled_depth = self.pool_128(pred_depth).repeat_interleave(3,
+ dim=1)
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ pooled_depth,
+ ],
+ dim=-1) # B, 3, H, W
+
+ # ! save depth
+ name_prefix = i
+
+ torchvision.utils.save_image(self.pool_128(pred['image_raw']),
+ logger.get_dir() +
+ '/FID_Cals/{}.png'.format(i),
+ normalize=True,
+ val_range=(0, 1),
+ padding=0)
+
+ torchvision.utils.save_image(self.pool_128(pooled_depth),
+ logger.get_dir() +
+ '/FID_Cals/{}_depth.png'.format(i),
+ normalize=True,
+ val_range=(0, 1),
+ padding=0)
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ del video_out
+ # del pred_vis
+ # del pred
+
+ th.cuda.empty_cache()
diff --git a/nsr/lsgm/crossattn_cldm_objv.py b/nsr/lsgm/crossattn_cldm_objv.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba269f7a5c035fe043eca80ab9e5d1b14192331b
--- /dev/null
+++ b/nsr/lsgm/crossattn_cldm_objv.py
@@ -0,0 +1,1124 @@
+"""
+https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/diffusion/ddpm.py#L30
+"""
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+from typing import Any
+from click import prompt
+import einops
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+# from .train_util import TrainLoop3DRec
+from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_loss_dict,
+ log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+from guided_diffusion.gaussian_diffusion import ModelMeanType
+
+from ldm.modules.encoders.modules import FrozenClipImageEmbedder, TextEmbedder, FrozenCLIPTextEmbedder, FrozenOpenCLIPImagePredictionEmbedder, FrozenOpenCLIPImageEmbedder
+
+import dnnlib
+from dnnlib.util import requires_grad
+from dnnlib.util import calculate_adaptive_weight
+
+from ..train_util_diffusion import TrainLoop3DDiffusion
+from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD
+
+from guided_diffusion.continuous_diffusion_utils import get_mixed_prediction, different_p_q_objectives, kl_per_group_vada, kl_balancer
+# from .train_util_diffusion_lsgm_noD_joint import TrainLoop3DDiffusionLSGMJointnoD # joint diffusion and rec class
+# from .controlLDM import TrainLoop3DDiffusionLSGM_Control # joint diffusion and rec class
+from .train_util_diffusion_lsgm_noD_joint import TrainLoop3DDiffusionLSGMJointnoD # joint diffusion and rec class
+
+__conditioning_keys__ = {
+ 'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class TrainLoop3DDiffusionLSGM_crossattn(TrainLoop3DDiffusionLSGMJointnoD):
+
+ def __init__(self,
+ *,
+ rec_model,
+ denoise_model,
+ diffusion,
+ sde_diffusion,
+ control_model,
+ control_key,
+ only_mid_control,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ resume_cldm_checkpoint=None,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ schedule_sampler=None,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ ignore_resume_opt=False,
+ freeze_ae=False,
+ denoised_ae=True,
+ triplane_scaling_divider=10,
+ use_amp=False,
+ diffusion_input_size=224,
+ normalize_clip_encoding=False,
+ scale_clip_encoding=1.0,
+ cfg_dropout_prob=0.,
+ cond_key='img_sr',
+ use_eos_feature=False,
+ compile=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ denoise_model=denoise_model,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ control_model=control_model,
+ control_key=control_key,
+ only_mid_control=only_mid_control,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ resume_cldm_checkpoint=resume_cldm_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ schedule_sampler=schedule_sampler,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ ignore_resume_opt=ignore_resume_opt,
+ freeze_ae=freeze_ae,
+ denoised_ae=denoised_ae,
+ triplane_scaling_divider=triplane_scaling_divider,
+ use_amp=use_amp,
+ diffusion_input_size=diffusion_input_size,
+ compile=compile,
+ **kwargs)
+ self.conditioning_key = 'c_crossattn'
+ self.cond_key = cond_key
+ self.instantiate_cond_stage(normalize_clip_encoding,
+ scale_clip_encoding, cfg_dropout_prob,
+ use_eos_feature)
+ requires_grad(self.rec_model, False)
+ self.rec_model.eval()
+
+ # self.normalize_clip_encoding = normalize_clip_encoding
+ # self.cfg_dropout_prob = cfg_dropout_prob
+
+ def instantiate_cond_stage(self, normalize_clip_encoding,
+ scale_clip_encoding, cfg_dropout_prob,
+ use_eos_feature):
+ # https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/diffusion/ddpm.py#L509C1-L509C46
+ # self.cond_stage_model.train = disabled_train # type: ignore
+ if self.cond_key == 'caption':
+ self.cond_txt_model = TextEmbedder(dropout_prob=cfg_dropout_prob,
+ use_eos_feature=use_eos_feature)
+ elif self.cond_key == 'img':
+ self.cond_img_model = FrozenOpenCLIPImagePredictionEmbedder(
+ 1, 1,
+ FrozenOpenCLIPImageEmbedder(freeze=True,
+ device=dist_util.dev(),
+ init_device=dist_util.dev()))
+
+ else: # zero-shot Text to 3D using normalized clip latent
+ self.cond_stage_model = FrozenClipImageEmbedder(
+ 'ViT-L/14',
+ dropout_prob=cfg_dropout_prob,
+ normalize_encoding=normalize_clip_encoding,
+ scale_clip_encoding=scale_clip_encoding)
+ self.cond_stage_model.freeze()
+
+ self.cond_txt_model = FrozenCLIPTextEmbedder(
+ dropout_prob=cfg_dropout_prob,
+ scale_clip_encoding=scale_clip_encoding)
+ self.cond_txt_model.freeze()
+
+ @th.no_grad()
+ def get_c_input(self,
+ batch,
+ bs=None,
+ use_text=False,
+ prompt="",
+ force_drop_ids=None,
+ *args,
+ **kwargs):
+ if use_text:
+ cond_inp = prompt
+ else:
+ if 'caption' in self.cond_key: # support caption-img
+ cond_inp = batch['caption']
+ else:
+ cond_inp = batch[self.cond_key]
+ # if bs is not None:
+ # cond_inp = cond_inp[:bs]
+
+ # using clip to transform control to tokens for crossattn
+ control = None
+ if 'caption' in self.cond_key:
+ c = self.cond_txt_model(
+ cond_inp,
+ train=self.ddpm_model.training,
+ force_drop_ids=force_drop_ids,
+ ) # ! SD training text condition injection layer
+ if bs is None: # duplicated sample
+ if c.shape[0] != batch['c'].shape[0]:
+ c = th.repeat_interleave(c,
+ batch['c'].shape[0] // c.shape[0],
+ dim=0)
+ else:
+ assert c.shape[0] == bs
+
+ # st()
+ # if 'img' in self.cond_key:
+
+ # ! later
+ # if 'img' in batch:
+ # control = batch['img'] + 0.02 * th.randn_like(
+ # batch['img']) # follow SVD?
+
+ elif self.cond_key == 'img':
+ c = self.cond_img_model(cond_inp)
+ # control = batch['img']
+ control = batch['img'] + 0.02 * th.randn_like(
+ batch['img']) # follow SVD?
+
+ else: # zero shot
+ if use_text: # for test
+ assert prompt != ""
+ c = self.cond_txt_model.encode(prompt) # ! for test
+ else:
+ cond_inp = cond_inp.to(
+ memory_format=th.contiguous_format).float()
+ c = self.cond_stage_model(cond_inp) # BS 768
+
+ # if c.shape[0] < batch['img_to_encoder'].shape[0]:
+ # c = th.repeat_interleave(c, batch['img_to_encoder'].shape[0]//c.shape[0], dim=0)
+
+ # return dict(c_concat=[control])
+ # return dict(c_crossattn=c, c_concat=batch['img'])
+ # if self.cond_key == 'img':
+ # return dict(c_crossattn=c, c_concat=control)
+ return dict(c_crossattn=c)
+
+ # else:
+ # return dict(c_crossattn=c)
+
+ # return dict(__conditioning_keys__[self.cond_key]=)
+ # return {self.conditioning_key: [c], 'c_concat': [cond_inp]}
+ # return {self.conditioning_key: c, 'c_concat': [cond_inp]}
+
+ # TODO, merge the APIs
+ def apply_model_inference(self, x_noisy, t, c, model_kwargs={}):
+ pred_params = self.ddp_ddpm_model(x_noisy,
+ timesteps=t,
+ **{
+ **model_kwargs, 'context':
+ c['c_crossattn'],
+ 'hint':
+ c.get('c_concat', None)
+ })
+ return pred_params
+
+ def apply_model(self, p_sample_batch, cond, model_kwargs={}):
+ return super().apply_model(
+ p_sample_batch,
+ **{
+ **model_kwargs, 'context': cond['c_crossattn'],
+ 'hint': cond.get('c_concat', None)
+ # **cond,
+ })
+
+ def run_step(self, batch, step='ldm_step'):
+
+ # if step == 'diffusion_step_rec':
+
+ if step == 'ldm_step':
+ self.ldm_train_step(batch)
+
+ # if took_step_ddpm:
+ # self._update_cldm_ema()
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self):
+ # eval camera
+ camera = th.load('eval_pose.pt', map_location=dist_util.dev())
+
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ # dist_util.synchronize()
+
+ batch = next(self.data)
+ self.run_step(batch, step='ldm_step')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.eval_interval == 0 and self.step != 0:
+ # if self.step % self.eval_interval == 0:
+ # if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ # self.eval_ddpm_sample()
+ # self.eval_cldm(use_ddim=True, unconditional_guidance_scale=7.5, prompt="") # during training, use image as condition
+ if self.cond_key == 'caption':
+ self.eval_cldm(
+ use_ddim=False,
+ prompt="a voxelized dog",
+ use_train_trajectory=False,
+ camera=camera) # fix condition bug first
+ else:
+ pass # TODO
+ # self.eval_cldm(use_ddim=False,
+ # prompt="",
+ # use_train_trajectory=False,
+ # camera=camera) # fix condition bug first
+ # if self.sde_diffusion.args.train_vae:
+ # self.eval_loop()
+
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ # if self.sde_diffusion.args.train_vae:
+ # self.save(self.mp_trainer_rec,
+ # self.mp_trainer_rec.model_name)
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save(self.mp_trainer,
+ self.mp_trainer.model_name) # rec and ddpm all fixed.
+ # st()
+ # self.save(self.mp_trainer_canonical_cvD, 'cvD')
+
+ # ddpm + rec loss
+ def ldm_train_step(self, batch, behaviour='cano', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+
+ # ! enable the gradient of both models
+ requires_grad(self.ddpm_model, True)
+
+ self.mp_trainer.zero_grad() # !!!!
+
+ if 'img' in batch:
+ batch_size = batch['img'].shape[0]
+ else:
+ batch_size = len(batch['caption'])
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+
+ loss = th.tensor(0.).to(dist_util.dev())
+
+ if 'latent' in micro:
+ vae_out = {self.latent_name: micro['latent']}
+ else:
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='encoder_vae',
+ ) # pred: (B, 3, 64, 64)
+
+ eps = vae_out[self.latent_name] / self.triplane_scaling_divider
+ # eps = vae_out.pop(self.latent_name)
+
+ if 'bg_plane' in vae_out:
+ eps = th.cat((eps, vae_out['bg_plane']),
+ dim=1) # include background, B 12+4 32 32
+
+ p_sample_batch = self.prepare_ddpm(eps)
+ cond = self.get_c_input(micro, bs=eps.shape[0])
+
+ # ! running diffusion forward
+ ddpm_ret = self.apply_model(p_sample_batch, cond)
+ if self.sde_diffusion.args.p_rendering_loss:
+
+ target = micro
+ pred = self.ddp_rec_model(
+ # latent=vae_out,
+ latent={
+ # **vae_out,
+ self.latent_name: ddpm_ret['pred_x0_p'],
+ 'latent_name': self.latent_name
+ },
+ c=micro['c'],
+ behaviour=self.render_latent_behaviour)
+
+ # vae reconstruction loss
+ with self.ddp_control_model.no_sync(): # type: ignore
+ p_vae_recon_loss, rec_loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+ log_rec3d_loss_dict(rec_loss_dict)
+ # log_rec3d_loss_dict(
+ # dict(p_vae_recon_loss=p_vae_recon_loss, ))
+ loss = p_vae_recon_loss + ddpm_ret[
+ 'p_eps_objective'] # TODO, add obj_weight_t_p?
+ else:
+ loss = ddpm_ret['p_eps_objective'].mean()
+
+ # =====================================================================
+
+ self.mp_trainer.backward(loss) # joint gradient descent
+
+ # update ddpm accordingly
+ self.mp_trainer.optimize(self.opt)
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ self.log_control_images(vae_out, p_sample_batch, micro, ddpm_ret)
+
+ @th.inference_mode()
+ def log_control_images(self, vae_out, p_sample_batch, micro, ddpm_ret):
+
+ eps_t_p, t_p, logsnr_p = (p_sample_batch[k] for k in (
+ 'eps_t_p',
+ 't_p',
+ 'logsnr_p',
+ ))
+ pred_eps_p = ddpm_ret['pred_eps_p']
+
+ if 'posterior' in vae_out:
+ vae_out.pop('posterior') # for calculating kl loss
+ vae_out_for_pred = {
+ k: v[0:1].to(dist_util.dev()) if isinstance(v, th.Tensor) else v
+ for k, v in vae_out.items()
+ }
+
+ pred = self.ddp_rec_model(latent=vae_out_for_pred,
+ c=micro['c'][0:1],
+ behaviour=self.render_latent_behaviour)
+ assert isinstance(pred, dict)
+
+ pred_img = pred['image_raw']
+ if 'img' in micro:
+ gt_img = micro['img']
+ else:
+ gt_img = th.zeros_like(pred['image_raw'])
+
+ if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ else:
+ gt_depth = th.zeros_like(gt_img[:, 0:1, ...])
+
+ if 'image_depth' in pred:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+
+ gt_img = self.pool_128(gt_img)
+ gt_depth = self.pool_128(gt_depth)
+ # cond = self.get_c_input(micro)
+ # hint = th.cat(cond['c_concat'], 1)
+
+ gt_vis = th.cat(
+ [
+ gt_img,
+ gt_img,
+ gt_img,
+ # self.pool_128(hint),
+ # gt_img,
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ # eps_t_p_3D = eps_t_p.reshape(batch_size, eps_t_p.shape[1]//3, 3, -1) # B C 3 L
+
+ if 'bg_plane' in vae_out:
+ noised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ eps_t_p[0:1, :12] * self.triplane_scaling_divider,
+ 'bg_plane':
+ eps_t_p[0:1, 12:16] * self.triplane_scaling_divider,
+ }
+ else:
+ noised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ eps_t_p[0:1] * self.triplane_scaling_divider,
+ }
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=noised_latent,
+ # latent=eps_t_p[0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ pred_x0 = self.sde_diffusion._predict_x0_from_eps(
+ eps_t_p, pred_eps_p, logsnr_p) # for VAE loss, denosied latent
+
+ if 'bg_plane' in vae_out:
+ denoised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ pred_x0[0:1, :12] * self.triplane_scaling_divider,
+ 'bg_plane':
+ pred_x0[0:1, 12:16] * self.triplane_scaling_divider,
+ }
+ else:
+ denoised_latent = {
+ 'latent_normalized_2Ddiffusion':
+ pred_x0[0:1] * self.triplane_scaling_divider,
+ }
+
+ # pred_xstart_3D
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=denoised_latent,
+ # latent=pred_x0[0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ pred_vis = th.cat(
+ [
+ self.pool_128(img) for img in (
+ pred_img[0:1],
+ noised_ae_pred['image_raw'][0:1],
+ denoised_ae_pred['image_raw'][0:1], # controlnet result
+ pred_depth[0:1].repeat_interleave(3, dim=1))
+ ],
+ dim=-1) # B, 3, H, W
+
+ if 'img' in micro:
+ vis = th.cat([gt_vis, pred_vis],
+ dim=-2)[0].permute(1, 2,
+ 0).cpu() # ! pred in range[-1, 1]
+ else:
+ vis = pred_vis[0].permute(1, 2, 0).cpu()
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}.jpg'
+ )
+
+ # if self.cond_key == 'caption':
+ # with open(f'{logger.get_dir()}/{self.step+self.resume_step}caption_{t_p[0].item():3}.txt', 'w') as f:
+ # f.write(micro['caption'][0])
+
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}.jpg'
+ )
+
+ th.cuda.empty_cache()
+
+ @th.inference_mode()
+ def eval_cldm(
+ self,
+ prompt="",
+ use_ddim=False,
+ unconditional_guidance_scale=1.0,
+ save_img=False,
+ use_train_trajectory=False,
+ camera=None,
+ num_samples=1,
+ num_instances=1,
+ export_mesh=False,
+ ):
+ self.ddpm_model.eval()
+
+ args = dnnlib.EasyDict(
+ dict(
+ # batch_size=1,
+ batch_size=self.batch_size,
+ image_size=self.diffusion_input_size,
+ denoise_in_channels=self.rec_model.decoder.triplane_decoder.
+ out_chans, # type: ignore
+ clip_denoised=False,
+ class_cond=False,
+ use_ddim=use_ddim))
+
+ model_kwargs = {}
+
+ if args.class_cond:
+ classes = th.randint(low=0,
+ high=NUM_CLASSES,
+ size=(args.batch_size, ),
+ device=dist_util.dev())
+ model_kwargs["y"] = classes
+
+ diffusion = self.diffusion
+ sample_fn = (diffusion.p_sample_loop
+ if not args.use_ddim else diffusion.ddim_sample_loop)
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+
+ # use the first frame as the condition now
+ extra_kwargs = {}
+
+ uc = None
+ if args.use_ddim:
+ if unconditional_guidance_scale != 1.0:
+ uc = self.get_c_input(
+ {self.cond_key: 'None'},
+ use_text=True,
+ prompt="None",
+ bs=1, # TODO, support BS>1 later
+ force_drop_ids=np.array(
+ [ # ! make sure using dropped tokens
+ 1
+ ])) # use specific prompt for debug
+ extra_kwargs.update(
+ dict(
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc, # TODO
+ objv_inference=True,
+ # {
+ # k : unconditional_guidance_scale
+ # for k in cond.keys()
+ # }
+ ))
+
+ # hint = th.cat(cond['c_concat'], 1)
+
+ # record cond images
+ # broadcast to args.batch_size
+
+ for instance in range(num_instances):
+
+ if self.cond_key == 'caption':
+ if camera is not None:
+ batch = {'c': camera.clone()}
+ else:
+ if use_train_trajectory:
+ batch = next(iter(self.data))
+ else:
+ try:
+ batch = next(self.eval_data)
+ except Exception as e:
+ self.eval_data = iter(self.eval_data)
+ batch = next(self.eval_data)
+
+ if camera is not None:
+ batch['c'] = camera.clone()
+
+ # ! generate new samples
+
+ novel_view_cond = {
+ k:
+ v[0:1].to(dist_util.dev())
+ if isinstance(v, th.Tensor) else v[0:1]
+ # micro['img'].shape[0], 0)
+ for k, v in batch.items()
+ }
+
+ cond = self.get_c_input(
+ novel_view_cond, use_text=prompt != "",
+ prompt=prompt) # use specific prompt for debug
+
+ cond = {
+ k: cond_v.repeat_interleave(args.batch_size, 0)
+ for k, cond_v in cond.items()
+ # if k == self.conditioning_key
+ }
+
+ if self.cond_key == 'caption':
+ if prompt != '':
+ with open(
+ f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{instance}_caption.txt',
+ 'w') as f:
+ f.write(prompt)
+ else:
+ with open(
+ f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{instance}_caption.txt',
+ 'w') as f:
+ try:
+ f.write(novel_view_cond['caption'][0])
+ except Exception as e:
+ pass
+
+ elif self.cond_key == 'img':
+ torchvision.utils.save_image(
+ cond['c_concat'],
+ f'{logger.get_dir()}/{self.step + self.resume_step}_{instance}_cond.jpg',
+ normalize=True,
+ value_range=(-1, 1))
+
+ # continue
+
+ for i in range(num_samples):
+ triplane_sample = sample_fn(
+ self,
+ (
+ args.batch_size,
+ self.ddpm_model.in_channels
+ if not self.ddpm_model.roll_out else 3 *
+ self.ddpm_model.in_channels, # type: ignore
+ self.diffusion_input_size,
+ self.diffusion_input_size),
+ cond=cond,
+ clip_denoised=args.clip_denoised,
+ model_kwargs=model_kwargs,
+ # mixing_normal=True, # !
+ mixing_normal=self.ddpm_model.mixed_prediction, # !
+ device=dist_util.dev(),
+ **extra_kwargs)
+ th.cuda.empty_cache()
+
+ # render the generated samples
+ for sub_idx in range(triplane_sample.shape[0]):
+ self.render_video_given_triplane(
+ triplane_sample[sub_idx:sub_idx+1],
+ self.rec_model, # compatible with join_model
+ name_prefix=
+ f'{self.step + self.resume_step}_{instance}_{i+sub_idx}',
+ save_img=save_img,
+ render_reference=batch,
+ export_mesh=export_mesh)
+
+ # save gt
+ # video_out = imageio.get_writer(
+ # f'{logger.get_dir()}/triplane_{self.step + self.resume_step}_{i}_reference.mp4',
+ # mode='I',
+ # fps=15,
+ # codec='libx264')
+
+ # for j in range(batch['img'].shape[0]
+ # ): # ! currently only export one plane at a time
+ # cpu_gt = batch['img'][j].cpu().permute(1,2,0).numpy()
+ # cpu_gt = (cpu_gt*127.5)+127.5
+ # video_out.append_data(cpu_gt.astype(np.uint8))
+
+ # video_out.close()
+ # del video_out
+
+ # del triplane_sample
+ # th.cuda.empty_cache()
+
+ self.ddpm_model.train()
+
+
+class TrainLoop3DDiffusionLSGM_crossattn_controlNet(
+ TrainLoop3DDiffusionLSGM_crossattn):
+
+ def __init__(self,
+ *,
+ rec_model,
+ denoise_model,
+ diffusion,
+ sde_diffusion,
+ control_model,
+ control_key,
+ only_mid_control,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ resume_cldm_checkpoint=None,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ schedule_sampler=None,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ ignore_resume_opt=False,
+ freeze_ae=False,
+ denoised_ae=True,
+ triplane_scaling_divider=10,
+ use_amp=False,
+ diffusion_input_size=224,
+ normalize_clip_encoding=False,
+ scale_clip_encoding=1,
+ cfg_dropout_prob=0,
+ cond_key='img_sr',
+ use_eos_feature=False,
+ compile=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ denoise_model=denoise_model,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ control_model=control_model,
+ control_key=control_key,
+ only_mid_control=only_mid_control,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ resume_cldm_checkpoint=resume_cldm_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ schedule_sampler=schedule_sampler,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ ignore_resume_opt=ignore_resume_opt,
+ freeze_ae=freeze_ae,
+ denoised_ae=denoised_ae,
+ triplane_scaling_divider=triplane_scaling_divider,
+ use_amp=use_amp,
+ diffusion_input_size=diffusion_input_size,
+ normalize_clip_encoding=normalize_clip_encoding,
+ scale_clip_encoding=scale_clip_encoding,
+ cfg_dropout_prob=cfg_dropout_prob,
+ cond_key=cond_key,
+ use_eos_feature=use_eos_feature,
+ compile=compile,
+ **kwargs)
+
+ # st()
+ self.control_model = control_model
+ self.control_key = control_key
+ self.only_mid_control = only_mid_control
+ self.control_scales = [1.0] * 13
+ self.sd_locked = True
+ self._setup_control_model()
+
+ def _setup_control_model(self):
+
+ requires_grad(self.rec_model, False)
+ requires_grad(self.ddpm_model, False)
+
+ self.mp_cldm_trainer = MixedPrecisionTrainer(
+ model=self.control_model,
+ use_fp16=self.use_fp16,
+ fp16_scale_growth=self.fp16_scale_growth,
+ use_amp=self.use_amp,
+ model_name='cldm')
+
+ self.ddp_control_model = DDP(
+ self.control_model,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+
+ requires_grad(self.ddp_control_model, True)
+
+ # ! load trainable copy
+ # TODO
+ # st()
+ try:
+ logger.log(f"load pretrained controlnet, not trainable copy.")
+ self._load_and_sync_parameters(
+ model=self.control_model,
+ model_name='cldm',
+ resume_checkpoint=self.resume_cldm_checkpoint,
+ ) # if available
+ except:
+ logger.log(f"load trainable copy to controlnet")
+ model_state_dict = self.control_model.state_dict()
+ for k, v in self.ddpm_model.state_dict().items():
+ if k in model_state_dict.keys() and v.size(
+ ) == model_state_dict[k].size():
+ model_state_dict[k] = v
+
+ self.control_model.load_state_dict(model_state_dict)
+
+ # self._load_and_sync_parameters(
+ # model=self.control_model,
+ # model_name='ddpm') # load pre-trained SD
+
+ cldm_param = [{
+ 'name': 'cldm.parameters()',
+ 'params': self.control_model.parameters(),
+ }]
+ # if self.sde_diffusion.args.unfix_logit:
+ # self.ddpm_model.mixing_logit.requires_grad_(True)
+ # cldm_param.append({
+ # 'name': 'mixing_logit',
+ # 'params': self.ddpm_model.mixing_logit,
+ # })
+
+ self.opt_cldm = AdamW(cldm_param,
+ lr=self.lr,
+ weight_decay=self.weight_decay)
+ if self.sd_locked:
+ del self.opt
+ del self.mp_trainer
+
+ # add control during inference
+ def apply_model_inference(self, x_noisy, t, c, model_kwargs={}):
+
+ control = self.ddp_control_model(
+ x=x_noisy,
+ # hint=th.cat(c['c_concat'], 1),
+ hint=c['c_concat'],
+ timesteps=t,
+ context=None)
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
+ model_kwargs.update({'control': control})
+
+ return super().apply_model_inference(x_noisy, t, c, model_kwargs)
+
+ def apply_control_model(self, p_sample_batch, cond):
+ x_noisy, t, = (p_sample_batch[k] for k in ('eps_t_p', 't_p'))
+
+ control = self.ddp_control_model(
+ x=x_noisy,
+ # hint=th.cat(cond['c_concat'], 1),
+ hint=cond['c_concat'],
+ timesteps=t,
+ context=None)
+
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
+ return control
+
+ def apply_model(self, p_sample_batch, cond, model_kwargs={}):
+
+ control = self.apply_control_model(p_sample_batch,
+ cond) # len(control): 13
+ model_kwargs.update({'control': control})
+
+ return super().apply_model(p_sample_batch, cond, model_kwargs)
+
+ # cldm loss
+ def ldm_train_step(self, batch, behaviour='cano', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+
+ # ! enable the gradient of both models
+ requires_grad(self.ddp_control_model, True)
+ self.mp_cldm_trainer.zero_grad() # !!!!
+
+ if 'img' in batch:
+ batch_size = batch['img'].shape[0]
+ else:
+ batch_size = len(batch['caption'])
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_cldm_trainer.use_amp):
+
+ loss = th.tensor(0.).to(dist_util.dev())
+
+ if 'latent' in micro:
+ vae_out = {self.latent_name: micro['latent']}
+ else:
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='encoder_vae',
+ ) # pred: (B, 3, 64, 64)
+
+ eps = vae_out[self.latent_name] / self.triplane_scaling_divider
+ # eps = vae_out.pop(self.latent_name)
+
+ if 'bg_plane' in vae_out:
+ eps = th.cat((eps, vae_out['bg_plane']),
+ dim=1) # include background, B 12+4 32 32
+
+ p_sample_batch = self.prepare_ddpm(eps)
+ cond = self.get_c_input(micro, bs=eps.shape[0])
+
+ # ! running diffusion forward
+ ddpm_ret = self.apply_model(p_sample_batch, cond)
+ if self.sde_diffusion.args.p_rendering_loss:
+
+ target = micro
+ pred = self.ddp_rec_model(
+ # latent=vae_out,
+ latent={
+ # **vae_out,
+ self.latent_name: ddpm_ret['pred_x0_p'],
+ 'latent_name': self.latent_name
+ },
+ c=micro['c'],
+ behaviour=self.render_latent_behaviour)
+
+ # vae reconstruction loss
+ with self.ddp_control_model.no_sync(): # type: ignore
+ p_vae_recon_loss, rec_loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+ log_rec3d_loss_dict(rec_loss_dict)
+ # log_rec3d_loss_dict(
+ # dict(p_vae_recon_loss=p_vae_recon_loss, ))
+ loss = p_vae_recon_loss + ddpm_ret[
+ 'p_eps_objective'] # TODO, add obj_weight_t_p?
+ else:
+ loss = ddpm_ret['p_eps_objective'].mean()
+
+ # =====================================================================
+
+ self.mp_cldm_trainer.backward(loss) # joint gradient descent
+ # p self.control_model.input_hint_block[0].bias
+
+ # update ddpm accordingly
+ self.mp_cldm_trainer.optimize(self.opt_cldm)
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ self.log_control_images(vae_out, p_sample_batch, micro, ddpm_ret)
+
+ def run_loop(self):
+ # eval camera
+ camera = th.load('eval_pose.pt', map_location=dist_util.dev())
+
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ # dist_util.synchronize()
+
+ batch = next(self.data)
+ self.run_step(batch, step='ldm_step')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.eval_interval == 0 and self.step != 0:
+ # if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ # self.eval_ddpm_sample()
+ # self.eval_cldm(use_ddim=True, unconditional_guidance_scale=7.5, prompt="") # during training, use image as condition
+ if self.cond_key == 'caption':
+ self.eval_cldm(
+ use_ddim=False,
+ prompt="a voxelized dog",
+ use_train_trajectory=False,
+ camera=camera) # fix condition bug first
+ else:
+ pass # TODO
+ # self.eval_cldm(use_ddim=False,
+ # prompt="",
+ # use_train_trajectory=False,
+ # camera=camera) # fix condition bug first
+ # if self.sde_diffusion.args.train_vae:
+ # self.eval_loop()
+
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save(self.mp_cldm_trainer,
+ self.mp_cldm_trainer.model_name)
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ # if self.sde_diffusion.args.train_vae:
+ # self.save(self.mp_trainer_rec,
+ # self.mp_trainer_rec.model_name)
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ # self.save(self.mp_trainer,
+ # self.mp_trainer.model_name) # rec and ddpm all fixed.
+ # st()
+ # self.save(self.mp_trainer_canonical_cvD, 'cvD')
diff --git a/nsr/lsgm/train_util_diffusion_lsgm.py b/nsr/lsgm/train_util_diffusion_lsgm.py
new file mode 100644
index 0000000000000000000000000000000000000000..661146146b2b67cdaad65fe6f7baefab56f95ab6
--- /dev/null
+++ b/nsr/lsgm/train_util_diffusion_lsgm.py
@@ -0,0 +1,583 @@
+"""
+Modified from:
+https://github.com/NVlabs/LSGM/blob/main/training_obj_joint.py
+"""
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+from typing import Any
+
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+# from .train_util import TrainLoop3DRec
+from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_loss_dict,
+ log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+from guided_diffusion.gaussian_diffusion import ModelMeanType
+
+import dnnlib
+from dnnlib.util import calculate_adaptive_weight
+
+from ..train_util_diffusion import TrainLoop3DDiffusion
+from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD
+
+
+class TrainLoop3DDiffusionLSGM(TrainLoop3DDiffusion,TrainLoop3DcvD_nvsD_canoD):
+ def __init__(self, *, rec_model, denoise_model, diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, schedule_sampler=None, weight_decay=0, lr_anneal_steps=0, iterations=10001, ignore_resume_opt=False, freeze_ae=False, denoised_ae=True, triplane_scaling_divider=10, use_amp=False, diffusion_input_size=224, **kwargs):
+ super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, ignore_resume_opt=ignore_resume_opt, freeze_ae=freeze_ae, denoised_ae=denoised_ae, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, **kwargs)
+
+ def run_step(self, batch, step='g_step'):
+
+ if step == 'diffusion_step_rec':
+ self.forward_diffusion(batch, behaviour='diffusion_step_rec')
+ _ = self.mp_trainer_rec.optimize(self.opt_rec) # TODO, update two groups of parameters
+ took_step_ddpm = self.mp_trainer.optimize(self.opt) # TODO, update two groups of parameters
+
+ if took_step_ddpm:
+ self._update_ema() # g_ema # TODO, ema only needs to track ddpm, remove ema tracking in rec
+
+ elif step == 'd_step_rec':
+ self.forward_D(batch, behaviour='rec')
+ # _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+ _ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD)
+
+ elif step == 'diffusion_step_nvs':
+ self.forward_diffusion(batch, behaviour='diffusion_step_nvs')
+ _ = self.mp_trainer_rec.optimize(self.opt_rec) # TODO, update two groups of parameters
+ took_step_ddpm = self.mp_trainer.optimize(self.opt) # TODO, update two groups of parameters
+
+ if took_step_ddpm:
+ self._update_ema() # g_ema
+
+ elif step == 'd_step_nvs':
+ self.forward_D(batch, behaviour='nvs')
+ _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch, cond = next(self.data)
+ # if batch is None:
+ # batch = next(self.data)
+ # self.run_step(batch, 'g_step_rec')
+
+ batch = next(self.data)
+ self.run_step(batch, step='diffusion_step_rec')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step_rec')
+
+ # batch = next(self.data)
+ # self.run_step(batch, 'g_step_nvs')
+
+ batch = next(self.data)
+ self.run_step(batch, step='diffusion_step_nvs')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step_nvs')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ # if self.step % self.eval_interval == 0 and self.step != 0:
+ if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ self.eval_loop()
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name)
+ self.save(self.mp_trainer_cvD, 'cvD')
+ self.save(self.mp_trainer_canonical_cvD, 'cano_cvD')
+
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name)
+ self.save(self.mp_trainer_cvD, 'cvD')
+ self.save(self.mp_trainer_canonical_cvD, 'cano_cvD')
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ self.save(self.mp_trainer_canonical_cvD, 'cvD')
+
+ def forward_diffusion(self, batch, behaviour='rec', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+
+ self.ddp_cano_cvD.requires_grad_(False)
+ self.ddp_nvs_cvD.requires_grad_(False)
+
+ self.ddp_model.requires_grad_(True)
+ self.ddp_rec_model.requires_grad_(True)
+
+ # if behaviour != 'diff' and 'rec' in behaviour:
+ # if behaviour != 'diff' and 'rec' in behaviour: # pure diffusion step
+ # self.ddp_rec_model.requires_grad_(True)
+ for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore
+ ): # type: ignore
+ param.requires_grad_(False) # ! disable triplane_decoder grad in each iteration indepenently;
+ # else:
+
+ self.mp_trainer_rec.zero_grad()
+ self.mp_trainer.zero_grad()
+
+ # ! no 'sds' step now, both add sds grad back to ViT
+
+ # assert behaviour != 'sds'
+ # if behaviour == 'sds':
+ # else:
+ # self.ddp_ddpm_model.requires_grad_(True)
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ vae_nelbo_loss = th.tensor(0.0).to(dist_util.dev())
+ vision_aided_loss = th.tensor(0.0).to(dist_util.dev())
+ denoise_loss = th.tensor(0.0).to(dist_util.dev())
+ d_weight = th.tensor(0.0).to(dist_util.dev())
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp
+ and not self.freeze_ae):
+
+ # apply vae
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='enc_dec_wo_triplane') # pred: (B, 3, 64, 64)
+
+
+ if behaviour == 'diffusion_step_rec':
+ target = micro
+ pred = self.ddp_rec_model(latent=vae_out,
+ c=micro['c'],
+ behaviour='triplane_dec')
+
+ # vae reconstruction loss
+ if last_batch or not self.use_ddp:
+ vae_nelbo_loss, loss_dict = self.loss_class(pred,
+ target,
+ test_mode=False)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ vae_nelbo_loss, loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+
+ last_layer = self.ddp_rec_model.module.decoder.triplane_decoder.decoder.net[ # type: ignore
+ -1].weight # type: ignore
+
+ if 'image_sr' in pred:
+ vision_aided_loss = self.ddp_cano_cvD(
+ 0.5 * pred['image_sr'] +
+ 0.5 * th.nn.functional.interpolate(
+ pred['image_raw'],
+ size=pred['image_sr'].shape[2:],
+ mode='bilinear'),
+ for_G=True).mean() # [B, 1] shape
+ else:
+ vision_aided_loss = self.ddp_cano_cvD(
+ pred['image_raw'], for_G=True
+ ).mean(
+ ) # [B, 1] shape
+
+ d_weight = calculate_adaptive_weight(
+ vae_nelbo_loss,
+ vision_aided_loss,
+ last_layer,
+ # disc_weight_max=1) * 1
+ disc_weight_max=1) * self.loss_class.opt.rec_cvD_lambda
+ # d_weight = self.loss_class.opt.rec_cvD_lambda # since decoder is fixed here. set to 0.001
+
+ vision_aided_loss *= d_weight
+
+ # d_weight = self.loss_class.opt.rec_cvD_lambda
+ loss_dict.update({
+ 'vision_aided_loss/G_rec':
+ vision_aided_loss,
+ 'd_weight_G_rec':
+ d_weight,
+ })
+
+ log_rec3d_loss_dict(loss_dict)
+
+ elif behaviour == 'diffusion_step_nvs':
+
+ novel_view_c = th.cat([micro['c'][1:], micro['c'][:1]])
+
+ pred = self.ddp_rec_model(latent=vae_out,
+ c=novel_view_c,
+ behaviour='triplane_dec')
+
+ if 'image_sr' in pred:
+ vision_aided_loss = self.ddp_nvs_cvD(
+ # pred_for_rec['image_sr'],
+ 0.5 * pred['image_sr'] +
+ 0.5 * th.nn.functional.interpolate(
+ pred['image_raw'],
+ size=pred['image_sr'].shape[2:],
+ mode='bilinear'),
+ for_G=True).mean() # [B, 1] shape
+ else:
+ vision_aided_loss = self.ddp_nvs_cvD(
+ pred['image_raw'], for_G=True
+ ).mean(
+ ) # [B, 1] shape
+
+ d_weight = self.loss_class.opt.nvs_cvD_lambda
+ vision_aided_loss *= d_weight
+
+ log_rec3d_loss_dict({
+ 'vision_aided_loss/G_nvs':
+ vision_aided_loss,
+ })
+
+ # ae_loss = th.tensor(0.0).to(dist_util.dev())
+
+ # elif behaviour == 'diff':
+ # self.ddp_rec_model.requires_grad_(False)
+ # # assert self.ddp_rec_model.module.requires_grad == False, 'freeze ddpm_rec for pure diff step'
+ else:
+ raise NotImplementedError(behaviour)
+ # assert behaviour == 'sds'
+
+ # pred = None
+
+ # if behaviour != 'sds': # also train diffusion
+ # assert pred is not None
+
+ # TODO, train diff and sds together, available?
+ eps = vae_out[self.latent_name]
+
+ # if behaviour != 'sds':
+ # micro_to_denoise.detach_()
+ eps.requires_grad_(True) # single stage diffusion
+
+ t, weights = self.schedule_sampler.sample(
+ eps.shape[0], dist_util.dev())
+ noise = th.randn(size=vae_out.size(), device='cuda') # note that this noise value is currently shared!
+
+ model_kwargs = {}
+
+ # ?
+ # or directly use SSD NeRF version?
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+
+ # ! handle the sampling
+
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_p, args.iw_subvp_like_vp_sde)
+ eps_t_p = diffusion.sample_q(vae_out, noise, var_t_p, m_t_p)
+
+ # in case we want to train q (vae) with another batch using a different sampling scheme for times t
+ if args.iw_sample_q in ['ll_uniform', 'll_iw']:
+ t_q, var_t_q, m_t_q, obj_weight_t_q, _, g2_t_q = \
+ diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_q, args.iw_subvp_like_vp_sde)
+ eps_t_q = diffusion.sample_q(vae_out, noise, var_t_q, m_t_q)
+
+ eps_t_p = eps_t_p.detach().requires_grad_(True)
+ eps_t = th.cat([eps_t_p, eps_t_q], dim=0)
+ var_t = th.cat([var_t_p, var_t_q], dim=0)
+ t = th.cat([t_p, t_q], dim=0)
+ noise = th.cat([noise, noise], dim=0)
+ else:
+ eps_t, m_t, var_t, t, g2_t = eps_t_p, m_t_p, var_t_p, t_p, g2_t_p
+
+ # run the diffusion
+
+ # mixing normal trick
+ # TODO, create a new partial training_losses function
+ mixing_component = diffusion.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction) # TODO, which should I use?
+ params = utils.get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
+
+ # nelbo loss with kl balancing
+
+
+
+
+ # ! remainign parts of cross entropy in likelihook training
+
+ cross_entropy_per_var += diffusion.cross_entropy_const(args.time_eps)
+ cross_entropy = th.sum(cross_entropy_per_var, dim=[1, 2, 3])
+ cross_entropy += remaining_neg_log_p_total # for remaining scales if there is any
+ all_neg_log_p = vae.decompose_eps(cross_entropy_per_var)
+ all_neg_log_p.extend(remaining_neg_log_p_per_ver) # add the remaining neg_log_p
+ kl_all_list, kl_vals_per_group, kl_diag_list = utils.kl_per_group_vada(all_log_q, all_neg_log_p)
+
+
+ kl_coeff = 1.0
+
+ # ! calculate p/q loss;
+ # ? no spectral regularizer here
+ # ? try adding grid_clip and sn later on.
+ q_loss = th.mean(nelbo_loss)
+ p_loss = th.mean(p_objective)
+
+ # backpropagate q_loss for vae and update vae params, if trained
+ if args.train_vae:
+ grad_scalar.scale(q_loss).backward(retain_graph=utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q))
+ utils.average_gradients(vae.parameters(), args.distributed)
+ if args.grad_clip_max_norm > 0.: # apply gradient clipping
+ grad_scalar.unscale_(vae_optimizer)
+ th.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=args.grad_clip_max_norm)
+ grad_scalar.step(vae_optimizer)
+
+ # if we use different p and q objectives or are not training the vae, discard gradients and backpropagate p_loss
+ if utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q) or not args.train_vae:
+ if args.train_vae:
+ # discard current gradients computed by weighted loss for VAE
+ dae_optimizer.zero_grad()
+
+ # compute gradients with unweighted loss
+ grad_scalar.scale(p_loss).backward()
+
+ # update dae parameters
+ utils.average_gradients(dae.parameters(), args.distributed)
+ if args.grad_clip_max_norm > 0.: # apply gradient clipping
+ grad_scalar.unscale_(dae_optimizer)
+ th.nn.utils.clip_grad_norm_(dae.parameters(), max_norm=args.grad_clip_max_norm)
+ grad_scalar.step(dae_optimizer)
+
+
+ # unpack separate objectives, in case we want to train q (vae) using a different sampling scheme for times t
+ if args.iw_sample_q in ['ll_uniform', 'll_iw']:
+ l2_term_p, l2_term_q = th.chunk(l2_term, chunks=2, dim=0)
+ p_objective = th.sum(obj_weight_t_p * l2_term_p, dim=[1, 2, 3])
+ # cross_entropy_per_var = obj_weight_t_q * l2_term_q
+ else:
+ p_objective = th.sum(obj_weight_t_p * l2_term, dim=[1, 2, 3])
+ # cross_entropy_per_var = obj_weight_t_q * l2_term
+
+ # print(micro_to_denoise.min(), micro_to_denoise.max())
+ compute_losses = functools.partial(
+ self.diffusion.training_losses,
+ self.ddp_model,
+ eps, # x_start
+ t,
+ model_kwargs=model_kwargs,
+ return_detail=True)
+
+ # ! DDPM step
+ if last_batch or not self.use_ddp:
+ losses = compute_losses()
+ # denoised_out = denoised_fn()
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ losses = compute_losses()
+
+ if isinstance(self.schedule_sampler, LossAwareSampler):
+ self.schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach())
+
+ denoise_loss = (losses["loss"] * weights).mean()
+
+ x_t = losses.pop('x_t')
+ model_output = losses.pop('model_output')
+ diffusion_target = losses.pop('diffusion_target')
+ alpha_bar = losses.pop('alpha_bar')
+
+ log_loss_dict(self.diffusion, t,
+ {k: v * weights
+ for k, v in losses.items()})
+
+ # if behaviour == 'sds':
+ # ! calculate sds grad, and add to the grad of
+
+ # if 'rec' in behaviour and self.loss_class.opt.sds_lamdba > 0: # only enable sds along with rec step
+ # w = (
+ # 1 - alpha_bar**2
+ # ) / self.triplane_scaling_divider * self.loss_class.opt.sds_lamdba # https://github.com/ashawkey/stable-dreamfusion/issues/106
+ # sds_grad = denoise_loss.clone().detach(
+ # ) * w # * https://pytorch.org/docs/stable/generated/th.Tensor.detach.html. detach() returned Tensor share the same storage with previous one. add clone() here.
+
+ # # ae_loss = AddGradient.apply(latent[self.latent_name], sds_grad) # add sds_grad during backward
+
+ # def sds_hook(grad_to_add):
+
+ # def modify_grad(grad):
+ # return grad + grad_to_add # add the sds grad to the original grad for BP
+
+ # return modify_grad
+
+ # eps[self.latent_name].register_hook(
+ # sds_hook(sds_grad)) # merge sds grad with rec/nvs ae step
+
+ loss = vae_nelbo_loss + denoise_loss + vision_aided_loss # caluclate loss within AMP
+
+ # ! cvD loss
+
+ # exit AMP before backward
+ self.mp_trainer_rec.backward(loss)
+ self.mp_trainer.backward(loss)
+
+ # TODO, merge visualization with original AE
+ # =================================== denoised AE log part ===================================
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0 and behaviour != 'diff':
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ # st()
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ # if 'image_sr' in pred: # TODO
+ # pred_img = th.cat(
+ # [self.pool_512(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat(
+ # [self.pool_512(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_512(pred_depth)
+ # gt_depth = self.pool_512(gt_depth)
+
+ gt_vis = th.cat(
+ [
+ gt_img, micro['img'], micro['img'],
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=x_t[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ # if denoised_out is None:
+ # if not self.denoised_ae:
+ # denoised_out = denoised_fn()
+
+ if self.diffusion.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = model_output
+ else: # * used here
+ pred_xstart = self.diffusion._predict_xstart_from_eps(
+ x_t=x_t, t=t, eps=model_output)
+
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=pred_xstart[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ # denoised_out = denoised_ae_pred
+
+ # if not self.denoised_ae:
+ # denoised_ae_pred = self.ddp_rec_model(
+ # img=None,
+ # c=micro['c'][0:1],
+ # latent=denoised_out['pred_xstart'][0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically
+ # behaviour=self.render_latent_behaviour)
+ # else:
+ # assert denoised_ae_pred is not None
+ # denoised_ae_pred['image_raw'] = denoised_ae_pred[
+ # 'image_raw'][0:1]
+
+ # print(pred_img.shape)
+ # print('denoised_ae:', self.denoised_ae)
+
+ pred_vis = th.cat([
+ pred_img[0:1], noised_ae_pred['image_raw'][0:1],
+ denoised_ae_pred['image_raw'][0:1],
+ pred_depth[0:1].repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+ # s
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis = th.cat([
+ # self.pool_128(micro['img']), x_t[:, :3, ...],
+ # denoised_out['pred_xstart'][:, :3, ...]
+ # ],
+ # dim=-1)[0].permute(
+ # 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg'
+ )
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg'
+ )
+
+ th.cuda.empty_cache()
diff --git a/nsr/lsgm/train_util_diffusion_lsgm_noD_joint.py b/nsr/lsgm/train_util_diffusion_lsgm_noD_joint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a48a8b23142003d77b4c788f3ac47d3047b6081a
--- /dev/null
+++ b/nsr/lsgm/train_util_diffusion_lsgm_noD_joint.py
@@ -0,0 +1,1366 @@
+"""
+Modified from:
+https://github.com/NVlabs/LSGM/blob/main/training_obj_joint.py
+"""
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+from typing import Any
+
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+# from .train_util import TrainLoop3DRec
+from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_loss_dict,
+ log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+from guided_diffusion.gaussian_diffusion import ModelMeanType
+
+from dnnlib.util import requires_grad
+from dnnlib.util import calculate_adaptive_weight
+
+from ..train_util_diffusion import TrainLoop3DDiffusion, TrainLoopDiffusionWithRec
+from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD
+
+from guided_diffusion.continuous_diffusion_utils import get_mixed_prediction, different_p_q_objectives, kl_per_group_vada, kl_balancer
+# import utils as lsgm_utils
+
+
+class JointDenoiseRecModel(th.nn.Module):
+
+ def __init__(self, ddpm_model, rec_model, diffusion_input_size) -> None:
+ super().__init__()
+ # del ddpm_model
+ # th.cuda.empty_cache()
+ # self.ddpm_model = th.nn.Identity()
+ self.ddpm_model = ddpm_model
+ self.rec_model = rec_model
+
+ self._setup_latent_stat(diffusion_input_size)
+
+ def _setup_latent_stat(self, diffusion_input_size): # for dynamic EMA tracking.
+ latent_size = (
+ 1,
+ self.ddpm_model.in_channels, # type: ignore
+ diffusion_input_size,
+ diffusion_input_size),
+
+ self.ddpm_model.register_buffer(
+ 'ema_latent_std',
+ th.ones(*latent_size).to(dist_util.dev()), persistent=True)
+ self.ddpm_model.register_buffer(
+ 'ema_latent_mean',
+ th.zeros(*latent_size).to(dist_util.dev()), persistent=True)
+
+ # TODO, lint api.
+ def forward(
+ self,
+ *args,
+ model_name='ddpm',
+ **kwargs,
+ ):
+ if model_name == 'ddpm':
+ return self.ddpm_model(*args, **kwargs)
+ elif model_name == 'rec':
+ return self.rec_model(*args, **kwargs)
+ else:
+ raise NotImplementedError(model_name)
+
+
+# TODO, merge with train_util_diffusion.py later
+class SDETrainLoopJoint(TrainLoopDiffusionWithRec):
+ """A dataclass with some required attribtues; copied from guided_diffusion TrainLoop
+ """
+
+ def __init__(
+ self,
+ rec_model,
+ denoise_model,
+ diffusion, # not used
+ sde_diffusion,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ triplane_scaling_divider=1,
+ use_amp=False,
+ diffusion_input_size=224,
+ **kwargs,
+ ) -> None:
+
+ joint_model = JointDenoiseRecModel(denoise_model, rec_model, diffusion_input_size)
+ super().__init__(
+ model=joint_model,
+ diffusion=diffusion, # just for sampling
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ eval_interval=eval_interval,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ use_amp=use_amp,
+ model_name='joint_denoise_rec_model',
+ iterations=iterations,
+ triplane_scaling_divider=triplane_scaling_divider,
+ diffusion_input_size=diffusion_input_size,
+ **kwargs)
+ self.sde_diffusion = sde_diffusion
+ # setup latent scaling factor
+
+ # ! integrate the init_params_group for rec model
+ def _setup_model(self):
+
+ super()._setup_model()
+ self.ddp_rec_model = functools.partial(self.model, model_name='rec')
+ self.ddp_ddpm_model = functools.partial(self.model, model_name='ddpm')
+
+ self.rec_model = self.ddp_model.module.rec_model
+ self.ddpm_model = self.ddp_model.module.ddpm_model # compatability
+
+ # TODO, required?
+ # for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore
+ # ): # type: ignore
+ # param.requires_grad_(
+ # False
+ # ) # ! disable triplane_decoder grad in each iteration indepenently;
+
+ def _load_model(self):
+ # TODO, for currently compatability
+ if 'joint' in self.resume_checkpoint: # load joint directly
+ self._load_and_sync_parameters(model=self.model, model_name=self.model_name)
+ else: # from scratch
+ self._load_and_sync_parameters(model=self.rec_model, model_name='rec')
+ self._load_and_sync_parameters(model=self.ddpm_model,
+ model_name='ddpm')
+
+ def _setup_opt(self):
+ # TODO, two optims groups.
+ self.opt = AdamW([{
+ 'name': 'ddpm',
+ 'params': self.ddpm_model.parameters(),
+ }],
+ lr=self.lr,
+ weight_decay=self.weight_decay)
+
+ # for rec_param_group in self._init_optim_groups(self.rec_model):
+ # self.opt.add_param_group(rec_param_group)
+ print(self.opt)
+
+
+class TrainLoop3DDiffusionLSGMJointnoD(SDETrainLoopJoint):
+
+ def __init__(self,
+ *,
+ rec_model,
+ denoise_model,
+ sde_diffusion,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ resume_cldm_checkpoint=None,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ triplane_scaling_divider=1,
+ use_amp=False,
+ diffusion_input_size=224,
+ diffusion_ce_anneal=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ denoise_model=denoise_model,
+ sde_diffusion=sde_diffusion,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ triplane_scaling_divider=triplane_scaling_divider,
+ use_amp=use_amp,
+ diffusion_input_size=diffusion_input_size,
+ **kwargs)
+
+ sde_diffusion.args.batch_size = batch_size
+ self.latent_name = 'latent_normalized_2Ddiffusion' # normalized triplane latent
+ self.render_latent_behaviour = 'decode_after_vae' # directly render using triplane operations
+ self.diffusion_ce_anneal = diffusion_ce_anneal
+ # assert sde_diffusion.args.train_vae
+
+ def prepare_ddpm(self, eps, mode='p'):
+
+ log_rec3d_loss_dict({
+ f'eps_mean': eps.mean(),
+ f'eps_std': eps.std([1,2,3]).mean(0),
+ f'eps_max': eps.max()
+ })
+
+ args = self.sde_diffusion.args
+ # sample noise
+ noise = th.randn(size=eps.size(), device=eps.device
+ ) # note that this noise value is currently shared!
+
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+ if mode == 'p':
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ self.sde_diffusion.iw_quantities(args.iw_sample_p, noise.shape[0]) # TODO, q not used, fall back to original ddpm implementation
+ else:
+ assert mode == 'q'
+ # assert args.iw_sample_q in ['ll_uniform', 'll_iw']
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ self.sde_diffusion.iw_quantities(args.iw_sample_q, noise.shape[0]) # TODO, q not used, fall back to original ddpm implementation
+ eps_t_p = self.sde_diffusion.sample_q(eps, noise, var_t_p, m_t_p)
+ # ! important
+ # eps_t_p = eps_t_p.detach().requires_grad_(True)
+ # logsnr_p = self.sde_diffusion.log_snr(m_t_p,
+ # var_t_p) # for p only
+ logsnr_p = self.sde_diffusion.log_snr(m_t_p, var_t_p) # for p only
+
+ return {
+ 'noise': noise,
+ 't_p': t_p,
+ 'eps_t_p': eps_t_p,
+ 'logsnr_p': logsnr_p,
+ 'obj_weight_t_p': obj_weight_t_p,
+ 'var_t_p': var_t_p,
+ 'm_t_p': m_t_p,
+ 'eps': eps,
+ 'mode': mode
+ }
+
+ # merged from noD.py
+
+ def ce_weight(self):
+ return self.loss_class.opt.ce_lambda
+
+ def apply_model(self, p_sample_batch, **model_kwargs):
+ args = self.sde_diffusion.args
+ # args = self.sde_diffusion.args
+ noise, eps_t_p, t_p, logsnr_p, obj_weight_t_p, var_t_p, m_t_p = (
+ p_sample_batch[k] for k in ('noise', 'eps_t_p', 't_p', 'logsnr_p',
+ 'obj_weight_t_p', 'var_t_p', 'm_t_p'))
+
+ pred_eps_p, pred_x0_p = self.ddpm_step(eps_t_p, t_p, logsnr_p, var_t_p, m_t_p,
+ **model_kwargs)
+
+ # ! eps loss equivalent to snr weighting of x0 loss, see "progressive distillation"
+ with self.ddp_model.no_sync(): # type: ignore
+ if args.loss_type == 'eps':
+ l2_term_p = th.square(pred_eps_p - noise) # ? weights
+ elif args.loss_type == 'x0':
+ # l2_term_p = th.square(pred_eps_p - p_sample_batch['eps']) # ? weights
+ l2_term_p = th.square(
+ pred_x0_p - p_sample_batch['eps'].detach()) # ? weights
+ # if args.loss_weight == 'snr':
+ # obj_weight_t_p = th.sigmoid(th.exp(logsnr_p))
+ else:
+ raise NotImplementedError(args.loss_type)
+
+ # p_eps_objective = th.mean(obj_weight_t_p * l2_term_p)
+ p_eps_objective = obj_weight_t_p * l2_term_p
+
+ if p_sample_batch['mode'] == 'q':
+ ce_weight = self.ce_weight()
+ p_eps_objective = p_eps_objective * ce_weight
+
+ log_rec3d_loss_dict({
+ 'ce_weight': ce_weight,
+ })
+
+
+ log_rec3d_loss_dict({
+ f"{p_sample_batch['mode']}_loss":
+ p_eps_objective.mean(),
+ 'mixing_logit':
+ self.ddp_ddpm_model(x=None,
+ timesteps=None,
+ get_attr='mixing_logit').detach(),
+ })
+
+ return {
+ 'pred_eps_p': pred_eps_p,
+ 'eps_t_p': eps_t_p,
+ 'p_eps_objective': p_eps_objective,
+ 'pred_x0_p': pred_x0_p,
+ 'logsnr_p': logsnr_p
+ }
+
+ def ddpm_step(self, eps_t, t, logsnr, var_t, m_t, **model_kwargs):
+ """helper function for ddpm predictions; returns predicted eps, x0 and logsnr.
+
+ args notes:
+ eps_t is x_noisy
+ """
+ args = self.sde_diffusion.args
+ pred_params = self.ddp_ddpm_model(x=eps_t, timesteps=t, **model_kwargs)
+ # logsnr = self.sde_diffusion.log_snr(m_t, var_t) # for p only
+ if args.pred_type in ['eps', 'v']:
+ if args.pred_type == 'v':
+ pred_eps = self.sde_diffusion._predict_eps_from_z_and_v(
+ pred_params, var_t, eps_t, m_t
+ )
+ # pred_x0 = self.sde_diffusion._predict_x0_from_z_and_v(
+ # pred_params, var_t, eps_t, m_t) # ! verified
+ else:
+ pred_eps = pred_params
+
+ # mixing normal trick
+ mixing_component = self.sde_diffusion.mixing_component(
+ eps_t, var_t, t, enabled=True) # z_t * sigma_t
+ pred_eps = get_mixed_prediction(
+ True, pred_eps,
+ self.ddp_ddpm_model(x=None,
+ timesteps=None,
+ get_attr='mixing_logit'), mixing_component)
+
+ pred_x0 = self.sde_diffusion._predict_x0_from_eps( eps_t, pred_eps, logsnr) # for VAE loss, denosied latent
+ # eps, pred_params, logsnr) # for VAE loss, denosied latent
+ elif args.pred_type == 'x0':
+ # ! pred_x0_mixed = alpha * pred_x0 + (1-alpha) * z_t * alpha_t
+ pred_x0 = pred_params # how to mix?
+
+ # mixing normal trick
+ mixing_component = self.sde_diffusion.mixing_component_x0(
+ eps_t, var_t, t, enabled=True) # z_t * alpha_t
+ pred_x0 = get_mixed_prediction(
+ True, pred_x0,
+ self.ddp_ddpm_model(x=None,
+ timesteps=None,
+ get_attr='mixing_logit'), mixing_component)
+
+ pred_eps = self.sde_diffusion._predict_eps_from_x0(
+ eps_t, pred_x0, logsnr)
+ else:
+ raise NotImplementedError(f'{args.pred_type} not implemented.')
+
+ log_rec3d_loss_dict({
+ f'pred_x0_mean': pred_x0.mean(),
+ f'pred_x0_std': pred_x0.std([1,2,3]).mean(0),
+ f'pred_x0_max': pred_x0.max(),
+ })
+
+ return pred_eps, pred_x0
+
+ def ddpm_loss(self, noise, pred_eps, last_batch):
+
+ # ! eps loss equivalent to snr weighting of x0 loss, see "progressive distillation"
+ if last_batch or not self.use_ddp:
+ l2_term = th.square(pred_eps - noise)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ l2_term = th.square(pred_eps - noise) # ? weights
+ return l2_term
+
+ def run_step(self, batch, step='diffusion_step_rec'):
+
+ if step == 'ce_ddpm_step':
+ self.ce_ddpm_step(batch)
+ elif step == 'p_rendering_step':
+ self.p_rendering_step(batch)
+
+ elif step == 'eps_step':
+ self.eps_step(batch)
+
+ # ! both took ddpm step
+ self._update_ema()
+
+ self._anneal_lr()
+ self.log_step()
+
+ @th.inference_mode()
+ def _post_run_loop(self):
+
+ # if self.step % self.eval_interval =r 0 and self.step != 0:
+ # if self.step % self.eval_interval == 0:
+ # if dist_util.get_rank() == 0:
+ # self.eval_ddpm_sample(
+ # self.rec_model,
+ # # self.ddpm_model
+ # ) # ! only support single GPU inference now.
+ # if self.sde_diffusion.args.train_vae:
+ # self.eval_loop(self.ddp_rec_model)
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank() == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ # if self.step % self.eval_interval == 0 and self.step != 0:
+ if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ self.eval_ddpm_sample(self.ddp_rec_model)
+ if self.sde_diffusion.args.train_vae:
+ self.eval_loop(self.ddp_rec_model)
+
+ if self.step % self.save_interval == 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ exit()
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ batch = next(self.data)
+ self.run_step(batch, step='ce_ddpm_step')
+
+ self._post_run_loop()
+
+ # batch = next(self.data)
+ # self.run_step(batch, step='p_rendering_step')
+
+ def ce_ddpm_step(self, batch, behaviour='rec', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+ args = self.sde_diffusion.args
+ assert args.train_vae
+
+ requires_grad(self.rec_model, args.train_vae)
+ requires_grad(self.ddpm_model, True)
+
+ # TODO merge?
+ self.mp_trainer.zero_grad()
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ q_vae_recon_loss = th.tensor(0.0).to(dist_util.dev())
+ # vision_aided_loss = th.tensor(0.0).to(dist_util.dev())
+ # denoise_loss = th.tensor(0.0).to(dist_util.dev())
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+
+ # ! part 1: train vae with CE; ddpm fixed
+ # ! TODO, add KL_all_list? vae.decompose
+ with th.set_grad_enabled(args.train_vae):
+ # vae_out = self.ddp_rec_model(
+ # img=micro['img_to_encoder'],
+ # c=micro['c'],
+ # behaviour='encoder_vae',
+ # ) # pred: (B, 3, 64, 64)
+ # TODO, no need to render if not SSD; no need to do ViT decoder if only the latent is needed. update later
+ # if args.train_vae:
+ # if args.add_rendering_loss:
+ # if args.joint_train:
+ # with th.set_grad_enabled(args.train_vae):
+ pred = self.ddp_rec_model(
+ # latent=vae_out,
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ )
+ # behaviour=self.render_latent_behaviour)
+
+ # vae reconstruction loss
+ if last_batch or not self.use_ddp:
+ q_vae_recon_loss, loss_dict = self.loss_class(
+ pred, micro, test_mode=False)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ q_vae_recon_loss, loss_dict = self.loss_class(
+ pred, micro, test_mode=False)
+
+ log_rec3d_loss_dict(loss_dict)
+ # '''
+
+ # ! calculate p/q loss;
+ # nelbo_loss = balanced_kl * self.loss_class.opt.ce_balanced_kl + q_vae_recon_loss
+ nelbo_loss = q_vae_recon_loss
+ q_loss = th.mean(nelbo_loss)
+
+ # st()
+
+ # all_log_q = [vae_out['log_q_2Ddiffusion']]
+ # eps = vae_out[self.latent_name]
+ # all_log_q = [pred['log_q_2Ddiffusion']]
+ eps = pred[self.latent_name]
+
+ if not args.train_vae:
+ eps.requires_grad_(True) # single stage diffusion
+
+ # sample noise
+ noise = th.randn(
+ size=eps.size(), device=eps.device
+ ) # note that this noise value is currently shared!
+
+ # in case we want to train q (vae) with another batch using a different sampling scheme for times t
+ '''
+ assert args.iw_sample_q in ['ll_uniform', 'll_iw']
+ t_q, var_t_q, m_t_q, obj_weight_t_q, _, g2_t_q = \
+ self.sde_diffusion.iw_quantities(args.iw_sample_q)
+ eps_t_q = self.sde_diffusion.sample_q(eps, noise, var_t_q,
+ m_t_q)
+
+ # eps_t = th.cat([eps_t_p, eps_t_q], dim=0)
+ # var_t = th.cat([var_t_p, var_t_q], dim=0)
+ # t = th.cat([t_p, t_q], dim=0)
+ # noise = th.cat([noise, noise], dim=0)
+
+ # run the diffusion model
+ if not args.train_vae:
+ eps_t_q.requires_grad_(True) # 2*BS, 12, 16, 16
+
+ # ! For CE guidance.
+ requires_grad(self.ddpm_model_module, False)
+ pred_eps_q, _, _ = self.ddpm_step(eps_t_q, t_q, m_t_q, var_t_q)
+
+ l2_term_q = self.ddpm_loss(noise, pred_eps_q, last_batch)
+
+ # pred_eps = th.cat([pred_eps_p, pred_eps_q], dim=0) # p then q
+
+ # ÇE: nelbo loss with kl balancing
+ assert args.iw_sample_q in ['ll_uniform', 'll_iw']
+ # l2_term_p, l2_term_q = th.chunk(l2_term, chunks=2, dim=0)
+ cross_entropy_per_var = obj_weight_t_q * l2_term_q
+
+ cross_entropy_per_var += self.sde_diffusion.cross_entropy_const(
+ args.sde_time_eps)
+ all_neg_log_p = [cross_entropy_per_var
+ ] # since only one vae group
+
+ kl_all_list, kl_vals_per_group, kl_diag_list = kl_per_group_vada(
+ all_log_q, all_neg_log_p) # return the mean of two terms
+
+ # nelbo loss with kl balancing
+ balanced_kl, kl_coeffs, kl_vals = kl_balancer(kl_all_list,
+ kl_coeff=1.0,
+ kl_balance=False,
+ alpha_i=None)
+ # st()
+
+ log_rec3d_loss_dict(
+ dict(
+ balanced_kl=balanced_kl,
+ l2_term_q=l2_term_q,
+ cross_entropy_per_var=cross_entropy_per_var.mean(),
+ all_log_q=all_log_q[0].mean(),
+ ))
+
+
+ '''
+ # ! update vae for CE
+ # ! single stage diffusion for rec side 1: bind vae prior and diffusion prior
+
+ # ! BP for CE and VAE; quit the AMP context.
+ # if args.train_vae:
+ # self.mp_trainer.backward(q_loss)
+ # _ = self.mp_trainer.optimize(self.opt)
+ # retain_graph=different_p_q_objectives(
+ # args.iw_sample_p,
+ # args.iw_sample_q))
+
+ log_rec3d_loss_dict(
+ dict(q_vae_recon_loss=q_vae_recon_loss,
+ # all_log_q=all_log_q[0].mean(),
+ ))
+
+ # ! adding p loss; enable ddpm gradient
+ # self.mp_trainer.zero_grad()
+ # requires_grad(self.rec_model_module,
+ # False) # could be removed since eps_t_p.detach()
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+
+ # first get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ self.sde_diffusion.iw_quantities(args.iw_sample_p)
+ eps_t_p = self.sde_diffusion.sample_q(eps, noise, var_t_p,
+ m_t_p)
+ eps_t_p = eps_t_p.detach(
+ ) # .requires_grad_(True) # ! update ddpm not rec module
+
+ pred_eps_p, _, = self.ddpm_step(eps_t_p, t_p, m_t_p, var_t_p)
+ l2_term_p = self.ddpm_loss(noise, pred_eps_p, last_batch)
+ p_loss = th.mean(obj_weight_t_p * l2_term_p)
+
+ # ! update ddpm
+ self.mp_trainer.backward(p_loss +
+ q_loss) # just backward for p_loss
+ _ = self.mp_trainer.optimize(self.opt)
+ # requires_grad(self.rec_model_module, True)
+
+ log_rec3d_loss_dict(
+ dict(
+ p_loss=p_loss,
+ mixing_logit=self.ddp_ddpm_model(
+ x=None, timesteps=None,
+ get_attr='mixing_logit').detach(),
+ ))
+
+ # TODO, merge visualization with original AE
+ # =================================== denoised AE log part ===================================
+
+ # ! todo, wrap in a single function
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+
+ with th.no_grad():
+
+ if not args.train_vae:
+ vae_out.pop('posterior') # for calculating kl loss
+ vae_out_for_pred = {
+ k:
+ v[0:1].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in vae_out.items()
+ }
+
+ pred = self.ddp_rec_model(
+ latent=vae_out_for_pred,
+ c=micro['c'][0:1],
+ behaviour=self.render_latent_behaviour)
+ assert isinstance(pred, dict)
+ assert pred is not None
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+
+ if 'image_depth' in pred:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat(
+ [self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat(
+ [self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+ else:
+ gt_img = self.pool_64(gt_img)
+ gt_depth = self.pool_64(gt_depth)
+
+ gt_vis = th.cat(
+ [
+ gt_img,
+ # micro['img'],
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ pred_vis = th.cat([
+ pred_img[0:1], pred_depth[0:1].repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ # f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg'
+ f'{logger.get_dir()}/{self.step+self.resume_step}_{behaviour}.jpg'
+ )
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}_{behaviour}.jpg'
+ )
+
+ th.cuda.empty_cache()
+
+ def eps_step(self, batch, behaviour='rec', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+ args = self.sde_diffusion.args
+
+ requires_grad(self.ddpm_model_module, True)
+ requires_grad(self.rec_model_module, False)
+
+ # TODO?
+ # if args.train_vae:
+ # for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore
+ # ): # type: ignore
+ # param.requires_grad_(
+ # False
+ # ) # ! disable triplane_decoder grad in each iteration indepenently;
+
+ self.mp_trainer.zero_grad()
+
+ # assert args.train_vae
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+ # and args.train_vae):
+
+ # ! part 1: train vae with CE; ddpm fixed
+ # ! TODO, add KL_all_list? vae.decompose
+
+ with th.set_grad_enabled(args.train_vae):
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='encoder_vae',
+ ) # pred: (B, 3, 64, 64)
+ eps = vae_out[self.latent_name]
+
+ # sample noise
+ noise = th.randn(
+ size=eps.size(), device=eps.device
+ ) # note that this noise value is currently shared!
+
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ self.sde_diffusion.iw_quantities(args.iw_sample_p)
+ eps_t_p = self.sde_diffusion.sample_q(eps, noise, var_t_p,
+ m_t_p)
+ logsnr_p = self.sde_diffusion.log_snr(m_t_p,
+ var_t_p) # for p only
+
+ pred_eps_p, pred_x0_p, logsnr_p = self.ddpm_step(
+ eps_t_p, t_p, m_t_p, var_t_p)
+
+ # ! batchify for mixing_component
+ # mixing normal trick
+ mixing_component = self.sde_diffusion.mixing_component(
+ eps_t_p, var_t_p, t_p,
+ enabled=True) # TODO, which should I use?
+ pred_eps_p = get_mixed_prediction(
+ True, pred_eps_p,
+ self.ddp_ddpm_model(x=None,
+ timesteps=None,
+ get_attr='mixing_logit'),
+ mixing_component)
+
+ # ! eps loss equivalent to snr weighting of x0 loss, see "progressive distillation"
+ if last_batch or not self.use_ddp:
+ l2_term_p = th.square(pred_eps_p - noise)
+ else:
+ with self.ddp_ddpm_model.no_sync(): # type: ignore
+ l2_term_p = th.square(pred_eps_p - noise) # ? weights
+
+ p_eps_objective = th.mean(
+ obj_weight_t_p *
+ l2_term_p) * self.loss_class.opt.p_eps_lambda
+
+ log_rec3d_loss_dict(
+ dict(mixing_logit=self.ddp_ddpm_model(
+ x=None, timesteps=None,
+ get_attr='mixing_logit').detach(), ))
+
+ # =====================================================================
+ # ! single stage diffusion for rec side 2: generative feature
+ # if args.p_rendering_loss:
+ # target = micro
+ # pred = self.ddp_rec_model(
+ # # latent=vae_out,
+ # latent={
+ # **vae_out, self.latent_name: pred_x0_p,
+ # 'latent_name': self.latent_name
+ # },
+ # c=micro['c'],
+ # behaviour=self.render_latent_behaviour)
+
+ # # vae reconstruction loss
+ # if last_batch or not self.use_ddp:
+ # p_vae_recon_loss, _ = self.loss_class(pred,
+ # target,
+ # test_mode=False)
+ # else:
+ # with self.ddp_model.no_sync(): # type: ignore
+ # p_vae_recon_loss, _ = self.loss_class(
+ # pred, target, test_mode=False)
+ # log_rec3d_loss_dict(
+ # dict(p_vae_recon_loss=p_vae_recon_loss, ))
+ # p_loss = p_eps_objective + p_vae_recon_loss
+ # else:
+ p_loss = p_eps_objective
+
+ log_rec3d_loss_dict(
+ dict(p_loss=p_loss, p_eps_objective=p_eps_objective))
+
+ # ! to arrange: update vae params
+
+ self.mp_trainer.backward(p_loss)
+
+ # update ddpm accordingly
+ _ = self.mp_trainer.optimize(
+ self.opt) # TODO, update two groups of parameters
+
+ # TODO, merge visualization with original AE
+ # ! todo, merge required
+ # =================================== denoised AE log part ===================================
+ if dist_util.get_rank(
+ ) == 0 and self.step % 500 == 0 and behaviour != 'diff':
+
+ with th.no_grad():
+
+ vae_out.pop('posterior') # for calculating kl loss
+ vae_out_for_pred = {
+ k:
+ v[0:1].to(dist_util.dev())
+ if isinstance(v, th.Tensor) else v
+ for k, v in vae_out.items()
+ }
+
+ pred = self.ddp_rec_model(
+ latent=vae_out_for_pred,
+ c=micro['c'][0:1],
+ behaviour=self.render_latent_behaviour)
+ assert isinstance(pred, dict)
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+
+ if 'image_depth' in pred:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat(
+ [self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat(
+ [self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+ else:
+ gt_img = self.pool_64(gt_img)
+ gt_depth = self.pool_64(gt_depth)
+
+ gt_vis = th.cat(
+ [
+ gt_img, micro['img'], micro['img'],
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ # eps_t_p_3D = eps_t_p.reshape(batch_size, eps_t_p.shape[1]//3, 3, -1) # B C 3 L
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=eps_t_p[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ pred_x0 = self.sde_diffusion._predict_x0_from_eps(
+ eps_t_p, pred_eps_p,
+ logsnr_p) # for VAE loss, denosied latent
+
+ # pred_xstart_3D
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=pred_x0[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ pred_vis = th.cat([
+ pred_img[0:1], noised_ae_pred['image_raw'][0:1],
+ denoised_ae_pred['image_raw'][0:1],
+ pred_depth[0:1].repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}_{behaviour}.jpg'
+ )
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}_{behaviour}.jpg'
+ )
+ del vis, pred_vis, pred_x0, pred_eps_p, micro, vae_out
+
+ th.cuda.empty_cache()
+
+ def p_rendering_step(self, batch, behaviour='rec', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+ args = self.sde_diffusion.args
+
+ requires_grad(self.ddpm_model, True)
+ requires_grad(self.rec_model, args.train_vae)
+
+ # TODO?
+ # if args.train_vae:
+ # for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore
+ # ): # type: ignore
+ # param.requires_grad_(
+ # False
+ # ) # ! disable triplane_decoder grad in each iteration indepenently;
+
+ self.mp_trainer.zero_grad()
+
+ assert args.train_vae
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+ # and args.train_vae):
+
+ # ! part 1: train vae with CE; ddpm fixed
+ # ! TODO, add KL_all_list? vae.decompose
+
+ with th.set_grad_enabled(args.train_vae):
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='encoder_vae',
+ ) # pred: (B, 3, 64, 64)
+ eps = vae_out[self.latent_name]
+
+ # sample noise
+ noise = th.randn(
+ size=eps.size(), device=eps.device
+ ) # note that this noise value is currently shared!
+
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ self.sde_diffusion.iw_quantities(args.iw_sample_p)
+ eps_t_p = self.sde_diffusion.sample_q(eps, noise, var_t_p,
+ m_t_p)
+ logsnr_p = self.sde_diffusion.log_snr(m_t_p,
+ var_t_p) # for p only
+
+ # pred_eps_p, pred_x0_p, logsnr_p = self.ddpm_step(
+ pred_eps_p, pred_x0_p = self.ddpm_step(eps_t_p, t_p, logsnr_p,
+ var_t_p)
+ # eps_t_p, t_p, m_t_p, var_t_p)
+
+ # ! batchify for mixing_component
+ # mixing normal trick
+ # mixing_component = self.sde_diffusion.mixing_component(
+ # eps_t_p, var_t_p, t_p,
+ # enabled=True) # TODO, which should I use?
+ # pred_eps_p = get_mixed_prediction(
+ # True, pred_eps_p,
+ # self.ddp_ddpm_model(x=None,
+ # timesteps=None,
+ # get_attr='mixing_logit'),
+ # mixing_component)
+
+ # ! eps loss equivalent to snr weighting of x0 loss, see "progressive distillation"
+ if last_batch or not self.use_ddp:
+ l2_term_p = th.square(pred_eps_p - noise)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ l2_term_p = th.square(pred_eps_p - noise) # ? weights
+
+ p_eps_objective = th.mean(obj_weight_t_p * l2_term_p)
+ # st()
+
+ log_rec3d_loss_dict(
+ dict(mixing_logit=self.ddp_ddpm_model(
+ x=None, timesteps=None,
+ get_attr='mixing_logit').detach(), ))
+
+ # =====================================================================
+ # ! single stage diffusion for rec side 2: generative feature
+ if args.p_rendering_loss:
+ target = micro
+ pred = self.ddp_rec_model( # re-render
+ latent={
+ **vae_out, self.latent_name: pred_x0_p,
+ 'latent_name': self.latent_name
+ },
+ c=micro['c'],
+ behaviour=self.render_latent_behaviour)
+
+ # vae reconstruction loss
+ if last_batch or not self.use_ddp:
+ pred[self.latent_name] = vae_out[self.latent_name]
+ pred[
+ 'latent_name'] = self.latent_name # just for stats
+ p_vae_recon_loss, rec_loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ p_vae_recon_loss, rec_loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+ log_rec3d_loss_dict(
+ dict(p_vae_recon_loss=p_vae_recon_loss, ))
+
+ for key in rec_loss_dict.keys():
+ if 'latent' in key:
+ log_rec3d_loss_dict({key: rec_loss_dict[key]})
+
+ p_loss = p_eps_objective + p_vae_recon_loss
+ else:
+ p_loss = p_eps_objective
+
+ log_rec3d_loss_dict(
+ dict(p_loss=p_loss, p_eps_objective=p_eps_objective))
+
+ # ! to arrange: update vae params
+
+ self.mp_trainer.backward(p_loss)
+
+ # update ddpm accordingly
+ _ = self.mp_trainer.optimize(
+ self.opt) # TODO, update two groups of parameters
+
+ # TODO, merge visualization with original AE
+ # ! todo, merge required
+ # =================================== denoised AE log part ===================================
+ if dist_util.get_rank(
+ ) == 0 and self.step % 500 == 0 and behaviour != 'diff':
+
+ with th.no_grad():
+
+ vae_out.pop('posterior') # for calculating kl loss
+ vae_out_for_pred = {
+ k:
+ v[0:1].to(dist_util.dev())
+ if isinstance(v, th.Tensor) else v
+ for k, v in vae_out.items()
+ }
+
+ pred = self.ddp_rec_model(
+ latent=vae_out_for_pred,
+ c=micro['c'][0:1],
+ behaviour=self.render_latent_behaviour)
+ assert isinstance(pred, dict)
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+
+ if 'image_depth' in pred:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat(
+ [self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat(
+ [self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+ else:
+ gt_img = self.pool_64(gt_img)
+ gt_depth = self.pool_64(gt_depth)
+
+ gt_vis = th.cat(
+ [
+ gt_img, micro['img'], micro['img'],
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ # eps_t_p_3D = eps_t_p.reshape(batch_size, eps_t_p.shape[1]//3, 3, -1) # B C 3 L
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=eps_t_p[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ pred_x0 = self.sde_diffusion._predict_x0_from_eps(
+ eps_t_p, pred_eps_p,
+ logsnr_p) # for VAE loss, denosied latent
+
+ # pred_xstart_3D
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=pred_x0[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ pred_vis = th.cat([
+ pred_img[0:1], noised_ae_pred['image_raw'][0:1],
+ denoised_ae_pred['image_raw'][0:1],
+ pred_depth[0:1].repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}_{behaviour}.jpg'
+ )
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}_{behaviour}.jpg'
+ )
+ del vis, pred_vis, pred_x0, pred_eps_p, micro, vae_out
+
+ th.cuda.empty_cache()
+
+
+class TrainLoop3DDiffusionLSGMJointnoD_ponly(TrainLoop3DDiffusionLSGMJointnoD):
+
+ def __init__(self,
+ *,
+ rec_model,
+ denoise_model,
+ sde_diffusion,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ triplane_scaling_divider=1,
+ use_amp=False,
+ diffusion_input_size=224,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ denoise_model=denoise_model,
+ sde_diffusion=sde_diffusion,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ triplane_scaling_divider=triplane_scaling_divider,
+ use_amp=use_amp,
+ diffusion_input_size=diffusion_input_size,
+ **kwargs)
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ self._post_run_loop()
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch = next(self.data)
+ # self.run_step(batch, step='ce_ddpm_step')
+
+ batch = next(self.data)
+ self.run_step(batch, step='p_rendering_step')
+ # self.run_step(batch, step='eps_step')
diff --git a/nsr/lsgm/train_util_diffusion_vpsde.py b/nsr/lsgm/train_util_diffusion_vpsde.py
new file mode 100644
index 0000000000000000000000000000000000000000..bea22ff78941b8ce09eae3eb0a71b642ebd036de
--- /dev/null
+++ b/nsr/lsgm/train_util_diffusion_vpsde.py
@@ -0,0 +1,583 @@
+"""
+Modified from:
+https://github.com/NVlabs/LSGM/blob/main/training_obj_joint.py
+"""
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+from typing import Any
+
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+# from .train_util import TrainLoop3DRec
+from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_loss_dict,
+ log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+from guided_diffusion.gaussian_diffusion import ModelMeanType
+
+import dnnlib
+from dnnlib.util import calculate_adaptive_weight
+
+from ..train_util_diffusion import TrainLoop3DDiffusion
+from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD
+
+
+class TrainLoop3DDiffusion_vpsde(TrainLoop3DDiffusion,TrainLoop3DcvD_nvsD_canoD):
+ def __init__(self, *, rec_model, denoise_model, diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, schedule_sampler=None, weight_decay=0, lr_anneal_steps=0, iterations=10001, ignore_resume_opt=False, freeze_ae=False, denoised_ae=True, triplane_scaling_divider=10, use_amp=False, diffusion_input_size=224, **kwargs):
+ super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, ignore_resume_opt=ignore_resume_opt, freeze_ae=freeze_ae, denoised_ae=denoised_ae, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, **kwargs)
+
+ def run_step(self, batch, step='g_step'):
+
+ if step == 'diffusion_step_rec':
+ self.forward_diffusion(batch, behaviour='diffusion_step_rec')
+ _ = self.mp_trainer_rec.optimize(self.opt_rec) # TODO, update two groups of parameters
+ took_step_ddpm = self.mp_trainer.optimize(self.opt) # TODO, update two groups of parameters
+
+ if took_step_ddpm:
+ self._update_ema() # g_ema # TODO, ema only needs to track ddpm, remove ema tracking in rec
+
+ elif step == 'd_step_rec':
+ self.forward_D(batch, behaviour='rec')
+ # _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+ _ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD)
+
+ elif step == 'diffusion_step_nvs':
+ self.forward_diffusion(batch, behaviour='diffusion_step_nvs')
+ _ = self.mp_trainer_rec.optimize(self.opt_rec) # TODO, update two groups of parameters
+ took_step_ddpm = self.mp_trainer.optimize(self.opt) # TODO, update two groups of parameters
+
+ if took_step_ddpm:
+ self._update_ema() # g_ema
+
+ elif step == 'd_step_nvs':
+ self.forward_D(batch, behaviour='nvs')
+ _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch, cond = next(self.data)
+ # if batch is None:
+ # batch = next(self.data)
+ # self.run_step(batch, 'g_step_rec')
+
+ batch = next(self.data)
+ self.run_step(batch, step='diffusion_step_rec')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step_rec')
+
+ # batch = next(self.data)
+ # self.run_step(batch, 'g_step_nvs')
+
+ batch = next(self.data)
+ self.run_step(batch, step='diffusion_step_nvs')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step_nvs')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ # if self.step % self.eval_interval == 0 and self.step != 0:
+ if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ self.eval_loop()
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name)
+ self.save(self.mp_trainer_cvD, 'cvD')
+ self.save(self.mp_trainer_canonical_cvD, 'cano_cvD')
+
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+
+ self.save(self.mp_trainer, self.mp_trainer.model_name)
+ self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name)
+ self.save(self.mp_trainer_cvD, 'cvD')
+ self.save(self.mp_trainer_canonical_cvD, 'cano_cvD')
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ self.save(self.mp_trainer_canonical_cvD, 'cvD')
+
+ def forward_diffusion(self, batch, behaviour='rec', *args, **kwargs):
+ """
+ add sds grad to all ae predicted x_0
+ """
+
+ self.ddp_cano_cvD.requires_grad_(False)
+ self.ddp_nvs_cvD.requires_grad_(False)
+
+ self.ddp_model.requires_grad_(True)
+ self.ddp_rec_model.requires_grad_(True)
+
+ # if behaviour != 'diff' and 'rec' in behaviour:
+ # if behaviour != 'diff' and 'rec' in behaviour: # pure diffusion step
+ # self.ddp_rec_model.requires_grad_(True)
+ for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore
+ ): # type: ignore
+ param.requires_grad_(False) # ! disable triplane_decoder grad in each iteration indepenently;
+ # else:
+
+ self.mp_trainer_rec.zero_grad()
+ self.mp_trainer.zero_grad()
+
+ # ! no 'sds' step now, both add sds grad back to ViT
+
+ # assert behaviour != 'sds'
+ # if behaviour == 'sds':
+ # else:
+ # self.ddp_ddpm_model.requires_grad_(True)
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ vae_nelbo_loss = th.tensor(0.0).to(dist_util.dev())
+ vision_aided_loss = th.tensor(0.0).to(dist_util.dev())
+ denoise_loss = th.tensor(0.0).to(dist_util.dev())
+ d_weight = th.tensor(0.0).to(dist_util.dev())
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp
+ and not self.freeze_ae):
+
+ # apply vae
+ vae_out = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='enc_dec_wo_triplane') # pred: (B, 3, 64, 64)
+
+
+ if behaviour == 'diffusion_step_rec':
+ target = micro
+ pred = self.ddp_rec_model(latent=vae_out,
+ c=micro['c'],
+ behaviour='triplane_dec')
+
+ # vae reconstruction loss
+ if last_batch or not self.use_ddp:
+ vae_nelbo_loss, loss_dict = self.loss_class(pred,
+ target,
+ test_mode=False)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ vae_nelbo_loss, loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+
+ last_layer = self.ddp_rec_model.module.decoder.triplane_decoder.decoder.net[ # type: ignore
+ -1].weight # type: ignore
+
+ if 'image_sr' in pred:
+ vision_aided_loss = self.ddp_cano_cvD(
+ 0.5 * pred['image_sr'] +
+ 0.5 * th.nn.functional.interpolate(
+ pred['image_raw'],
+ size=pred['image_sr'].shape[2:],
+ mode='bilinear'),
+ for_G=True).mean() # [B, 1] shape
+ else:
+ vision_aided_loss = self.ddp_cano_cvD(
+ pred['image_raw'], for_G=True
+ ).mean(
+ ) # [B, 1] shape
+
+ d_weight = calculate_adaptive_weight(
+ vae_nelbo_loss,
+ vision_aided_loss,
+ last_layer,
+ # disc_weight_max=1) * 1
+ disc_weight_max=1) * self.loss_class.opt.rec_cvD_lambda
+ # d_weight = self.loss_class.opt.rec_cvD_lambda # since decoder is fixed here. set to 0.001
+
+ vision_aided_loss *= d_weight
+
+ # d_weight = self.loss_class.opt.rec_cvD_lambda
+ loss_dict.update({
+ 'vision_aided_loss/G_rec':
+ vision_aided_loss,
+ 'd_weight_G_rec':
+ d_weight,
+ })
+
+ log_rec3d_loss_dict(loss_dict)
+
+ elif behaviour == 'diffusion_step_nvs':
+
+ novel_view_c = th.cat([micro['c'][1:], micro['c'][:1]])
+
+ pred = self.ddp_rec_model(latent=vae_out,
+ c=novel_view_c,
+ behaviour='triplane_dec')
+
+ if 'image_sr' in pred:
+ vision_aided_loss = self.ddp_nvs_cvD(
+ # pred_for_rec['image_sr'],
+ 0.5 * pred['image_sr'] +
+ 0.5 * th.nn.functional.interpolate(
+ pred['image_raw'],
+ size=pred['image_sr'].shape[2:],
+ mode='bilinear'),
+ for_G=True).mean() # [B, 1] shape
+ else:
+ vision_aided_loss = self.ddp_nvs_cvD(
+ pred['image_raw'], for_G=True
+ ).mean(
+ ) # [B, 1] shape
+
+ d_weight = self.loss_class.opt.nvs_cvD_lambda
+ vision_aided_loss *= d_weight
+
+ log_rec3d_loss_dict({
+ 'vision_aided_loss/G_nvs':
+ vision_aided_loss,
+ })
+
+ # ae_loss = th.tensor(0.0).to(dist_util.dev())
+
+ # elif behaviour == 'diff':
+ # self.ddp_rec_model.requires_grad_(False)
+ # # assert self.ddp_rec_model.module.requires_grad == False, 'freeze ddpm_rec for pure diff step'
+ else:
+ raise NotImplementedError(behaviour)
+ # assert behaviour == 'sds'
+
+ # pred = None
+
+ # if behaviour != 'sds': # also train diffusion
+ # assert pred is not None
+
+ # TODO, train diff and sds together, available?
+ eps = vae_out[self.latent_name]
+
+ # if behaviour != 'sds':
+ # micro_to_denoise.detach_()
+ eps.requires_grad_(True) # single stage diffusion
+
+ t, weights = self.schedule_sampler.sample(
+ eps.shape[0], dist_util.dev())
+ noise = th.randn(size=vae_out.size(), device='cuda') # note that this noise value is currently shared!
+
+ model_kwargs = {}
+
+ # ?
+ # or directly use SSD NeRF version?
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+
+ # ! handle the sampling
+
+ # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae)
+ t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \
+ diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_p, args.iw_subvp_like_vp_sde)
+ eps_t_p = diffusion.sample_q(vae_out, noise, var_t_p, m_t_p)
+
+ # in case we want to train q (vae) with another batch using a different sampling scheme for times t
+ if args.iw_sample_q in ['ll_uniform', 'll_iw']:
+ t_q, var_t_q, m_t_q, obj_weight_t_q, _, g2_t_q = \
+ diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_q, args.iw_subvp_like_vp_sde)
+ eps_t_q = diffusion.sample_q(vae_out, noise, var_t_q, m_t_q)
+
+ eps_t_p = eps_t_p.detach().requires_grad_(True)
+ eps_t = th.cat([eps_t_p, eps_t_q], dim=0)
+ var_t = th.cat([var_t_p, var_t_q], dim=0)
+ t = th.cat([t_p, t_q], dim=0)
+ noise = th.cat([noise, noise], dim=0)
+ else:
+ eps_t, m_t, var_t, t, g2_t = eps_t_p, m_t_p, var_t_p, t_p, g2_t_p
+
+ # run the diffusion
+
+ # mixing normal trick
+ # TODO, create a new partial training_losses function
+ mixing_component = diffusion.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction) # TODO, which should I use?
+ params = utils.get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component)
+
+ # nelbo loss with kl balancing
+
+
+
+
+ # ! remainign parts of cross entropy in likelihook training
+
+ cross_entropy_per_var += diffusion.cross_entropy_const(args.time_eps)
+ cross_entropy = th.sum(cross_entropy_per_var, dim=[1, 2, 3])
+ cross_entropy += remaining_neg_log_p_total # for remaining scales if there is any
+ all_neg_log_p = vae.decompose_eps(cross_entropy_per_var)
+ all_neg_log_p.extend(remaining_neg_log_p_per_ver) # add the remaining neg_log_p
+ kl_all_list, kl_vals_per_group, kl_diag_list = utils.kl_per_group_vada(all_log_q, all_neg_log_p)
+
+
+ kl_coeff = 1.0
+
+ # ! calculate p/q loss;
+ # ? no spectral regularizer here
+ # ? try adding grid_clip and sn later on.
+ q_loss = th.mean(nelbo_loss)
+ p_loss = th.mean(p_objective)
+
+ # backpropagate q_loss for vae and update vae params, if trained
+ if args.train_vae:
+ grad_scalar.scale(q_loss).backward(retain_graph=utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q))
+ utils.average_gradients(vae.parameters(), args.distributed)
+ if args.grad_clip_max_norm > 0.: # apply gradient clipping
+ grad_scalar.unscale_(vae_optimizer)
+ th.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=args.grad_clip_max_norm)
+ grad_scalar.step(vae_optimizer)
+
+ # if we use different p and q objectives or are not training the vae, discard gradients and backpropagate p_loss
+ if utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q) or not args.train_vae:
+ if args.train_vae:
+ # discard current gradients computed by weighted loss for VAE
+ dae_optimizer.zero_grad()
+
+ # compute gradients with unweighted loss
+ grad_scalar.scale(p_loss).backward()
+
+ # update dae parameters
+ utils.average_gradients(dae.parameters(), args.distributed)
+ if args.grad_clip_max_norm > 0.: # apply gradient clipping
+ grad_scalar.unscale_(dae_optimizer)
+ th.nn.utils.clip_grad_norm_(dae.parameters(), max_norm=args.grad_clip_max_norm)
+ grad_scalar.step(dae_optimizer)
+
+
+ # unpack separate objectives, in case we want to train q (vae) using a different sampling scheme for times t
+ if args.iw_sample_q in ['ll_uniform', 'll_iw']:
+ l2_term_p, l2_term_q = th.chunk(l2_term, chunks=2, dim=0)
+ p_objective = th.sum(obj_weight_t_p * l2_term_p, dim=[1, 2, 3])
+ # cross_entropy_per_var = obj_weight_t_q * l2_term_q
+ else:
+ p_objective = th.sum(obj_weight_t_p * l2_term, dim=[1, 2, 3])
+ # cross_entropy_per_var = obj_weight_t_q * l2_term
+
+ # print(micro_to_denoise.min(), micro_to_denoise.max())
+ compute_losses = functools.partial(
+ self.diffusion.training_losses,
+ self.ddp_model,
+ eps, # x_start
+ t,
+ model_kwargs=model_kwargs,
+ return_detail=True)
+
+ # ! DDPM step
+ if last_batch or not self.use_ddp:
+ losses = compute_losses()
+ # denoised_out = denoised_fn()
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ losses = compute_losses()
+
+ if isinstance(self.schedule_sampler, LossAwareSampler):
+ self.schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach())
+
+ denoise_loss = (losses["loss"] * weights).mean()
+
+ x_t = losses.pop('x_t')
+ model_output = losses.pop('model_output')
+ diffusion_target = losses.pop('diffusion_target')
+ alpha_bar = losses.pop('alpha_bar')
+
+ log_loss_dict(self.diffusion, t,
+ {k: v * weights
+ for k, v in losses.items()})
+
+ # if behaviour == 'sds':
+ # ! calculate sds grad, and add to the grad of
+
+ # if 'rec' in behaviour and self.loss_class.opt.sds_lamdba > 0: # only enable sds along with rec step
+ # w = (
+ # 1 - alpha_bar**2
+ # ) / self.triplane_scaling_divider * self.loss_class.opt.sds_lamdba # https://github.com/ashawkey/stable-dreamfusion/issues/106
+ # sds_grad = denoise_loss.clone().detach(
+ # ) * w # * https://pytorch.org/docs/stable/generated/th.Tensor.detach.html. detach() returned Tensor share the same storage with previous one. add clone() here.
+
+ # # ae_loss = AddGradient.apply(latent[self.latent_name], sds_grad) # add sds_grad during backward
+
+ # def sds_hook(grad_to_add):
+
+ # def modify_grad(grad):
+ # return grad + grad_to_add # add the sds grad to the original grad for BP
+
+ # return modify_grad
+
+ # eps[self.latent_name].register_hook(
+ # sds_hook(sds_grad)) # merge sds grad with rec/nvs ae step
+
+ loss = vae_nelbo_loss + denoise_loss + vision_aided_loss # caluclate loss within AMP
+
+ # ! cvD loss
+
+ # exit AMP before backward
+ self.mp_trainer_rec.backward(loss)
+ self.mp_trainer.backward(loss)
+
+ # TODO, merge visualization with original AE
+ # =================================== denoised AE log part ===================================
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0 and behaviour != 'diff':
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ # st()
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ # if 'image_sr' in pred: # TODO
+ # pred_img = th.cat(
+ # [self.pool_512(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat(
+ # [self.pool_512(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_512(pred_depth)
+ # gt_depth = self.pool_512(gt_depth)
+
+ gt_vis = th.cat(
+ [
+ gt_img, micro['img'], micro['img'],
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=x_t[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ # if denoised_out is None:
+ # if not self.denoised_ae:
+ # denoised_out = denoised_fn()
+
+ if self.diffusion.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = model_output
+ else: # * used here
+ pred_xstart = self.diffusion._predict_xstart_from_eps(
+ x_t=x_t, t=t, eps=model_output)
+
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent=pred_xstart[0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ # denoised_out = denoised_ae_pred
+
+ # if not self.denoised_ae:
+ # denoised_ae_pred = self.ddp_rec_model(
+ # img=None,
+ # c=micro['c'][0:1],
+ # latent=denoised_out['pred_xstart'][0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically
+ # behaviour=self.render_latent_behaviour)
+ # else:
+ # assert denoised_ae_pred is not None
+ # denoised_ae_pred['image_raw'] = denoised_ae_pred[
+ # 'image_raw'][0:1]
+
+ # print(pred_img.shape)
+ # print('denoised_ae:', self.denoised_ae)
+
+ pred_vis = th.cat([
+ pred_img[0:1], noised_ae_pred['image_raw'][0:1],
+ denoised_ae_pred['image_raw'][0:1],
+ pred_depth[0:1].repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+ # s
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis = th.cat([
+ # self.pool_128(micro['img']), x_t[:, :3, ...],
+ # denoised_out['pred_xstart'][:, :3, ...]
+ # ],
+ # dim=-1)[0].permute(
+ # 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg'
+ )
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg'
+ )
+
+ th.cuda.empty_cache()
diff --git a/nsr/networks_stylegan2.py b/nsr/networks_stylegan2.py
new file mode 100644
index 0000000000000000000000000000000000000000..27b24236f8abe4ec3b93f96ef388957aeac28f1a
--- /dev/null
+++ b/nsr/networks_stylegan2.py
@@ -0,0 +1,1093 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Network architectures from the paper
+"Analyzing and Improving the Image Quality of StyleGAN".
+Matches the original implementation of configs E-F by Karras et al. at
+https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py"""
+
+import numpy as np
+import torch
+from utils.torch_utils import misc
+from utils.torch_utils import persistence
+from utils.torch_utils.ops import conv2d_resample
+from utils.torch_utils.ops import upfirdn2d
+from utils.torch_utils.ops import bias_act
+from utils.torch_utils.ops import fma
+from pdb import set_trace as st
+
+from pdb import set_trace as st
+
+#----------------------------------------------------------------------------
+
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+
+#----------------------------------------------------------------------------
+
+
+@misc.profiled_function
+# @torch.autocast(device_type='cuda')
+def modulated_conv2d(
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
+ noise=None, # Optional noise tensor to add to the output activations.
+ up=1, # Integer upsampling factor.
+ down=1, # Integer downsampling factor.
+ padding=0, # Padding with respect to the upsampled image.
+ resample_filter=None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
+ demodulate=True, # Apply weight demodulation?
+ flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
+ fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation?
+):
+ batch_size = x.shape[0]
+ out_channels, in_channels, kh, kw = weight.shape
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs to avoid FP16 overflow.
+ if x.dtype == torch.float16 and demodulate:
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(
+ float('inf'), dim=[1, 2, 3], keepdim=True)) # max_Ikk
+ styles = styles / styles.norm(float('inf'), dim=1,
+ keepdim=True) # max_I
+
+ # Calculate per-sample weights and demodulation coefficients.
+ w = None
+ dcoefs = None
+ if demodulate or fused_modconv:
+ w = weight.unsqueeze(0) # [NOIkk]
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
+ if demodulate and fused_modconv:
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
+
+ # Execute by scaling the activations before and after the convolution.
+ if not fused_modconv:
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ x = conv2d_resample.conv2d_resample(x=x,
+ w=weight.to(x.dtype),
+ f=resample_filter,
+ up=up,
+ down=down,
+ padding=padding,
+ flip_weight=flip_weight)
+ if demodulate and noise is not None:
+ x = fma.fma(x,
+ dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1),
+ noise.to(x.dtype))
+ elif demodulate:
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ elif noise is not None:
+ x = x.add_(noise.to(x.dtype))
+ return x
+
+ # Execute as one fused op using grouped convolution.
+ with misc.suppress_tracer_warnings(
+ ): # this value will be treated as a constant
+ batch_size = int(batch_size)
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_resample.conv2d_resample(x=x,
+ w=w.to(x.dtype),
+ f=resample_filter,
+ up=up,
+ down=down,
+ padding=padding,
+ groups=batch_size,
+ flip_weight=flip_weight)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ if noise is not None:
+ x = x.add_(noise)
+ return x
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias=True, # Apply additive bias before the activation function?
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier=1, # Learning rate multiplier.
+ bias_init=0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.weight = torch.nn.Parameter(
+ torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(
+ torch.full([out_features],
+ np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class Conv2dLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias=True, # Apply additive bias before the activation function?
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
+ up=1, # Integer upsampling factor.
+ down=1, # Integer downsampling factor.
+ resample_filter=[
+ 1, 3, 3, 1
+ ], # Low-pass filter to apply when resampling activations.
+ conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
+ channels_last=False, # Expect the input to have memory_format=channels_last?
+ trainable=True, # Update the weights of this layer during training?
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn(
+ [out_channels, in_channels, kernel_size,
+ kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ # @torch.autocast(device_type='cuda')
+ def forward(self, x, gain=1):
+ w = self.weight * self.weight_gain # w dtype is fp32
+ b = self.bias.to(x.dtype) if self.bias is not None else None
+
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x,
+ w=w.to(x.dtype),
+ f=self.resample_filter,
+ up=self.up,
+ down=self.down,
+ padding=self.padding,
+ flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x,
+ b,
+ act=self.activation,
+ gain=act_gain,
+ clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
+ f'up={self.up}, down={self.down}'
+ ])
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class MappingNetwork(torch.nn.Module):
+ def __init__(
+ self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers=8, # Number of mapping layers.
+ embed_features=None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta=0.998, # Decay for tracking the moving average of W during training, None = do not track.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+ features_list = [z_dim + embed_features
+ ] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ layer = FullyConnectedLayer(in_features,
+ out_features,
+ activation=activation,
+ lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if update_emas and self.w_avg_beta is not None:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(
+ self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
+ x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class SynthesisLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this layer.
+ kernel_size=3, # Convolution kernel size.
+ up=1, # Integer upsampling factor.
+ use_noise=True, # Enable noise input?
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter=[
+ 1, 3, 3, 1
+ ], # Low-pass filter to apply when resampling activations.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ channels_last=False, # Use channels_last format for the weights?
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.up = up
+ self.use_noise = use_noise
+ self.activation = activation
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(
+ torch.randn([out_channels, in_channels, kernel_size,
+ kernel_size]).to(memory_format=memory_format))
+ if use_noise:
+ self.register_buffer('noise_const',
+ torch.randn([resolution, resolution]))
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+
+ # def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
+ def forward(self, x, w, noise_mode='const', fused_modconv=True, gain=1):
+ assert noise_mode in ['random', 'const', 'none']
+ in_resolution = self.resolution // self.up
+ misc.assert_shape(
+ x, [None, self.in_channels, in_resolution, in_resolution])
+ styles = self.affine(w)
+
+ noise = None
+ if self.use_noise and noise_mode == 'random':
+ noise = torch.randn(
+ [x.shape[0], 1, self.resolution, self.resolution],
+ device=x.device) * self.noise_strength
+ if self.use_noise and noise_mode == 'const':
+ noise = self.noise_const * self.noise_strength
+
+ flip_weight = (self.up == 1) # slightly faster
+ x = modulated_conv2d(x=x,
+ weight=self.weight,
+ styles=styles,
+ noise=noise,
+ up=self.up,
+ padding=self.padding,
+ resample_filter=self.resample_filter,
+ flip_weight=flip_weight,
+ fused_modconv=fused_modconv)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x,
+ self.bias.to(x.dtype),
+ act=self.activation,
+ gain=act_gain,
+ clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
+ f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'
+ ])
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class ToRGBLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ w_dim,
+ kernel_size=1,
+ conv_clamp=None,
+ channels_last=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.conv_clamp = conv_clamp
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(
+ torch.randn([out_channels, in_channels, kernel_size,
+ kernel_size]).to(memory_format=memory_format))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
+
+ def forward(self, x, w, fused_modconv=True):
+ styles = self.affine(w) * self.weight_gain
+ x = modulated_conv2d(x=x,
+ weight=self.weight,
+ styles=styles,
+ demodulate=False,
+ fused_modconv=fused_modconv)
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
+ return x
+
+ def extra_repr(self):
+ return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class SynthesisBlock(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter=[
+ 1, 3, 3, 1
+ ], # Low-pass filter to apply when resampling activations.
+ conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16=False, # Use FP16 for this block?
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
+ fused_modconv_default=True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(
+ torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels,
+ out_channels,
+ w_dim=w_dim,
+ resolution=resolution,
+ up=2,
+ resample_filter=resample_filter,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last,
+ **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels,
+ out_channels,
+ w_dim=w_dim,
+ resolution=resolution,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last,
+ **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels,
+ img_channels,
+ w_dim=w_dim,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False,
+ up=2,
+ resample_filter=resample_filter,
+ channels_last=self.channels_last)
+
+ def forward(self,
+ x,
+ img,
+ ws,
+ force_fp32=False,
+ fused_modconv=None,
+ update_emas=False,
+ **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws,
+ [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [
+ None, self.in_channels, self.resolution // 2,
+ self.resolution // 2
+ ])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+ x = self.conv1(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ gain=np.sqrt(0.5),
+ **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+ x = self.conv1(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+
+ # ToRGB.
+ if img is not None:
+ misc.assert_shape(img, [
+ None, self.img_channels, self.resolution // 2,
+ self.resolution // 2
+ ])
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32,
+ memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ # assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(
+ self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base=32768, # Overall multiplier for the number of channels.
+ channel_max=512, # Maximum number of channels in any layer.
+ num_fp16_res=4, # Use FP16 for the N highest resolutions.
+ **block_kwargs, # Arguments for SynthesisBlock.
+ ):
+ assert img_resolution >= 4 and img_resolution & (img_resolution -
+ 1) == 0
+ super().__init__()
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.num_fp16_res = num_fp16_res
+ self.block_resolutions = [
+ 2**i for i in range(2, self.img_resolution_log2 + 1)
+ ]
+ channels_dict = {
+ res: min(channel_base // res, channel_max)
+ for res in self.block_resolutions
+ }
+ fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res),
+ 8)
+
+ self.num_ws = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res // 2] if res > 4 else 0
+ out_channels = channels_dict[res]
+ use_fp16 = (res >= fp16_resolution)
+ is_last = (res == self.img_resolution)
+ block = SynthesisBlock(in_channels,
+ out_channels,
+ w_dim=w_dim,
+ resolution=res,
+ img_channels=img_channels,
+ is_last=is_last,
+ use_fp16=use_fp16,
+ **block_kwargs)
+ self.num_ws += block.num_conv
+ if is_last:
+ self.num_ws += block.num_torgb
+ setattr(self, f'b{res}', block)
+
+ def forward(self, ws, **block_kwargs):
+ block_ws = []
+ with torch.autograd.profiler.record_function('split_ws'):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32)
+ w_idx = 0
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ block_ws.append(
+ ws.narrow(1, w_idx, block.num_conv +
+ block.num_torgb)) # dim start length
+ w_idx += block.num_conv
+ # print(f'synthesisNetwork : b{res}, device={block.conv1.weight.device}')
+
+ x = img = None
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, cur_ws, **block_kwargs)
+ return img
+
+ def extra_repr(self):
+ return ' '.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_fp16_res={self.num_fp16_res:d}'
+ ])
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class Generator(torch.nn.Module):
+ def __init__(
+ self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim,
+ img_resolution=img_resolution,
+ img_channels=img_channels,
+ **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim,
+ c_dim=c_dim,
+ w_dim=w_dim,
+ num_ws=self.num_ws,
+ **mapping_kwargs)
+
+ def forward(self,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False,
+ **synthesis_kwargs):
+ ws = self.mapping(z,
+ c,
+ truncation_psi=truncation_psi,
+ truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class DiscriminatorBlock(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels, # Number of input channels, 0 = first block.
+ tmp_channels, # Number of intermediate channels.
+ out_channels, # Number of output channels.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ first_layer_idx, # Index of the first layer.
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter=[
+ 1, 3, 3, 1
+ ], # Low-pass filter to apply when resampling activations.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16=False, # Use FP16 for this block?
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
+ freeze_layers=0, # Freeze-D: Number of layers to freeze.
+ ):
+ assert in_channels in [0, tmp_channels]
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.first_layer_idx = first_layer_idx
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter(resample_filter))
+
+ self.num_layers = 0
+
+ def trainable_gen():
+ while True:
+ layer_idx = self.first_layer_idx + self.num_layers
+ trainable = (layer_idx >= freeze_layers)
+ self.num_layers += 1
+ yield trainable
+
+ trainable_iter = trainable_gen()
+
+ if in_channels == 0 or architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels,
+ tmp_channels,
+ kernel_size=1,
+ activation=activation,
+ trainable=next(trainable_iter),
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last)
+
+ self.conv0 = Conv2dLayer(tmp_channels,
+ tmp_channels,
+ kernel_size=3,
+ activation=activation,
+ trainable=next(trainable_iter),
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last)
+
+ self.conv1 = Conv2dLayer(tmp_channels,
+ out_channels,
+ kernel_size=3,
+ activation=activation,
+ down=2,
+ trainable=next(trainable_iter),
+ resample_filter=resample_filter,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last)
+
+ if architecture == 'resnet':
+ self.skip = Conv2dLayer(tmp_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False,
+ down=2,
+ trainable=next(trainable_iter),
+ resample_filter=resample_filter,
+ channels_last=self.channels_last)
+
+ def forward(self, x, img, force_fp32=False):
+ if (x if x is not None else img).device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ # dtype = img.dtype
+ # dtype = x.dtype
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+
+ # Input.
+ if x is not None:
+ misc.assert_shape(
+ x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # FromRGB.
+ if self.in_channels == 0 or self.architecture == 'skip':
+ misc.assert_shape(
+ img,
+ [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ y = self.fromrgb(img)
+ x = x + y if x is not None else y
+ img = upfirdn2d.downsample2d(
+ img,
+ self.resample_filter) if self.architecture == 'skip' else None
+
+ # Main layers.
+ if self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x)
+ x = self.conv1(x, gain=np.sqrt(0.5))
+ x = y.add_(x)
+ else:
+ x = self.conv0(x)
+ x = self.conv1(x)
+
+ assert x.dtype == dtype
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class MinibatchStdLayer(torch.nn.Module):
+ def __init__(self, group_size, num_channels=1):
+ super().__init__()
+ self.group_size = group_size
+ self.num_channels = num_channels
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+ with misc.suppress_tracer_warnings(
+ ): # as_tensor results are registered as constants
+ G = torch.min(
+ torch.as_tensor(self.group_size),
+ torch.as_tensor(N)) if self.group_size is not None else N
+ F = self.num_channels
+ c = C // F
+
+ y = x.reshape(
+ G, -1, F, c, H, W
+ ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
+ y = y.mean(dim=[2, 3,
+ 4]) # [nF] Take average over channels and pixels.
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
+ x = torch.cat([x, y],
+ dim=1) # [NCHW] Append to input as new channels.
+ return x
+
+ def extra_repr(self):
+ return f'group_size={self.group_size}, num_channels={self.num_channels:d}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class DiscriminatorEpilogue(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels, # Number of input channels.
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.cmap_dim = cmap_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ if architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels,
+ in_channels,
+ kernel_size=1,
+ activation=activation)
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
+ num_channels=mbstd_num_channels
+ ) if mbstd_num_channels > 0 else None
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels,
+ in_channels,
+ kernel_size=3,
+ activation=activation,
+ conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * (resolution**2),
+ in_channels,
+ activation=activation)
+ self.out = FullyConnectedLayer(in_channels,
+ 1 if cmap_dim == 0 else cmap_dim)
+
+ def forward(self, x, img, cmap, force_fp32=False):
+ misc.assert_shape(
+ x, [None, self.in_channels, self.resolution, self.resolution
+ ]) # [NCHW]
+ _ = force_fp32 # unused
+ # dtype = torch.float32
+ dtype = x.dtype
+ memory_format = torch.contiguous_format
+
+ # FromRGB.
+ x = x.to(dtype=dtype, memory_format=memory_format)
+ if self.architecture == 'skip':
+ misc.assert_shape(
+ img,
+ [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ x = x + self.fromrgb(img)
+
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ x = self.conv(x)
+ x = self.fc(x.flatten(1))
+ x = self.out(x)
+
+ # Conditioning.
+ if self.cmap_dim > 0:
+ misc.assert_shape(cmap, [None, self.cmap_dim])
+ x = (x * cmap).sum(dim=1,
+ keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ assert x.dtype == dtype
+ return x
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class Discriminator(torch.nn.Module):
+ def __init__(
+ self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base=32768, # Overall multiplier for the number of channels.
+ channel_max=512, # Maximum number of channels in any layer.
+ num_fp16_res=4, # Use FP16 for the N highest resolutions.
+ conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim=None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs={}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ epilogue_kwargs={}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [
+ 2**i for i in range(self.img_resolution_log2, 2, -1)
+ ]
+ channels_dict = {
+ res: min(channel_base // res, channel_max)
+ for res in self.block_resolutions + [4]
+ }
+ fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res),
+ 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels,
+ architecture=architecture,
+ conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels,
+ tmp_channels,
+ out_channels,
+ resolution=res,
+ first_layer_idx=cur_layer_idx,
+ use_fp16=use_fp16,
+ **block_kwargs,
+ **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0,
+ c_dim=c_dim,
+ w_dim=cmap_dim,
+ num_ws=None,
+ w_avg_beta=None,
+ **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4],
+ cmap_dim=cmap_dim,
+ resolution=4,
+ **epilogue_kwargs,
+ **common_kwargs)
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+
+#----------------------------------------------------------------------------
diff --git a/nsr/networks_stylegan3.py b/nsr/networks_stylegan3.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba4f88ac6f17e07dfb4a8e2e08091367b66e0d4
--- /dev/null
+++ b/nsr/networks_stylegan3.py
@@ -0,0 +1,679 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Generator architecture from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import numpy as np
+import scipy.signal
+import scipy.optimize
+import torch
+from utils.torch_utils import misc
+from utils.torch_utils import persistence
+from utils.torch_utils.ops import conv2d_gradfix
+from utils.torch_utils.ops import filtered_lrelu
+from utils.torch_utils.ops import bias_act
+
+#----------------------------------------------------------------------------
+# from pdb import set_trace as st
+
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor: [batch_size, in_channels, in_height, in_width]
+ w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
+ s, # Style tensor: [batch_size, in_channels]
+ demodulate=True, # Apply weight demodulation?
+ padding=0, # Padding: int or [padH, padW]
+ input_gain=None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
+):
+ with misc.suppress_tracer_warnings(
+ ): # this value will be treated as a constant
+ batch_size = int(x.shape[0])
+ out_channels, in_channels, kh, kw = w.shape
+ misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(s, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs.
+ if demodulate:
+ w = w * w.square().mean([1, 2, 3], keepdim=True).rsqrt()
+ s = s * s.square().mean().rsqrt()
+
+ # Modulate weights.
+ w = w.unsqueeze(0) # [NOIkk]
+ w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Demodulate weights.
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO]
+ w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Apply input scaling.
+ if input_gain is not None:
+ input_gain = input_gain.expand(batch_size, in_channels) # [NI]
+ w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Execute as one fused op using grouped convolution.
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_gradfix.conv2d(input=x,
+ weight=w.to(x.dtype),
+ padding=padding,
+ groups=batch_size)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ return x
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
+ bias=True, # Apply additive bias before the activation function?
+ lr_multiplier=1, # Learning rate multiplier.
+ weight_init=1, # Initial standard deviation of the weight tensor.
+ bias_init=0, # Initial value of the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.weight = torch.nn.Parameter(
+ torch.randn([out_features, in_features]) *
+ (weight_init / lr_multiplier))
+ bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32),
+ [out_features])
+ self.bias = torch.nn.Parameter(
+ torch.from_numpy(bias_init / lr_multiplier)) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class MappingNetwork(torch.nn.Module):
+ def __init__(
+ self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output.
+ num_layers=2, # Number of mapping layers.
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta=0.998, # Decay for tracking the moving average of W during training.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ # Construct layers.
+ self.embed = FullyConnectedLayer(
+ self.c_dim, self.w_dim) if self.c_dim > 0 else None
+ features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)
+ ] + [self.w_dim] * self.num_layers
+ for idx, in_features, out_features in zip(range(num_layers),
+ features[:-1], features[1:]):
+ layer = FullyConnectedLayer(in_features,
+ out_features,
+ activation='lrelu',
+ lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False):
+ misc.assert_shape(z, [None, self.z_dim])
+ if truncation_cutoff is None:
+ truncation_cutoff = self.num_ws
+
+ # Embed, normalize, and concatenate inputs.
+ x = z.to(torch.float32)
+ x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = self.embed(c.to(torch.float32))
+ y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Execute layers.
+ for idx in range(self.num_layers):
+ x = getattr(self, f'fc{idx}')(x)
+
+ # Update moving average of W.
+ if update_emas:
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(
+ self.w_avg, self.w_avg_beta))
+
+ # Broadcast and apply truncation.
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+ if truncation_psi != 1:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(
+ x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class SynthesisInput(torch.nn.Module):
+ def __init__(
+ self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ channels, # Number of output channels.
+ size, # Output spatial size: int or [width, height].
+ sampling_rate, # Output sampling rate.
+ bandwidth, # Output bandwidth.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.channels = channels
+ self.size = np.broadcast_to(np.asarray(size), [2])
+ self.sampling_rate = sampling_rate
+ self.bandwidth = bandwidth
+
+ # Draw random frequencies from uniform 2D disc.
+ freqs = torch.randn([self.channels, 2])
+ radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
+ freqs /= radii * radii.square().exp().pow(0.25)
+ freqs *= bandwidth
+ phases = torch.rand([self.channels]) - 0.5
+
+ # Setup parameters and buffers.
+ self.weight = torch.nn.Parameter(
+ torch.randn([self.channels, self.channels]))
+ self.affine = FullyConnectedLayer(w_dim,
+ 4,
+ weight_init=0,
+ bias_init=[1, 0, 0, 0])
+ self.register_buffer('transform', torch.eye(
+ 3, 3)) # User-specified inverse transform wrt. resulting image.
+ self.register_buffer('freqs', freqs)
+ self.register_buffer('phases', phases)
+
+ def forward(self, w):
+ # Introduce batch dimension.
+ transforms = self.transform.unsqueeze(0) # [batch, row, col]
+ freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
+ phases = self.phases.unsqueeze(0) # [batch, channel]
+
+ # Apply learned transformation.
+ t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
+ t = t / t[:, :2].norm(dim=1,
+ keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
+ m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat(
+ [w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
+ m_r[:, 0, 0] = t[:, 0] # r'_c
+ m_r[:, 0, 1] = -t[:, 1] # r'_s
+ m_r[:, 1, 0] = t[:, 1] # r'_s
+ m_r[:, 1, 1] = t[:, 0] # r'_c
+ m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat(
+ [w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
+ m_t[:, 0, 2] = -t[:, 2] # t'_x
+ m_t[:, 1, 2] = -t[:, 3] # t'_y
+ transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
+
+ # Transform frequencies.
+ phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
+ freqs = freqs @ transforms[:, :2, :2]
+
+ # Dampen out-of-band frequencies that may occur due to the user-specified transform.
+ amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) /
+ (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
+
+ # Construct sampling grid.
+ theta = torch.eye(2, 3, device=w.device)
+ theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
+ theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
+ grids = torch.nn.functional.affine_grid(
+ theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]],
+ align_corners=False)
+
+ # Compute Fourier features.
+ x = (grids.unsqueeze(3) @ freqs.permute(
+ 0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(
+ 3) # [batch, height, width, channel]
+ x = x + phases.unsqueeze(1).unsqueeze(2)
+ x = torch.sin(x * (np.pi * 2))
+ x = x * amplitudes.unsqueeze(1).unsqueeze(2)
+
+ # Apply trainable mapping.
+ weight = self.weight / np.sqrt(self.channels)
+ x = x @ weight.t()
+
+ # Ensure correct shape.
+ x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
+ misc.assert_shape(
+ x,
+ [w.shape[0], self.channels,
+ int(self.size[1]),
+ int(self.size[0])])
+ return x
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
+ f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'
+ ])
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class SynthesisLayer(torch.nn.Module):
+ def __init__(
+ self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ is_torgb, # Is this the final ToRGB layer?
+ is_critically_sampled, # Does this layer use critical sampling?
+ use_fp16, # Does this layer use FP16?
+
+ # Input & output specifications.
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ in_size, # Input spatial size: int or [width, height].
+ out_size, # Output spatial size: int or [width, height].
+ in_sampling_rate, # Input sampling rate (s).
+ out_sampling_rate, # Output sampling rate (s).
+ in_cutoff, # Input cutoff frequency (f_c).
+ out_cutoff, # Output cutoff frequency (f_c).
+ in_half_width, # Input transition band half-width (f_h).
+ out_half_width, # Output Transition band half-width (f_h).
+
+ # Hyperparameters.
+ conv_kernel=3, # Convolution kernel size. Ignored for final the ToRGB layer.
+ filter_size=6, # Low-pass filter size relative to the lower resolution when up/downsampling.
+ lrelu_upsampling=2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
+ use_radial_filters=False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
+ conv_clamp=256, # Clamp the output to [-X, +X], None = disable clamping.
+ magnitude_ema_beta=0.999, # Decay rate for the moving average of input magnitudes.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.is_torgb = is_torgb
+ self.is_critically_sampled = is_critically_sampled
+ self.use_fp16 = use_fp16
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.in_size = np.broadcast_to(np.asarray(in_size), [2])
+ self.out_size = np.broadcast_to(np.asarray(out_size), [2])
+ self.in_sampling_rate = in_sampling_rate
+ self.out_sampling_rate = out_sampling_rate
+ self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (
+ 1 if is_torgb else lrelu_upsampling)
+ self.in_cutoff = in_cutoff
+ self.out_cutoff = out_cutoff
+ self.in_half_width = in_half_width
+ self.out_half_width = out_half_width
+ self.conv_kernel = 1 if is_torgb else conv_kernel
+ self.conv_clamp = conv_clamp
+ self.magnitude_ema_beta = magnitude_ema_beta
+
+ # Setup parameters and buffers.
+ self.affine = FullyConnectedLayer(self.w_dim,
+ self.in_channels,
+ bias_init=1)
+ self.weight = torch.nn.Parameter(
+ torch.randn([
+ self.out_channels, self.in_channels, self.conv_kernel,
+ self.conv_kernel
+ ]))
+ self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
+ self.register_buffer('magnitude_ema', torch.ones([]))
+
+ # Design upsampling filter.
+ self.up_factor = int(
+ np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
+ assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
+ self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
+ self.register_buffer(
+ 'up_filter',
+ self.design_lowpass_filter(numtaps=self.up_taps,
+ cutoff=self.in_cutoff,
+ width=self.in_half_width * 2,
+ fs=self.tmp_sampling_rate))
+
+ # Design downsampling filter.
+ self.down_factor = int(
+ np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
+ assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
+ self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
+ self.down_radial = use_radial_filters and not self.is_critically_sampled
+ self.register_buffer(
+ 'down_filter',
+ self.design_lowpass_filter(numtaps=self.down_taps,
+ cutoff=self.out_cutoff,
+ width=self.out_half_width * 2,
+ fs=self.tmp_sampling_rate,
+ radial=self.down_radial))
+
+ # Compute padding.
+ pad_total = (
+ self.out_size - 1
+ ) * self.down_factor + 1 # Desired output size before downsampling.
+ pad_total -= (self.in_size + self.conv_kernel -
+ 1) * self.up_factor # Input size after upsampling.
+ pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
+ pad_lo = (
+ pad_total + self.up_factor
+ ) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
+ pad_hi = pad_total - pad_lo
+ self.padding = [
+ int(pad_lo[0]),
+ int(pad_hi[0]),
+ int(pad_lo[1]),
+ int(pad_hi[1])
+ ]
+
+ def forward(self,
+ x,
+ w,
+ noise_mode='random',
+ force_fp32=False,
+ update_emas=False):
+ assert noise_mode in ['random', 'const', 'none'] # unused
+ misc.assert_shape(x, [
+ None, self.in_channels,
+ int(self.in_size[1]),
+ int(self.in_size[0])
+ ])
+ misc.assert_shape(w, [x.shape[0], self.w_dim])
+
+ # Track input magnitude.
+ if update_emas:
+ with torch.autograd.profiler.record_function(
+ 'update_magnitude_ema'):
+ magnitude_cur = x.detach().to(torch.float32).square().mean()
+ self.magnitude_ema.copy_(
+ magnitude_cur.lerp(self.magnitude_ema,
+ self.magnitude_ema_beta))
+ input_gain = self.magnitude_ema.rsqrt()
+
+ # Execute affine layer.
+ styles = self.affine(w)
+ if self.is_torgb:
+ weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel**2))
+ styles = styles * weight_gain
+
+ # Execute modulated conv2d.
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and
+ x.device.type == 'cuda') else torch.float32
+ x = modulated_conv2d(x=x.to(dtype),
+ w=self.weight,
+ s=styles,
+ padding=self.conv_kernel - 1,
+ demodulate=(not self.is_torgb),
+ input_gain=input_gain)
+
+ # Execute bias, filtered leaky ReLU, and clamping.
+ gain = 1 if self.is_torgb else np.sqrt(2)
+ slope = 1 if self.is_torgb else 0.2
+ x = filtered_lrelu.filtered_lrelu(x=x,
+ fu=self.up_filter,
+ fd=self.down_filter,
+ b=self.bias.to(x.dtype),
+ up=self.up_factor,
+ down=self.down_factor,
+ padding=self.padding,
+ gain=gain,
+ slope=slope,
+ clamp=self.conv_clamp)
+
+ # Ensure correct shape and dtype.
+ misc.assert_shape(x, [
+ None, self.out_channels,
+ int(self.out_size[1]),
+ int(self.out_size[0])
+ ])
+ assert x.dtype == dtype
+ return x
+
+ @staticmethod
+ def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
+ assert numtaps >= 1
+
+ # Identity filter.
+ if numtaps == 1:
+ return None
+
+ # Separable Kaiser low-pass filter.
+ if not radial:
+ f = scipy.signal.firwin(numtaps=numtaps,
+ cutoff=cutoff,
+ width=width,
+ fs=fs)
+ return torch.as_tensor(f, dtype=torch.float32)
+
+ # Radially symmetric jinc-based filter.
+ x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
+ r = np.hypot(*np.meshgrid(x, x))
+ f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
+ beta = scipy.signal.kaiser_beta(
+ scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
+ w = np.kaiser(numtaps, beta)
+ f *= np.outer(w, w)
+ f /= np.sum(f)
+ return torch.as_tensor(f, dtype=torch.float32)
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
+ f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
+ f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
+ f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
+ f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
+ f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'
+ ])
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(
+ self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base=32768, # Overall multiplier for the number of channels.
+ channel_max=512, # Maximum number of channels in any layer.
+ num_layers=14, # Total number of layers, excluding Fourier features and ToRGB.
+ num_critical=2, # Number of critically sampled layers at the end.
+ first_cutoff=2, # Cutoff frequency of the first layer (f_{c,0}).
+ first_stopband=2**
+ 2.1, # Minimum stopband of the first layer (f_{t,0}).
+ last_stopband_rel=2**
+ 0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
+ margin_size=10, # Number of additional pixels outside the image.
+ output_scale=0.25, # Scale factor for the output image.
+ num_fp16_res=4, # Use FP16 for the N highest resolutions.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.num_ws = num_layers + 2
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.num_layers = num_layers
+ self.num_critical = num_critical
+ self.margin_size = margin_size
+ self.output_scale = output_scale
+ self.num_fp16_res = num_fp16_res
+
+ # Geometric progression of layer cutoffs and min. stopbands.
+ last_cutoff = self.img_resolution / 2 # f_{c,N}
+ last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
+ exponents = np.minimum(
+ np.arange(self.num_layers + 1) /
+ (self.num_layers - self.num_critical), 1)
+ cutoffs = first_cutoff * (last_cutoff /
+ first_cutoff)**exponents # f_c[i]
+ stopbands = first_stopband * (last_stopband /
+ first_stopband)**exponents # f_t[i]
+
+ # Compute remaining layer parameters.
+ sampling_rates = np.exp2(
+ np.ceil(np.log2(np.minimum(stopbands * 2,
+ self.img_resolution)))) # s[i]
+ half_widths = np.maximum(stopbands,
+ sampling_rates / 2) - cutoffs # f_h[i]
+ sizes = sampling_rates + self.margin_size * 2
+ sizes[-2:] = self.img_resolution
+ channels = np.rint(
+ np.minimum((channel_base / 2) / cutoffs, channel_max))
+ channels[-1] = self.img_channels
+
+ # Construct layers.
+ self.input = SynthesisInput(w_dim=self.w_dim,
+ channels=int(channels[0]),
+ size=int(sizes[0]),
+ sampling_rate=sampling_rates[0],
+ bandwidth=cutoffs[0])
+ self.layer_names = []
+ for idx in range(self.num_layers + 1):
+ prev = max(idx - 1, 0)
+ is_torgb = (idx == self.num_layers)
+ is_critically_sampled = (idx >=
+ self.num_layers - self.num_critical)
+ use_fp16 = (sampling_rates[idx] *
+ (2**self.num_fp16_res) > self.img_resolution)
+ layer = SynthesisLayer(w_dim=self.w_dim,
+ is_torgb=is_torgb,
+ is_critically_sampled=is_critically_sampled,
+ use_fp16=use_fp16,
+ in_channels=int(channels[prev]),
+ out_channels=int(channels[idx]),
+ in_size=int(sizes[prev]),
+ out_size=int(sizes[idx]),
+ in_sampling_rate=int(sampling_rates[prev]),
+ out_sampling_rate=int(sampling_rates[idx]),
+ in_cutoff=cutoffs[prev],
+ out_cutoff=cutoffs[idx],
+ in_half_width=half_widths[prev],
+ out_half_width=half_widths[idx],
+ **layer_kwargs)
+ name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
+ setattr(self, name, layer)
+ self.layer_names.append(name)
+
+ def forward(self, ws, **layer_kwargs):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32).unbind(dim=1)
+
+ # Execute layers.
+ x = self.input(ws[0])
+ for name, w in zip(self.layer_names, ws[1:]):
+ x = getattr(self, name)(x, w, **layer_kwargs)
+ if self.output_scale != 1:
+ x = x * self.output_scale
+
+ # Ensure correct shape and dtype.
+ misc.assert_shape(x, [
+ None, self.img_channels, self.img_resolution, self.img_resolution
+ ])
+ x = x.to(torch.float32)
+ return x
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
+ f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'
+ ])
+
+
+#----------------------------------------------------------------------------
+
+
+@persistence.persistent_class
+class Generator(torch.nn.Module):
+ def __init__(
+ self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim,
+ img_resolution=img_resolution,
+ img_channels=img_channels,
+ **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim,
+ c_dim=c_dim,
+ w_dim=w_dim,
+ num_ws=self.num_ws,
+ **mapping_kwargs)
+
+ def forward(self,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False,
+ **synthesis_kwargs):
+ ws = self.mapping(z,
+ c,
+ truncation_psi=truncation_psi,
+ truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
+
+
+#----------------------------------------------------------------------------
diff --git a/nsr/script_util.py b/nsr/script_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7506aa850988bbdc1d40837b50b02801e39cabef
--- /dev/null
+++ b/nsr/script_util.py
@@ -0,0 +1,1489 @@
+import torch
+from torch import nn
+from nsr.triplane import Triplane_fg_bg_plane
+# import timm
+from vit.vit_triplane import Triplane, ViTTriplaneDecomposed
+import argparse
+import inspect
+import dnnlib
+from guided_diffusion import dist_util
+
+from pdb import set_trace as st
+
+import vit.vision_transformer as vits
+from guided_diffusion import logger
+from .confnet import ConfNet
+
+from ldm.modules.diffusionmodules.model import Encoder, MVEncoder, MVEncoderGS
+from ldm.modules.diffusionmodules.mv_unet import MVUNet, LGM_MVEncoder
+
+# from ldm.modules.diffusionmodules.openaimodel import MultiViewUNetModel_Encoder
+
+# * create pre-trained encoder & triplane / other nsr decoder
+
+
+class AE(torch.nn.Module):
+
+ def __init__(self,
+ encoder,
+ decoder,
+ img_size,
+ encoder_cls_token,
+ decoder_cls_token,
+ preprocess,
+ use_clip,
+ dino_version='v1',
+ clip_dtype=None,
+ no_dim_up_mlp=False,
+ dim_up_mlp_as_func=False,
+ uvit_skip_encoder=False,
+ confnet=None) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.img_size = img_size
+ self.encoder_cls_token = encoder_cls_token
+ self.decoder_cls_token = decoder_cls_token
+ self.use_clip = use_clip
+ self.dino_version = dino_version
+ self.confnet = confnet
+
+ if self.dino_version == 'v2':
+ self.encoder.mask_token = None
+ self.decoder.vit_decoder.mask_token = None
+
+ if 'sd' not in self.dino_version:
+
+ self.uvit_skip_encoder = uvit_skip_encoder
+ if uvit_skip_encoder:
+ logger.log(
+ f'enables uvit: length of vit_encoder.blocks: {len(self.encoder.blocks)}'
+ )
+ for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]:
+ blk.skip_linear = nn.Linear(2 * self.encoder.embed_dim,
+ self.encoder.embed_dim)
+
+ # trunc_normal_(blk.skip_linear.weight, std=.02)
+ nn.init.constant_(blk.skip_linear.weight, 0)
+ if isinstance(
+ blk.skip_linear,
+ nn.Linear) and blk.skip_linear.bias is not None:
+ nn.init.constant_(blk.skip_linear.bias, 0)
+ else:
+ logger.log(f'disable uvit')
+ else:
+ if 'dit' not in self.dino_version: # dino vit, not dit
+ self.decoder.vit_decoder.cls_token = None
+ self.decoder.vit_decoder.patch_embed.proj = nn.Identity()
+ self.decoder.triplane_decoder.planes = None
+ self.decoder.vit_decoder.mask_token = None
+
+ if self.use_clip:
+ self.clip_dtype = clip_dtype # torch.float16
+
+ else:
+
+ if not no_dim_up_mlp and self.encoder.embed_dim != self.decoder.vit_decoder.embed_dim:
+ self.dim_up_mlp = nn.Linear(
+ self.encoder.embed_dim,
+ self.decoder.vit_decoder.embed_dim)
+ logger.log(
+ f"dim_up_mlp: {self.encoder.embed_dim} -> {self.decoder.vit_decoder.embed_dim}, as_func: {self.dim_up_mlp_as_func}"
+ )
+ else:
+ logger.log('ignore dim_up_mlp: ', no_dim_up_mlp)
+
+ self.preprocess = preprocess
+
+ self.dim_up_mlp = None # CLIP/B-16
+ self.dim_up_mlp_as_func = dim_up_mlp_as_func
+
+ # * remove certain components to make sure no unused parameters during DDP
+ # self.decoder.vit_decoder.cls_token = nn.Identity()
+ torch.cuda.empty_cache()
+ # self.decoder.vit_decoder.patch_embed.proj.bias = nn.Identity()
+ # self.decoder.vit_decoder.patch_embed.proj.weight = nn.Identity()
+ # self.decoder.vit_decoder.patch_embed.proj.bias = nn.Identity()
+
+ def encode(self, *args, **kwargs):
+ if not self.use_clip:
+ if self.dino_version == 'v1':
+ latent = self.encode_dinov1(*args, **kwargs)
+ elif self.dino_version == 'v2':
+ if self.uvit_skip_encoder:
+ latent = self.encode_dinov2_uvit(*args, **kwargs)
+ else:
+ latent = self.encode_dinov2(*args, **kwargs)
+ else:
+ latent = self.encoder(*args)
+
+ else:
+ latent = self.encode_clip(*args, **kwargs)
+
+ return latent
+
+ def encode_dinov1(self, x):
+ # return self.encoder(img)
+ x = self.encoder.prepare_tokens(x)
+ for blk in self.encoder.blocks:
+ x = blk(x)
+ x = self.encoder.norm(x)
+ if not self.encoder_cls_token:
+ return x[:, 1:]
+
+ return x
+
+ def encode_dinov2(self, x):
+ # return self.encoder(img)
+ x = self.encoder.prepare_tokens_with_masks(x, masks=None)
+ for blk in self.encoder.blocks:
+ x = blk(x)
+ x_norm = self.encoder.norm(x)
+
+ if not self.encoder_cls_token:
+ return x_norm[:, 1:]
+ # else:
+ # return x_norm[:, :1]
+
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # }
+
+ return x_norm
+
+ def encode_dinov2_uvit(self, x):
+ # return self.encoder(img)
+ x = self.encoder.prepare_tokens_with_masks(x, masks=None)
+
+ # for blk in self.encoder.blocks:
+ # x = blk(x)
+
+ skips = [x]
+
+ # in blks
+ for blk in self.encoder.blocks[0:len(self.encoder.blocks) // 2 - 1]:
+ x = blk(x) # B 3 N C
+ skips.append(x)
+
+ # mid blks
+ for blk in self.encoder.blocks[len(self.encoder.blocks) // 2 -
+ 1:len(self.encoder.blocks) // 2]:
+ x = blk(x) # B 3 N C
+
+ # out blks
+ for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]:
+ x = x + blk.skip_linear(torch.cat(
+ [x, skips.pop()], dim=-1)) # long skip connections in uvit
+ x = blk(x) # B 3 N C
+
+ x_norm = self.encoder.norm(x)
+
+ if not self.decoder_cls_token:
+ return x_norm[:, 1:]
+
+ return x_norm
+
+ def encode_clip(self, x):
+ # * replace with CLIP encoding pipeline
+ # return self.encoder(img)
+ # x = x.dtype(self.clip_dtype)
+ x = self.encoder.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([
+ self.encoder.class_embedding.to(x.dtype) + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
+ ],
+ dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.encoder.positional_embedding.to(x.dtype)
+ x = self.encoder.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.encoder.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.encoder.ln_post(x[:, 1:, :]) # * return the spatial tokens
+
+ return x
+
+ # x = self.ln_post(x[:, 0, :]) # * return the spatial tokens
+
+ # if self.proj is not None:
+ # x = x @ self.proj
+
+ # return x
+
+ def decode_wo_triplane(self, latent, c=None, img_size=None):
+ if img_size is None:
+ img_size = self.img_size
+
+ if self.dim_up_mlp is not None:
+ if not self.dim_up_mlp_as_func:
+ latent = self.dim_up_mlp(latent)
+ # return self.decoder.vit_decode(latent, img_size)
+ else:
+ return self.decoder.vit_decode(
+ latent, img_size,
+ dim_up_mlp=self.dim_up_mlp) # used in vae-ldm
+
+ return self.decoder.vit_decode(latent, img_size, c=c)
+
+ def decode(self, latent, c, img_size=None, return_raw_only=False):
+ # if img_size is None:
+ # img_size = self.img_size
+
+ # if self.dim_up_mlp is not None:
+ # latent = self.dim_up_mlp(latent)
+
+ latent = self.decode_wo_triplane(latent, img_size=img_size, c=c)
+ # return self.decoder.triplane_decode(latent, c, return_raw_only=return_raw_only)
+ return self.decoder.triplane_decode(latent, c)
+
+ def decode_after_vae_no_render(
+ self,
+ ret_dict,
+ img_size=None,
+ ):
+
+ if img_size is None:
+ img_size = self.img_size
+
+ assert self.dim_up_mlp is None
+ # if not self.dim_up_mlp_as_func:
+ # latent = self.dim_up_mlp(latent)
+ # return self.decoder.vit_decode(latent, img_size)
+
+ latent = self.decoder.vit_decode_backbone(ret_dict, img_size)
+ ret_dict = self.decoder.vit_decode_postprocess(latent, ret_dict)
+ return ret_dict
+
+ def decode_after_vae(
+ self,
+ # latent,
+ ret_dict, # vae_dict
+ c,
+ img_size=None,
+ return_raw_only=False):
+ ret_dict = self.decode_after_vae_no_render(ret_dict, img_size)
+ return self.decoder.triplane_decode(ret_dict, c)
+
+ def decode_confmap(self, img):
+ assert self.confnet is not None
+ # https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L152
+ # conf_sigma_l1 = self.confnet(img) # Bx2xHxW
+ return self.confnet(img) # Bx1xHxW
+
+ def encode_decode(self, img, c, return_raw_only=False):
+ latent = self.encode(img)
+ pred = self.decode(latent, c, return_raw_only=return_raw_only)
+ if self.confnet is not None:
+ pred.update({
+ 'conf_sigma': self.decode_confmap(img) # 224x224
+ })
+
+ return pred
+
+ def forward(self,
+ img=None,
+ c=None,
+ latent=None,
+ behaviour='enc_dec',
+ coordinates=None,
+ directions=None,
+ return_raw_only=False,
+ *args,
+ **kwargs):
+ """wrap all operations inside forward() for DDP use.
+ """
+
+ if behaviour == 'enc_dec':
+ pred = self.encode_decode(img, c, return_raw_only=return_raw_only)
+ return pred
+
+ elif behaviour == 'enc':
+ latent = self.encode(img)
+ return latent
+
+ elif behaviour == 'dec':
+ assert latent is not None
+ pred: dict = self.decode(latent,
+ c,
+ self.img_size,
+ return_raw_only=return_raw_only)
+ return pred
+
+ elif behaviour == 'dec_wo_triplane':
+ assert latent is not None
+ pred: dict = self.decode_wo_triplane(latent, self.img_size)
+ return pred
+
+ elif behaviour == 'enc_dec_wo_triplane':
+ latent = self.encode(img)
+ pred: dict = self.decode_wo_triplane(latent, img_size=self.img_size, c=c)
+ return pred
+
+ elif behaviour == 'encoder_vae':
+ latent = self.encode(img)
+ ret_dict = self.decoder.vae_reparameterization(latent, True)
+ return ret_dict
+
+ elif behaviour == 'decode_after_vae_no_render':
+ pred: dict = self.decode_after_vae_no_render(latent, self.img_size)
+ return pred
+
+ elif behaviour == 'decode_after_vae':
+ pred: dict = self.decode_after_vae(latent, c, self.img_size)
+ return pred
+
+ # elif behaviour == 'gaussian_dec':
+ # assert latent is not None
+ # pred: dict = self.decoder.triplane_decode(
+ # latent, c, return_raw_only=return_raw_only, **kwargs)
+ # # pred: dict = self.decoder.triplane_decode(latent, c)
+
+ elif behaviour == 'triplane_dec':
+ assert latent is not None
+ pred: dict = self.decoder.triplane_decode(
+ latent, c, return_raw_only=return_raw_only, **kwargs)
+ # pred: dict = self.decoder.triplane_decode(latent, c)
+
+ elif behaviour == 'triplane_decode_grid':
+ assert latent is not None
+ pred: dict = self.decoder.triplane_decode_grid(
+ latent, **kwargs)
+ # pred: dict = self.decoder.triplane_decode(latent, c)
+
+ elif behaviour == 'vit_postprocess_triplane_dec':
+ assert latent is not None
+ latent = self.decoder.vit_decode_postprocess(
+ latent) # translate spatial token from vit-decoder into 2D
+ pred: dict = self.decoder.triplane_decode(
+ latent, c) # render with triplane
+
+ elif behaviour == 'triplane_renderer':
+ assert latent is not None
+ pred: dict = self.decoder.triplane_renderer(
+ latent, coordinates, directions)
+
+ # elif behaviour == 'triplane_SR':
+ # assert latent is not None
+ # pred: dict = self.decoder.triplane_renderer(
+ # latent, coordinates, directions)
+
+ elif behaviour == 'get_rendering_kwargs':
+ pred = self.decoder.triplane_decoder.rendering_kwargs
+
+ return pred
+
+
+class AE_CLIPEncoder(AE):
+
+ def __init__(self, encoder, decoder, img_size, cls_token) -> None:
+ super().__init__(encoder, decoder, img_size, cls_token)
+
+
+class AE_with_Diffusion(torch.nn.Module):
+
+ def __init__(self, auto_encoder, denoise_model) -> None:
+ super().__init__()
+ self.auto_encoder = auto_encoder
+ self.denoise_model = denoise_model # simply for easy MPTrainer manipulation
+
+ def forward(self,
+ img,
+ c,
+ behaviour='enc_dec',
+ latent=None,
+ *args,
+ **kwargs):
+ # wrap auto_encoder and denoising model inside a single forward function to use DDP (only forward supported) and MPTrainer (single model) easier
+ if behaviour == 'enc_dec':
+ pred = self.auto_encoder(img, c)
+ return pred
+ elif behaviour == 'enc':
+ latent = self.auto_encoder.encode(img)
+ if self.auto_encoder.dim_up_mlp is not None:
+ latent = self.auto_encoder.dim_up_mlp(latent)
+ return latent
+ elif behaviour == 'dec':
+ assert latent is not None
+ pred: dict = self.auto_encoder.decode(latent, c, self.img_size)
+ return pred
+ elif behaviour == 'denoise':
+ assert latent is not None
+ pred: dict = self.denoise_model(*args, **kwargs)
+ return pred
+
+
+def eg3d_options_default():
+
+ opts = dnnlib.EasyDict(
+ dict(
+ cbase=32768,
+ cmax=512,
+ map_depth=2,
+ g_class_name='nsr.triplane.TriPlaneGenerator', # TODO
+ g_num_fp16_res=0,
+ ))
+
+ return opts
+
+
+def rendering_options_defaults(opts):
+
+ rendering_options = {
+ # 'image_resolution': c.training_set_kwargs.resolution,
+ 'image_resolution': 256,
+ 'disparity_space_sampling': False,
+ 'clamp_mode': 'softplus',
+ 'c_gen_conditioning_zero':
+ True, # if true, fill generator pose conditioning label with dummy zero vector
+ # 'gpc_reg_prob': opts.gpc_reg_prob if opts.gen_pose_cond else None,
+ 'c_scale':
+ opts.c_scale, # mutliplier for generator pose conditioning label
+ 'superresolution_noise_mode': 'none',
+ 'density_reg': opts.density_reg, # strength of density regularization
+ 'density_reg_p_dist': opts.
+ density_reg_p_dist, # distance at which to sample perturbed points for density regularization
+ 'reg_type': opts.
+ reg_type, # for experimenting with variations on density regularization
+ 'decoder_lr_mul': 1,
+ # opts.decoder_lr_mul, # learning rate multiplier for decoder
+ 'decoder_activation': 'sigmoid',
+ 'sr_antialias': True,
+ 'return_triplane_features': False, # for DDF supervision
+ 'return_sampling_details_flag': False,
+
+ # * shape default sr
+
+ # 'superresolution_module': 'nsr.superresolution.SuperresolutionHybrid4X',
+ # 'superresolution_module':
+ # 'utils.torch_utils.components.PixelUnshuffleUpsample',
+ 'superresolution_module': 'utils.torch_utils.components.NearestConvSR',
+ }
+
+ if opts.cfg == 'ffhq':
+ rendering_options.update({
+ 'superresolution_module':
+ 'nsr.superresolution.SuperresolutionHybrid8XDC',
+ 'focal': 2985.29 / 700,
+ 'depth_resolution':
+ 48 - 0, # number of uniform samples to take per ray.
+ 'depth_resolution_importance':
+ 48 - 0, # number of importance samples to take per ray.
+ 'bg_depth_resolution':
+ 16, # 4/14 in stylenerf, https://github.com/facebookresearch/StyleNeRF/blob/7f5610a058f27fcc360c6b972181983d7df794cb/conf/model/stylenerf_ffhq.yaml#L48
+ 'ray_start':
+ 2.25, # near point along each ray to start taking samples.
+ 'ray_end':
+ 3.3, # far point along each ray to stop taking samples.
+ 'box_warp':
+ 1, # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5].
+ 'avg_camera_radius':
+ 2.7, # used only in the visualizer to specify camera orbit radius.
+ 'avg_camera_pivot': [
+ 0, 0, 0.2
+ ], # used only in the visualizer to control center of camera rotation.
+ 'superresolution_noise_mode': 'random',
+ })
+ elif opts.cfg == 'afhq':
+ rendering_options.update({
+ 'superresolution_module':
+ 'nsr.superresolution.SuperresolutionHybrid8X',
+ 'superresolution_noise_mode': 'random',
+ 'focal': 4.2647,
+ 'depth_resolution': 48,
+ 'depth_resolution_importance': 48,
+ 'ray_start': 2.25,
+ 'ray_end': 3.3,
+ 'box_warp': 1,
+ 'avg_camera_radius': 2.7,
+ 'avg_camera_pivot': [0, 0, -0.06],
+ })
+ elif opts.cfg == 'shapenet': # TODO, lies in a sphere
+ rendering_options.update({
+ 'depth_resolution': 64,
+ 'depth_resolution_importance': 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start': 0.2,
+ 'ray_end': 2.2,
+ # 'ray_start': opts.ray_start,
+ # 'ray_end': opts.ray_end,
+ 'box_warp': 2, # TODO, how to set this value?
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'eg3d_shapenet_aug_resolution':
+ rendering_options.update({
+ 'depth_resolution': 80,
+ 'depth_resolution_importance': 80,
+ 'ray_start': 0.1,
+ 'ray_end': 1.9, # 2.6/1.7*1.2
+ 'box_warp': 1.1,
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair':
+ rendering_options.update({
+ 'depth_resolution': 96,
+ 'depth_resolution_importance': 96,
+ 'ray_start': 0.1,
+ 'ray_end': 1.9, # 2.6/1.7*1.2
+ 'box_warp': 1.1,
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128':
+ rendering_options.update({
+ 'depth_resolution': 128,
+ 'depth_resolution_importance': 128,
+ 'ray_start': 0.1,
+ 'ray_end': 1.9, # 2.6/1.7*1.2
+ 'box_warp': 1.1,
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_64':
+ rendering_options.update({
+ 'depth_resolution': 64,
+ 'depth_resolution_importance': 64,
+ 'ray_start': 0.1,
+ 'ray_end': 1.9, # 2.6/1.7*1.2
+ 'box_warp': 1.1,
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'srn_shapenet_aug_resolution_chair_128':
+ rendering_options.update({
+ 'depth_resolution': 128,
+ 'depth_resolution_importance': 128,
+ 'ray_start': 1.25,
+ 'ray_end': 2.75,
+ 'box_warp': 1.5,
+ 'white_back': True,
+ 'avg_camera_radius': 2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128_residualSR':
+ rendering_options.update({
+ 'depth_resolution':
+ 128,
+ 'depth_resolution_importance':
+ 128,
+ 'ray_start':
+ 0.1,
+ 'ray_end':
+ 1.9, # 2.6/1.7*1.2
+ 'box_warp':
+ 1.1,
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR_Residual',
+ })
+
+ elif opts.cfg == 'shapenet_tuneray': # TODO, lies in a sphere
+ rendering_options.update({
+ 'depth_resolution': 64,
+ 'depth_resolution_importance': 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start': opts.ray_start,
+ 'ray_end': opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution': 80,
+ 'depth_resolution_importance': 80,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start': opts.ray_start,
+ 'ray_end': opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution': 128,
+ 'depth_resolution_importance': 128,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start': opts.ray_start,
+ 'ray_end': opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution': 96,
+ 'depth_resolution_importance': 96,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start': opts.ray_start,
+ 'ray_end': opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+ # ! default version
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestSR': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 96,
+ 'depth_resolution_importance':
+ 96,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ opts.ray_start,
+ 'ray_end':
+ opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR',
+ })
+
+ # ! 64+64, since ssdnerf adopts this setting
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 64,
+ 'depth_resolution_importance':
+ 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ opts.ray_start,
+ 'ray_end':
+ opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR',
+ })
+
+ # ! 64+64+patch, since ssdnerf adopts this setting
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR_patch': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 64,
+ 'depth_resolution_importance':
+ 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ opts.ray_start,
+ 'ray_end':
+ opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR',
+ # patch configs
+ 'PatchRaySampler':
+ True,
+ # 'patch_rendering_resolution': 32,
+ # 'patch_rendering_resolution': 48,
+ 'patch_rendering_resolution':
+ opts.patch_rendering_resolution,
+ })
+
+ elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_nearestSR': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 64,
+ 'depth_resolution_importance':
+ 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ opts.ray_start,
+ # 'auto',
+ 'ray_end':
+ opts.ray_end,
+ # 'auto',
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ # 2,
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.946, # ?
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR',
+ # patch configs
+ # 'PatchRaySampler': False,
+ # 'patch_rendering_resolution': 32,
+ # 'patch_rendering_resolution': 48,
+ # 'patch_rendering_resolution': opts.patch_rendering_resolution,
+ })
+
+ elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_auto': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 64,
+ 'depth_resolution_importance':
+ 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ 'auto',
+ 'ray_end':
+ 'auto',
+ 'box_warp':
+ 0.9,
+ 'white_back':
+ True,
+ 'radius_range': [1.5,2],
+ # 'z_near': 1.5-0.45, # radius in [1.5, 2], https://github.com/modelscope/richdreamer/issues/12#issuecomment-1897734616
+ # 'z_far': 2.0+0.45,
+ 'sampler_bbox_min':
+ -0.45,
+ 'sampler_bbox_max':
+ 0.45,
+ # 'avg_camera_pivot': [0, 0, 0], # not used
+ 'filter_out_of_bbox':
+ True,
+ # 'superresolution_module':
+ # 'utils.torch_utils.components.NearestConvSR',
+ # patch configs
+ 'PatchRaySampler':
+ True,
+ # 'patch_rendering_resolution': 32,
+ # 'patch_rendering_resolution': 48,
+ 'patch_rendering_resolution':
+ opts.patch_rendering_resolution,
+ })
+ rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min']
+ rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max']
+
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestResidualSR': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 96,
+ 'depth_resolution_importance':
+ 96,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ opts.ray_start,
+ 'ray_end':
+ opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR_Residual',
+ })
+
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestResidualSR': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution':
+ 64,
+ 'depth_resolution_importance':
+ 64,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start':
+ opts.ray_start,
+ 'ray_end':
+ opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back':
+ True,
+ 'avg_camera_radius':
+ 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ 'superresolution_module':
+ 'utils.torch_utils.components.NearestConvSR_Residual',
+ })
+
+ elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_104': # to differentiate hwc
+ rendering_options.update({
+ 'depth_resolution': 104,
+ 'depth_resolution_importance': 104,
+ # * radius 1.2 setting, newly rendered images
+ 'ray_start': opts.ray_start,
+ 'ray_end': opts.ray_end,
+ 'box_warp':
+ opts.ray_end - opts.ray_start, # TODO, how to set this value?
+ 'white_back': True,
+ 'avg_camera_radius': 1.2,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+
+ rendering_options.update({'return_sampling_details_flag': True})
+ rendering_options.update({'return_sampling_details_flag': True})
+
+ return rendering_options
+
+
+def model_encoder_defaults():
+
+ return dict(
+ use_clip=False,
+ arch_encoder="vits",
+ arch_decoder="vits",
+ load_pretrain_encoder=False,
+ encoder_lr=1e-5,
+ encoder_weight_decay=
+ 0.001, # https://github.com/google-research/vision_transformer
+ no_dim_up_mlp=False,
+ dim_up_mlp_as_func=False,
+ decoder_load_pretrained=True,
+ uvit_skip_encoder=False,
+ # vae ldm
+ vae_p=1,
+ ldm_z_channels=4,
+ ldm_embed_dim=4,
+ use_conf_map=False,
+ # sd E, lite version by default
+ sd_E_ch=64,
+ z_channels=3*4,
+ sd_E_num_res_blocks=1,
+ # vit_decoder
+ arch_dit_decoder='DiT2-B/2',
+ return_all_dit_layers=False,
+ # sd D
+ # sd_D_ch=32,
+ # sd_D_res_blocks=1,
+ # sd_D_res_blocks=1,
+ lrm_decoder=False,
+ gs_rendering=False,
+ )
+
+
+def triplane_decoder_defaults():
+ opts = dict(
+ triplane_fg_bg=False,
+ cfg='shapenet',
+ density_reg=0.25,
+ density_reg_p_dist=0.004,
+ reg_type='l1',
+ triplane_decoder_lr=0.0025, # follow eg3d G lr
+ super_resolution_lr=0.0025,
+ # triplane_decoder_wd=0.1,
+ c_scale=1,
+ nsr_lr=0.02,
+ triplane_size=224,
+ decoder_in_chans=32,
+ triplane_in_chans=-1,
+ decoder_output_dim=3,
+ out_chans=96,
+ c_dim=25, # Conditioning label (C) dimensionality.
+ # ray_start=0.2,
+ # ray_end=2.2,
+ ray_start=0.6, # shapenet default
+ ray_end=1.8,
+ rendering_kwargs={},
+ sr_training=False,
+ bcg_synthesis=False, # from panohead
+ bcg_synthesis_kwargs={}, # G_kwargs.copy()
+ #
+ image_size=128, # raw 3D rendering output resolution.
+ patch_rendering_resolution=45,
+ )
+
+ # else:
+ # assert False, "Need to specify config"
+
+ # opts = dict(opts)
+ # opts.pop('cfg')
+
+ return opts
+
+
+def vit_decoder_defaults():
+ res = dict(
+ vit_decoder_lr=1e-5, # follow eg3d G lr
+ vit_decoder_wd=0.001,
+ )
+ return res
+
+
+def nsr_decoder_defaults():
+ res = {
+ 'decomposed': False,
+ } # TODO, add defaults for all nsr
+ res.update(triplane_decoder_defaults()) # triplane by default now
+ res.update(vit_decoder_defaults()) # type: ignore
+ return res
+
+
+def loss_defaults():
+ opt = dict(
+ color_criterion='mse',
+ l2_lambda=1.0,
+ lpips_lambda=0.,
+ lpips_delay_iter=0,
+ sr_delay_iter=0,
+ # kl_anneal=0,
+ kl_anneal=False,
+ latent_lambda=0.,
+ latent_criterion='mse',
+ kl_lambda=0.0,
+ # kl_anneal=False,
+ ssim_lambda=0.,
+ l1_lambda=0.,
+ id_lambda=0.0,
+ depth_lambda=0.0, # TODO
+ alpha_lambda=0.0, # TODO
+ fg_mse=False,
+ bg_lamdba=0.0,
+ density_reg=0.0, # tvloss in eg3d
+ density_reg_p_dist=0.004, # 'density regularization strength.'
+ density_reg_every=4, # lazy density reg
+
+ # 3D supervision, ffhq/afhq eg3d warm up
+ shape_uniform_lambda=0.005,
+ shape_importance_lambda=0.01,
+ shape_depth_lambda=0.,
+
+ # gan loss
+ rec_cvD_lambda=0.01,
+ nvs_cvD_lambda=0.025,
+ patchgan_disc_factor=0.01,
+ patchgan_disc_g_weight=0.2, #
+ r1_gamma=1.0, # ffhq default value for eg3d
+ sds_lamdba=1.0,
+ nvs_D_lr_mul=1, # compared with 1e-4
+ cano_D_lr_mul=1, # compared with 1e-4
+
+ # lsgm loss
+ ce_balanced_kl=1.,
+ p_eps_lambda=1,
+ # symmetric loss
+ symmetry_loss=False,
+ depth_smoothness_lambda=0.0,
+ ce_lambda=1.0,
+ negative_entropy_lambda=1.0,
+ grad_clip=False,
+ online_mask=False, # in unsup3d
+ )
+ return opt
+
+
+def dataset_defaults():
+ res = dict(
+ use_lmdb=False,
+ use_wds=False,
+ use_lmdb_compressed=True,
+ compile=False,
+ interval=1,
+ objv_dataset=False,
+ decode_encode_img_only=False,
+ load_wds_diff=False,
+ load_wds_latent=False,
+ eval_load_wds_instance=True,
+ shards_lst="",
+ eval_shards_lst="",
+ mv_input=False,
+ duplicate_sample=True,
+ orthog_duplicate=False,
+ split_chunk_input=False, # split=8 per chunk
+ load_real=False,
+ four_view_for_latent=False,
+ single_view_for_i23d=False,
+ shuffle_across_cls=False,
+ load_extra_36_view=False,
+ mv_latent_dir='',
+ append_depth=False,
+ plucker_embedding=False,
+ gs_cam_format=False,
+ )
+ return res
+
+
+def encoder_and_nsr_defaults():
+ """
+ Defaults for image training.
+ """
+ # ViT configs
+ res = dict(
+ dino_version='v1',
+ encoder_in_channels=3,
+ img_size=[224],
+ patch_size=16, # ViT-S/16
+ in_chans=384,
+ num_classes=0,
+ embed_dim=384, # Check ViT encoder dim
+ depth=6,
+ num_heads=16,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.1,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer='nn.LayerNorm',
+ # img_resolution=128, # Output resolution.
+ cls_token=False,
+ # image_size=128, # rendered output resolution.
+ # img_channels=3, # Number of output color channels.
+ encoder_cls_token=False,
+ decoder_cls_token=False,
+ sr_kwargs={},
+ sr_ratio=2,
+ # sd configs
+ )
+ # Triplane configs
+ res.update(model_encoder_defaults())
+ res.update(nsr_decoder_defaults())
+ res.update(
+ ae_classname='vit.vit_triplane.ViTTriplaneDecomposed') # if add SR
+ return res
+
+
+def create_3DAE_model(
+ arch_encoder,
+ arch_decoder,
+ dino_version='v1',
+ img_size=[224],
+ patch_size=16,
+ in_chans=384,
+ num_classes=0,
+ embed_dim=1024, # Check ViT encoder dim
+ depth=6,
+ num_heads=16,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.1,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ # norm_layer=nn.LayerNorm,
+ norm_layer='nn.LayerNorm',
+ out_chans=96,
+ decoder_in_chans=32,
+ triplane_in_chans=-1,
+ decoder_output_dim=32,
+ encoder_cls_token=False,
+ decoder_cls_token=False,
+ c_dim=25, # Conditioning label (C) dimensionality.
+ image_size=128, # Output resolution.
+ img_channels=3, # Number of output color channels.
+ rendering_kwargs={},
+ load_pretrain_encoder=False,
+ decomposed=True,
+ triplane_size=224,
+ ae_classname='ViTTriplaneDecomposed',
+ use_clip=False,
+ sr_kwargs={},
+ sr_ratio=2,
+ no_dim_up_mlp=False,
+ dim_up_mlp_as_func=False,
+ decoder_load_pretrained=True,
+ uvit_skip_encoder=False,
+ bcg_synthesis_kwargs={},
+ # decoder params
+ vae_p=1,
+ ldm_z_channels=4,
+ ldm_embed_dim=4,
+ use_conf_map=False,
+ triplane_fg_bg=False,
+ encoder_in_channels=3,
+ sd_E_ch=64,
+ z_channels=3*4,
+ sd_E_num_res_blocks=1,
+ arch_dit_decoder='DiT2-B/2',
+ lrm_decoder=False,
+ gs_rendering=False,
+ return_all_dit_layers=False,
+ *args,
+ **kwargs):
+
+ # TODO, check pre-trained ViT encoder cfgs
+
+ preprocess = None
+ clip_dtype = None
+ if load_pretrain_encoder:
+ if not use_clip:
+ if dino_version == 'v1':
+ encoder = torch.hub.load(
+ 'facebookresearch/dino:main',
+ 'dino_{}{}'.format(arch_encoder, patch_size))
+ logger.log(
+ f'loaded pre-trained dino v1 ViT-S{patch_size} encoder ckpt'
+ )
+ elif dino_version == 'v2':
+ encoder = torch.hub.load(
+ 'facebookresearch/dinov2',
+ 'dinov2_{}{}'.format(arch_encoder, patch_size))
+ logger.log(
+ f'loaded pre-trained dino v2 {arch_encoder}{patch_size} encoder ckpt'
+ )
+ elif 'sd' in dino_version: # just for compat
+
+ if 'mv' in dino_version:
+ if 'lgm' in dino_version:
+ encoder_cls = MVUNet(
+ input_size=256,
+ up_channels=(1024, 1024, 512, 256,
+ 128), # one more decoder
+ up_attention=(True, True, True, False, False),
+ splat_size=128,
+ output_size=
+ 512, # render & supervise Gaussians at a higher resolution.
+ batch_size=8,
+ num_views=8,
+ gradient_accumulation_steps=1,
+ # mixed_precision='bf16',
+ )
+ elif 'gs' in dino_version:
+ encoder_cls = MVEncoder
+ else:
+ encoder_cls = MVEncoder
+
+ else:
+ encoder_cls = Encoder
+
+ encoder = encoder_cls( # mono input
+ double_z=True,
+ resolution=256,
+ in_channels=encoder_in_channels,
+ # ch=128,
+ ch=64, # ! fit in the memory
+ # ch_mult=[1,2,4,4],
+ # num_res_blocks=2,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=1,
+ dropout=0.0,
+ attn_resolutions=[],
+ out_ch=3, # unused
+ z_channels=4 * 3,
+ ) # stable diffusion encoder
+ else:
+ raise NotImplementedError()
+
+ else:
+ import clip
+ model, preprocess = clip.load("ViT-B/16", device=dist_util.dev())
+ model.float() # convert weight to float32
+ clip_dtype = model.dtype
+ encoder = getattr(
+ model, 'visual') # only use the CLIP visual encoder here
+ encoder.requires_grad_(False)
+ logger.log(
+ f'loaded pre-trained CLIP ViT-B{patch_size} encoder, fixed.')
+
+ elif 'sd' in dino_version:
+ attn_kwargs = {}
+ if 'mv' in dino_version:
+ if 'lgm' in dino_version:
+ encoder = LGM_MVEncoder(
+ in_channels=9,
+ # input_size=256,
+ up_channels=(1024, 1024, 512, 256,
+ 128), # one more decoder
+ up_attention=(True, True, True, False, False),
+ # splat_size=128,
+ # output_size=
+ # 512, # render & supervise Gaussians at a higher resolution.
+ # batch_size=8,
+ # num_views=8,
+ # gradient_accumulation_steps=1,
+ # mixed_precision='bf16',
+ )
+
+ elif 'gs' in dino_version:
+ encoder_cls = MVEncoderGS
+ attn_kwargs = {
+ 'n_heads': 8,
+ 'd_head': 64,
+ }
+
+ else:
+ encoder_cls = MVEncoder
+ attn_kwargs = {
+ 'n_heads': 8,
+ 'd_head': 64,
+ }
+
+ else:
+ encoder_cls = Encoder
+
+ if 'lgm' not in dino_version: # TODO, for compat now
+ # st()
+ encoder = encoder_cls(
+ double_z=True,
+ resolution=256,
+ in_channels=encoder_in_channels,
+ # ch=128,
+ # ch=64, # ! fit in the memory
+ ch=sd_E_ch,
+ # ch_mult=[1,2,4,4],
+ # num_res_blocks=2,
+ ch_mult=[1, 2, 4, 4],
+ # num_res_blocks=1,
+ num_res_blocks=sd_E_num_res_blocks,
+ dropout=0.0,
+ attn_resolutions=[],
+ out_ch=3, # unused
+ z_channels=z_channels, # 4 * 3
+ attn_kwargs=attn_kwargs,
+ ) # stable diffusion encoder
+
+ else:
+ encoder = vits.__dict__[arch_encoder](
+ patch_size=patch_size,
+ drop_path_rate=drop_path_rate, # stochastic depth
+ img_size=img_size)
+
+ # assert decomposed
+ # if decomposed:
+ if triplane_in_chans == -1:
+ triplane_in_chans = decoder_in_chans
+
+ # if triplane_fg_bg:
+ # triplane_renderer_cls = Triplane_fg_bg_plane
+ # else:
+ triplane_renderer_cls = Triplane
+
+ # triplane_decoder = Triplane(
+ triplane_decoder = triplane_renderer_cls(
+ c_dim, # Conditioning label (C) dimensionality.
+ image_size, # Output resolution.
+ img_channels, # Number of output color channels.
+ rendering_kwargs=rendering_kwargs,
+ out_chans=out_chans,
+ # create_triplane=True, # compatability, remove later
+ triplane_size=triplane_size,
+ decoder_in_chans=triplane_in_chans,
+ decoder_output_dim=decoder_output_dim,
+ sr_kwargs=sr_kwargs,
+ bcg_synthesis_kwargs=bcg_synthesis_kwargs,
+ lrm_decoder=lrm_decoder)
+
+ if load_pretrain_encoder:
+
+ if dino_version == 'v1':
+ vit_decoder = torch.hub.load(
+ 'facebookresearch/dino:main',
+ 'dino_{}{}'.format(arch_decoder, patch_size))
+ logger.log(
+ 'loaded pre-trained decoder',
+ "facebookresearch/dino:main', 'dino_{}{}".format(
+ arch_decoder, patch_size))
+ else:
+
+ vit_decoder = torch.hub.load(
+ 'facebookresearch/dinov2',
+ # 'dinov2_{}{}'.format(arch_decoder, patch_size))
+ 'dinov2_{}{}'.format(arch_decoder, patch_size),
+ pretrained=decoder_load_pretrained)
+ logger.log(
+ 'loaded pre-trained decoder',
+ "facebookresearch/dinov2', 'dinov2_{}{}".format(
+ arch_decoder,
+ patch_size), 'pretrianed=', decoder_load_pretrained)
+
+ elif 'dit' in dino_version:
+ from dit.dit_decoder import DiT2_models
+
+ vit_decoder = DiT2_models[arch_dit_decoder](
+ input_size=16,
+ num_classes=0,
+ learn_sigma=False,
+ in_channels=embed_dim,
+ mixed_prediction=False,
+ context_dim=None, # add CLIP text embedding
+ roll_out=True, plane_n=4 if
+ 'gs' in dino_version else 3,
+ return_all_layers=return_all_dit_layers,
+ )
+
+ else: # has bug on global token, to fix
+ vit_decoder = vits.__dict__[arch_decoder](
+ patch_size=patch_size,
+ drop_path_rate=drop_path_rate, # stochastic depth
+ img_size=img_size)
+
+ # decoder = ViTTriplaneDecomposed(vit_decoder, triplane_decoder)
+ # if True:
+ decoder_kwargs = dict(
+ class_name=ae_classname,
+ vit_decoder=vit_decoder,
+ triplane_decoder=triplane_decoder,
+ # encoder_cls_token=encoder_cls_token,
+ cls_token=decoder_cls_token,
+ sr_ratio=sr_ratio,
+ vae_p=vae_p,
+ ldm_z_channels=ldm_z_channels,
+ ldm_embed_dim=ldm_embed_dim,
+ )
+ decoder = dnnlib.util.construct_class_by_name(**decoder_kwargs)
+
+
+ # if return_encoder_decoder:
+ # return encoder, decoder, img_size[0], cls_token
+ # else:
+
+ if use_conf_map:
+ confnet = ConfNet(cin=3, cout=1, nf=64, zdim=128)
+ else:
+ confnet = None
+
+ auto_encoder = AE(
+ encoder,
+ decoder,
+ img_size[0],
+ encoder_cls_token,
+ decoder_cls_token,
+ preprocess,
+ use_clip,
+ dino_version,
+ clip_dtype,
+ no_dim_up_mlp=no_dim_up_mlp,
+ dim_up_mlp_as_func=dim_up_mlp_as_func,
+ uvit_skip_encoder=uvit_skip_encoder,
+ confnet=confnet,
+ )
+
+ logger.log(auto_encoder)
+ torch.cuda.empty_cache()
+
+ return auto_encoder
+
+
+# def create_3DAE_Diffusion_model(
+# arch_encoder,
+# arch_decoder,
+# img_size=[224],
+# patch_size=16,
+# in_chans=384,
+# num_classes=0,
+# embed_dim=1024, # Check ViT encoder dim
+# depth=6,
+# num_heads=16,
+# mlp_ratio=4.,
+# qkv_bias=False,
+# qk_scale=None,
+# drop_rate=0.1,
+# attn_drop_rate=0.,
+# drop_path_rate=0.,
+# # norm_layer=nn.LayerNorm,
+# norm_layer='nn.LayerNorm',
+# out_chans=96,
+# decoder_in_chans=32,
+# decoder_output_dim=32,
+# cls_token=False,
+# c_dim=25, # Conditioning label (C) dimensionality.
+# img_resolution=128, # Output resolution.
+# img_channels=3, # Number of output color channels.
+# rendering_kwargs={},
+# load_pretrain_encoder=False,
+# decomposed=True,
+# triplane_size=224,
+# ae_classname='ViTTriplaneDecomposed',
+# # return_encoder_decoder=False,
+# *args,
+# **kwargs
+# ):
+
+# # TODO, check pre-trained ViT encoder cfgs
+
+# encoder, decoder, img_size, cls_token = create_3DAE_model(
+# arch_encoder,
+# arch_decoder,
+# img_size,
+# patch_size,
+# in_chans,
+# num_classes,
+# embed_dim, # Check ViT encoder dim
+# depth,
+# num_heads,
+# mlp_ratio,
+# qkv_bias,
+# qk_scale,
+# drop_rate,
+# attn_drop_rate,
+# drop_path_rate,
+# # norm_layer=nn.LayerNorm,
+# norm_layer,
+# out_chans=96,
+# decoder_in_chans=32,
+# decoder_output_dim=32,
+# cls_token=False,
+# c_dim=25, # Conditioning label (C) dimensionality.
+# img_resolution=128, # Output resolution.
+# img_channels=3, # Number of output color channels.
+# rendering_kwargs={},
+# load_pretrain_encoder=False,
+# decomposed=True,
+# triplane_size=224,
+# ae_classname='ViTTriplaneDecomposed',
+# return_encoder_decoder=False,
+# *args,
+# **kwargs
+# ) # type: ignore
+
+
+def create_Triplane(
+ c_dim=25, # Conditioning label (C) dimensionality.
+ img_resolution=128, # Output resolution.
+ img_channels=3, # Number of output color channels.
+ rendering_kwargs={},
+ decoder_output_dim=32,
+ *args,
+ **kwargs):
+
+ decoder = Triplane(
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ # TODO, replace with c
+ rendering_kwargs=rendering_kwargs,
+ create_triplane=True,
+ decoder_output_dim=decoder_output_dim)
+ return decoder
+
+
+def DiT_defaults():
+ return {
+ 'dit_model': "DiT-B/16",
+ 'vae': "ema"
+ # dit_model="DiT-XL/2",
+ # dit_patch_size=8,
+ }
diff --git a/nsr/superresolution.py b/nsr/superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f47429f2f1ab1643a3d2b9b52ec502e4e0b228
--- /dev/null
+++ b/nsr/superresolution.py
@@ -0,0 +1,446 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Superresolution network architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import torch
+from nsr.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer
+from utils.torch_utils.ops import upfirdn2d
+from utils.torch_utils import persistence
+from utils.torch_utils import misc
+
+from nsr.networks_stylegan2 import SynthesisBlock
+import numpy as np
+from pdb import set_trace as st
+
+
+@persistence.persistent_class
+class SynthesisBlockNoUp(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter=[
+ 1, 3, 3, 1
+ ], # Low-pass filter to apply when resampling activations.
+ conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16=False, # Use FP16 for this block?
+ fp16_channels_last=False, # Use channels-last memory format with FP16?
+ fused_modconv_default=True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(
+ torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels,
+ out_channels,
+ w_dim=w_dim,
+ resolution=resolution,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last,
+ **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels,
+ out_channels,
+ w_dim=w_dim,
+ resolution=resolution,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last,
+ **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels,
+ img_channels,
+ w_dim=w_dim,
+ conv_clamp=conv_clamp,
+ channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False,
+ up=2,
+ resample_filter=resample_filter,
+ channels_last=self.channels_last)
+
+ def forward(self,
+ x,
+ img,
+ ws,
+ force_fp32=False,
+ fused_modconv=None,
+ update_emas=False,
+ **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws,
+ [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(
+ x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+ x = self.conv1(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ gain=np.sqrt(0.5),
+ **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+ x = self.conv1(x,
+ next(w_iter),
+ fused_modconv=fused_modconv,
+ **layer_kwargs)
+
+ # ToRGB.
+ # if img is not None:
+ # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ # img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32,
+ memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ # assert x.dtype == dtype # support AMP in this library
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+#----------------------------------------------------------------------------
+
+
+# for 512x512 generation
+@persistence.persistent_class
+class SuperresolutionHybrid8X(torch.nn.Module):
+ def __init__(
+ self,
+ channels,
+ img_resolution,
+ sr_num_fp16_res,
+ sr_antialias,
+ num_fp16_res=4,
+ conv_clamp=None,
+ channel_base=None,
+ channel_max=None, # IGNORE
+ **block_kwargs):
+ super().__init__()
+ # assert img_resolution == 512
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 128
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlock(channels,
+ 128,
+ w_dim=512,
+ resolution=256,
+ img_channels=3,
+ is_last=False,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.block1 = SynthesisBlock(128,
+ 64,
+ w_dim=512,
+ resolution=512,
+ img_channels=3,
+ is_last=True,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter([1, 3, 3, 1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs) # block_kwargs: {'noise_mode': 'none'}
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+
+#----------------------------------------------------------------------------
+
+
+# for 256x256 generation
+@persistence.persistent_class
+class SuperresolutionHybrid4X(torch.nn.Module):
+ def __init__(
+ self,
+ channels,
+ img_resolution,
+ sr_num_fp16_res,
+ sr_antialias,
+ num_fp16_res=4,
+ conv_clamp=None,
+ channel_base=None,
+ channel_max=None, # IGNORE
+ **block_kwargs):
+ super().__init__()
+ # assert img_resolution == 256
+ use_fp16 = sr_num_fp16_res > 0
+ self.sr_antialias = sr_antialias
+ self.input_resolution = 128
+ self.block0 = SynthesisBlockNoUp(
+ channels,
+ 128,
+ w_dim=512,
+ resolution=128,
+ img_channels=3,
+ is_last=False,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.block1 = SynthesisBlock(128,
+ 64,
+ w_dim=512,
+ resolution=256,
+ img_channels=3,
+ is_last=True,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter([1, 3, 3, 1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] < self.input_resolution:
+ x = torch.nn.functional.interpolate(x,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+
+#----------------------------------------------------------------------------
+
+
+# for 128 x 128 generation
+@persistence.persistent_class
+class SuperresolutionHybrid2X(torch.nn.Module):
+ def __init__(
+ self,
+ channels,
+ img_resolution,
+ sr_num_fp16_res,
+ sr_antialias,
+ num_fp16_res=4,
+ conv_clamp=None,
+ channel_base=None,
+ channel_max=None, # IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 128
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 64
+ # self.input_resolution = 128
+
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlockNoUp(
+ channels,
+ 128,
+ w_dim=512,
+ resolution=64,
+ # resolution=128,
+ img_channels=3,
+ is_last=False,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.block1 = SynthesisBlock(128,
+ 64,
+ w_dim=512,
+ resolution=128,
+ # resolution=256,
+ img_channels=3,
+ is_last=True,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.register_buffer('resample_filter',
+ upfirdn2d.setup_filter([1, 3, 3, 1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+
+#----------------------------------------------------------------------------
+
+
+# for 512x512 generation
+@persistence.persistent_class
+class SuperresolutionHybrid8XDC(torch.nn.Module):
+ def __init__(
+ self,
+ channels,
+ img_resolution,
+ sr_num_fp16_res,
+ sr_antialias,
+ num_fp16_res=4,
+ conv_clamp=None,
+ channel_base=None,
+ channel_max=None, # IGNORE
+ **block_kwargs):
+ super().__init__()
+ # assert img_resolution == 512
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 128
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlock(channels,
+ 256,
+ w_dim=512,
+ resolution=256,
+ img_channels=3,
+ is_last=False,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+ self.block1 = SynthesisBlock(256,
+ 128,
+ w_dim=512,
+ resolution=512,
+ img_channels=3,
+ is_last=True,
+ use_fp16=use_fp16,
+ conv_clamp=(256 if use_fp16 else None),
+ **block_kwargs)
+
+ def forward(self, rgb, x, ws, base_x=None, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1) # BS 3 512
+
+ # st()
+ if x.shape[-1] != self.input_resolution: # resize 64 => 128
+ x = torch.nn.functional.interpolate(x,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb,
+ size=(self.input_resolution,
+ self.input_resolution),
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ # print(f'device={self.block0.conv1.weight.device}')
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ # print(f'device={self.block1.conv1.weight.device}')
+ return rgb
+
+
+#----------------------------------------------------------------------------
diff --git a/nsr/train_nv_util.py b/nsr/train_nv_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42df7b3b0d0ec266314af45c22d434a0e842bea
--- /dev/null
+++ b/nsr/train_nv_util.py
@@ -0,0 +1,1642 @@
+import copy
+# import imageio.v3
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+from einops import rearrange
+import webdataset as wds
+
+import traceback
+import blobfile as bf
+import imageio
+import numpy as np
+# from sympy import O
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+from guided_diffusion.train_util import (calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+
+from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics
+
+from .train_util import TrainLoop3DRec
+
+
+class TrainLoop3DRecNV(TrainLoop3DRec):
+ # supervise the training of novel view
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+ self.rec_cano = True
+
+ def forward_backward(self, batch, *args, **kwargs):
+ # return super().forward_backward(batch, *args, **kwargs)
+
+ self.mp_trainer_rec.zero_grad()
+ batch_size = batch['img_to_encoder'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ # st()
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in batch.items()
+ }
+
+ # ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first.
+
+ # wrap forward within amp
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ target_nvs = {}
+ target_cano = {}
+
+ latent = self.rec_model(img=micro['img_to_encoder'],
+ behaviour='enc_dec_wo_triplane')
+
+ pred = self.rec_model(
+ latent=latent,
+ c=micro['nv_c'], # predict novel view here
+ behaviour='triplane_dec')
+
+ for k, v in micro.items():
+ if k[:2] == 'nv':
+ orig_key = k.replace('nv_', '')
+ target_nvs[orig_key] = v
+ target_cano[orig_key] = micro[orig_key]
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict, fg_mask = self.loss_class(
+ pred,
+ target_nvs,
+ step=self.step + self.resume_step,
+ test_mode=False,
+ return_fg_mask=True,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None)
+ log_rec3d_loss_dict(loss_dict)
+
+ if self.rec_cano:
+
+ pred_cano = self.rec_model(latent=latent,
+ c=micro['c'],
+ behaviour='triplane_dec')
+
+ with self.rec_model.no_sync(): # type: ignore
+
+ fg_mask = target_cano['depth_mask'].unsqueeze(
+ 1).repeat_interleave(3, 1).float()
+
+ loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss(
+ pred_cano['image_raw'],
+ target_cano['img'],
+ fg_mask,
+ step=self.step + self.resume_step,
+ test_mode=False,
+ )
+
+ loss = loss + loss_cano
+
+ # remove redundant log
+ log_rec3d_loss_dict({
+ f'cano_{k}': v
+ for k, v in loss_cano_dict.items()
+ # if "loss" in k
+ })
+
+ self.mp_trainer_rec.backward(loss)
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ if self.rec_cano:
+ self.log_img(micro, pred, pred_cano)
+ else:
+ self.log_img(micro, pred, None)
+
+ @th.inference_mode()
+ def log_img(self, micro, pred, pred_cano):
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ def norm_depth(pred_depth): # to [-1,1]
+ # pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ return -(pred_depth * 2 - 1)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ # infer novel view also
+ # if self.loss_class.opt.symmetry_loss:
+ # pred_nv_img = nvs_pred
+ # else:
+ # ! replace with novel view prediction
+
+ # ! log another novel-view prediction
+ # pred_nv_img = self.rec_model(
+ # img=micro['img_to_encoder'],
+ # c=self.novel_view_poses) # pred: (B, 3, 64, 64)
+
+ # if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = norm_depth(gt_depth)
+ # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ # gt_depth.min())
+ # if True:
+ fg_mask = pred['image_mask'] * 2 - 1 # 0-1
+ input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1
+ if 'image_depth' in pred:
+ pred_depth = norm_depth(pred['image_depth'])
+ pred_nv_depth = norm_depth(pred_cano['image_depth'])
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+ pred_nv_depth = th.zeros_like(gt_depth)
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+ else:
+ gt_img = self.pool_64(gt_img)
+ gt_depth = self.pool_64(gt_depth)
+
+ pred_vis = th.cat([
+ pred_img,
+ pred_depth.repeat_interleave(3, dim=1),
+ fg_mask.repeat_interleave(3, dim=1),
+ ],
+ dim=-1) # B, 3, H, W
+
+ pred_vis_nv = th.cat([
+ pred_cano['image_raw'],
+ pred_nv_depth.repeat_interleave(3, dim=1),
+ input_fg_mask.repeat_interleave(3, dim=1),
+ ],
+ dim=-1) # B, 3, H, W
+
+ pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim
+
+ gt_vis = th.cat([
+ gt_img,
+ gt_depth.repeat_interleave(3, dim=1),
+ th.zeros_like(gt_img)
+ ],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ if 'conf_sigma' in pred:
+ gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ vis = th.cat([gt_vis, pred_vis], dim=-2)
+ # .permute(
+ # 0, 2, 3, 1).cpu()
+ vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] //
+ 64) # HWC
+ torchvision.utils.save_image(
+ vis_tensor,
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
+ value_range=(-1, 1),
+ normalize=True)
+ # vis = vis.numpy() * 127.5 + 127.5
+ # vis = vis.clip(0, 255).astype(np.uint8)
+
+ # Image.fromarray(vis).save(
+ # f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ logger.log('log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ # self.writer.add_image(f'images',
+ # vis,
+ # self.step + self.resume_step,
+ # dataformats='HWC')
+
+
+# return pred
+
+
+class TrainLoop3DRecNVPatch(TrainLoop3DRecNV):
+ # add patch rendering
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+ # the rendrer
+ self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore
+ # self.rec_cano = False
+ self.rec_cano = True
+
+ def forward_backward(self, batch, *args, **kwargs):
+ # add patch sampling
+
+ self.mp_trainer_rec.zero_grad()
+ batch_size = batch['img_to_encoder'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in batch.items()
+ }
+
+ # ! sample rendering patch
+ target = {
+ **self.eg3d_model(
+ c=micro['nv_c'], # type: ignore
+ ws=None,
+ planes=None,
+ sample_ray_only=True,
+ fg_bbox=micro['nv_bbox']), # rays o / dir
+ }
+
+ patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
+ 'patch_rendering_resolution'] # type: ignore
+ cropped_target = {
+ k:
+ th.empty_like(v)
+ [..., :patch_rendering_resolution, :patch_rendering_resolution]
+ if k not in [
+ 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
+ 'nv_img_sr', 'c'
+ ] else v
+ for k, v in micro.items()
+ }
+
+ # crop according to uv sampling
+ for j in range(micro['img'].shape[0]):
+ top, left, height, width = target['ray_bboxes'][
+ j] # list of tuple
+ # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
+ for key in ('img', 'depth_mask', 'depth'): # type: ignore
+ # target[key][i:i+1] = torchvision.transforms.functional.crop(
+ # cropped_target[key][
+ # j:j + 1] = torchvision.transforms.functional.crop(
+ # micro[key][j:j + 1], top, left, height, width)
+
+ cropped_target[f'{key}'][ # ! no nv_ here
+ j:j + 1] = torchvision.transforms.functional.crop(
+ micro[f'nv_{key}'][j:j + 1], top, left, height,
+ width)
+
+ # target.update(cropped_target)
+
+ # wrap forward within amp
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ # target_nvs = {}
+ # target_cano = {}
+
+ latent = self.rec_model(img=micro['img_to_encoder'],
+ behaviour='enc_dec_wo_triplane')
+
+ pred_nv = self.rec_model(
+ latent=latent,
+ c=micro['nv_c'], # predict novel view here
+ behaviour='triplane_dec',
+ ray_origins=target['ray_origins'],
+ ray_directions=target['ray_directions'],
+ )
+
+ # ! directly retrieve from target
+ # for k, v in target.items():
+ # if k[:2] == 'nv':
+ # orig_key = k.replace('nv_', '')
+ # target_nvs[orig_key] = v
+ # target_cano[orig_key] = target[orig_key]
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict, _ = self.loss_class(pred_nv,
+ cropped_target,
+ step=self.step +
+ self.resume_step,
+ test_mode=False,
+ return_fg_mask=True,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None)
+ log_rec3d_loss_dict(loss_dict)
+
+ if self.rec_cano:
+
+ cano_target = {
+ **self.eg3d_model(
+ c=micro['c'], # type: ignore
+ ws=None,
+ planes=None,
+ sample_ray_only=True,
+ fg_bbox=micro['bbox']), # rays o / dir
+ }
+
+ cano_cropped_target = {
+ k: th.empty_like(v)
+ for k, v in cropped_target.items()
+ }
+
+ for j in range(micro['img'].shape[0]):
+ top, left, height, width = cano_target['ray_bboxes'][
+ j] # list of tuple
+ # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
+ for key in ('img', 'depth_mask',
+ 'depth'): # type: ignore
+ # target[key][i:i+1] = torchvision.transforms.functional.crop(
+ cano_cropped_target[key][
+ j:j +
+ 1] = torchvision.transforms.functional.crop(
+ micro[key][j:j + 1], top, left, height,
+ width)
+
+ # cano_target.update(cano_cropped_target)
+
+ pred_cano = self.rec_model(
+ latent=latent,
+ c=micro['c'],
+ behaviour='triplane_dec',
+ ray_origins=cano_target['ray_origins'],
+ ray_directions=cano_target['ray_directions'],
+ )
+
+ with self.rec_model.no_sync(): # type: ignore
+
+ fg_mask = cano_cropped_target['depth_mask'].unsqueeze(
+ 1).repeat_interleave(3, 1).float()
+
+ loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss(
+ pred_cano['image_raw'],
+ cano_cropped_target['img'],
+ fg_mask,
+ step=self.step + self.resume_step,
+ test_mode=False,
+ )
+
+ loss = loss + loss_cano
+
+ # remove redundant log
+ log_rec3d_loss_dict({
+ f'cano_{k}': v
+ for k, v in loss_cano_dict.items()
+ # if "loss" in k
+ })
+
+ self.mp_trainer_rec.backward(loss)
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ self.log_patch_img(cropped_target, pred_nv, pred_cano)
+
+ @th.inference_mode()
+ def log_patch_img(self, micro, pred, pred_cano):
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ def norm_depth(pred_depth): # to [-1,1]
+ # pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ return -(pred_depth * 2 - 1)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ # infer novel view also
+ # if self.loss_class.opt.symmetry_loss:
+ # pred_nv_img = nvs_pred
+ # else:
+ # ! replace with novel view prediction
+
+ # ! log another novel-view prediction
+ # pred_nv_img = self.rec_model(
+ # img=micro['img_to_encoder'],
+ # c=self.novel_view_poses) # pred: (B, 3, 64, 64)
+
+ # if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = norm_depth(gt_depth)
+ # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ # gt_depth.min())
+ # if True:
+ fg_mask = pred['image_mask'] * 2 - 1 # 0-1
+ input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1
+ if 'image_depth' in pred:
+ pred_depth = norm_depth(pred['image_depth'])
+ pred_cano_depth = norm_depth(pred_cano['image_depth'])
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+ pred_cano_depth = th.zeros_like(gt_depth)
+
+ # if 'image_sr' in pred:
+ # if pred['image_sr'].shape[-1] == 512:
+ # pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_512(pred_depth)
+ # gt_depth = self.pool_512(gt_depth)
+
+ # elif pred['image_sr'].shape[-1] == 256:
+ # pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_256(pred_depth)
+ # gt_depth = self.pool_256(gt_depth)
+
+ # else:
+ # pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # gt_depth = self.pool_128(gt_depth)
+ # pred_depth = self.pool_128(pred_depth)
+ # else:
+ # gt_img = self.pool_64(gt_img)
+ # gt_depth = self.pool_64(gt_depth)
+
+ pred_vis = th.cat([
+ pred_img,
+ pred_depth.repeat_interleave(3, dim=1),
+ fg_mask.repeat_interleave(3, dim=1),
+ ],
+ dim=-1) # B, 3, H, W
+
+ pred_vis_nv = th.cat([
+ pred_cano['image_raw'],
+ pred_cano_depth.repeat_interleave(3, dim=1),
+ input_fg_mask.repeat_interleave(3, dim=1),
+ ],
+ dim=-1) # B, 3, H, W
+
+ pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim
+
+ gt_vis = th.cat([
+ gt_img,
+ gt_depth.repeat_interleave(3, dim=1),
+ th.zeros_like(gt_img)
+ ],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ # if 'conf_sigma' in pred:
+ # gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ # st()
+ vis = th.cat([gt_vis, pred_vis], dim=-2)
+ # .permute(
+ # 0, 2, 3, 1).cpu()
+ vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] //
+ 64) # HWC
+ torchvision.utils.save_image(
+ vis_tensor,
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
+ value_range=(-1, 1),
+ normalize=True)
+
+ logger.log('log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ # self.writer.add_image(f'images',
+ # vis,
+ # self.step + self.resume_step,
+ # dataformats='HWC')
+
+
+class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch):
+
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+
+ def forward_backward(self, batch, *args, **kwargs):
+ # add patch sampling
+
+ self.mp_trainer_rec.zero_grad()
+ batch_size = batch['img_to_encoder'].shape[0]
+
+ batch.pop('caption') # not required
+ batch.pop('ins') # not required
+ # batch.pop('nv_caption') # not required
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v[i:i + self.microbatch]
+ for k, v in batch.items()
+ }
+
+ # ! sample rendering patch
+ target = {
+ **self.eg3d_model(
+ c=micro['nv_c'], # type: ignore
+ ws=None,
+ planes=None,
+ sample_ray_only=True,
+ fg_bbox=micro['nv_bbox']), # rays o / dir
+ }
+
+ patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
+ 'patch_rendering_resolution'] # type: ignore
+ cropped_target = {
+ k:
+ th.empty_like(v)
+ [..., :patch_rendering_resolution, :patch_rendering_resolution]
+ if k not in [
+ 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
+ 'nv_img_sr', 'c', 'caption', 'nv_caption'
+ ] else v
+ for k, v in micro.items()
+ }
+
+ # crop according to uv sampling
+ for j in range(micro['img'].shape[0]):
+ top, left, height, width = target['ray_bboxes'][
+ j] # list of tuple
+ # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
+ for key in ('img', 'depth_mask', 'depth'): # type: ignore
+ # target[key][i:i+1] = torchvision.transforms.functional.crop(
+ # cropped_target[key][
+ # j:j + 1] = torchvision.transforms.functional.crop(
+ # micro[key][j:j + 1], top, left, height, width)
+
+ cropped_target[f'{key}'][ # ! no nv_ here
+ j:j + 1] = torchvision.transforms.functional.crop(
+ micro[f'nv_{key}'][j:j + 1], top, left, height,
+ width)
+
+ # ! cano view loss
+ cano_target = {
+ **self.eg3d_model(
+ c=micro['c'], # type: ignore
+ ws=None,
+ planes=None,
+ sample_ray_only=True,
+ fg_bbox=micro['bbox']), # rays o / dir
+ }
+
+ # cano_cropped_target = {
+ # k: th.empty_like(v)
+ # for k, v in cropped_target.items()
+ # }
+
+ # for j in range(micro['img'].shape[0]):
+ # top, left, height, width = cano_target['ray_bboxes'][
+ # j] # list of tuple
+ # # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
+ # for key in ('img', 'depth_mask', 'depth'): # type: ignore
+ # # target[key][i:i+1] = torchvision.transforms.functional.crop(
+ # cano_cropped_target[key][
+ # j:j + 1] = torchvision.transforms.functional.crop(
+ # micro[key][j:j + 1], top, left, height, width)
+
+ # ! vit no amp
+ latent = self.rec_model(img=micro['img_to_encoder'],
+ behaviour='enc_dec_wo_triplane')
+
+ # wrap forward within amp
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ # c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here
+ # c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here
+ instance_mv_num = batch_size // 4 # 4 pairs by default
+ # instance_mv_num = 4
+ # ! roll views for multi-view supervision
+ c = th.cat([
+ micro['nv_c'].roll(instance_mv_num * i, dims=0)
+ for i in range(1, 4)
+ ]
+ # + [micro['c']]
+ ) # predict novel view here
+
+ ray_origins = th.cat(
+ [
+ target['ray_origins'].roll(instance_mv_num * i, dims=0)
+ for i in range(1, 4)
+ ]
+ # + [cano_target['ray_origins'] ]
+ ,
+ 0)
+
+ ray_directions = th.cat([
+ target['ray_directions'].roll(instance_mv_num * i, dims=0)
+ for i in range(1, 4)
+ ]
+ # + [cano_target['ray_directions'] ]
+ )
+
+ pred_nv_cano = self.rec_model(
+ # latent=latent.expand(2,),
+ latent={
+ 'latent_after_vit': # ! triplane for rendering
+ # latent['latent_after_vit'].repeat(2, 1, 1, 1)
+ latent['latent_after_vit'].repeat(3, 1, 1, 1)
+ },
+ c=c,
+ behaviour='triplane_dec',
+ # ray_origins=target['ray_origins'],
+ # ray_directions=target['ray_directions'],
+ ray_origins=ray_origins,
+ ray_directions=ray_directions,
+ )
+
+ pred_nv_cano.update(
+ latent
+ ) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True)
+ # gt = {
+ # k: th.cat([v, cano_cropped_target[k]], 0)
+ # for k, v in cropped_target.items()
+ # }
+ gt = {
+ k:
+ th.cat(
+ [
+ v.roll(instance_mv_num * i, dims=0)
+ for i in range(1, 4)
+ ]
+ # + [cano_cropped_target[k] ]
+ ,
+ 0)
+ for k, v in cropped_target.items()
+ } # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True)
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict, _ = self.loss_class(
+ pred_nv_cano,
+ gt, # prepare merged data
+ step=self.step + self.resume_step,
+ test_mode=False,
+ return_fg_mask=True,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None)
+ log_rec3d_loss_dict(loss_dict)
+
+ self.mp_trainer_rec.backward(loss)
+
+ # for name, p in self.rec_model.named_parameters():
+ # if p.grad is None:
+ # logger.log(f"found rec unused param: {name}")
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ micro_bs = micro['img_to_encoder'].shape[0]
+ self.log_patch_img( # record one cano view and one novel view
+ cropped_target,
+ {
+ k: pred_nv_cano[k][-micro_bs:]
+ for k in ['image_raw', 'image_depth', 'image_mask']
+ },
+ {
+ k: pred_nv_cano[k][:micro_bs]
+ for k in ['image_raw', 'image_depth', 'image_mask']
+ },
+ )
+
+ def eval_loop(self):
+ return super().eval_loop()
+
+ @th.inference_mode()
+ # def eval_loop(self, c_list:list):
+ def eval_novelview_loop_old(self, camera=None):
+ # novel view synthesis given evaluation camera trajectory
+
+ all_loss_dict = []
+ novel_view_micro = {}
+
+ # ! randomly inference an instance
+
+ export_mesh = True
+ if export_mesh:
+ Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True,
+ exist_ok=True)
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+
+ batch = {}
+ # if camera is not None:
+ # # batch['c'] = camera.to(batch['c'].device())
+ # batch['c'] = camera.clone()
+ # else:
+ # batch =
+
+ for eval_idx, render_reference in enumerate(tqdm(self.eval_data)):
+
+ if eval_idx > 500:
+ break
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4',
+ mode='I',
+ fps=25,
+ codec='libx264')
+
+ with open(
+ f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt',
+ 'w') as f:
+ f.write(render_reference['caption'])
+
+ for key in ['ins', 'bbox', 'caption']:
+ if key in render_reference:
+ render_reference.pop(key)
+
+ real_flag = False
+ mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics.
+ if render_reference['c'].shape[:2] == (1, 40):
+ real_flag = True
+ # real img monocular reconstruction
+ # compat lst for enumerate
+ render_reference = [{
+ k: v[0][idx:idx + 1]
+ for k, v in render_reference.items()
+ } for idx in range(40)]
+
+ elif render_reference['c'].shape[0] == 8:
+ mv_flag = True
+
+ render_reference = {
+ k: v[:4]
+ for k, v in render_reference.items()
+ }
+
+ # save gt
+ torchvision.utils.save_image(
+ render_reference[0:4]['img'],
+ logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx),
+ padding=0,
+ normalize=True,
+ value_range=(-1, 1),
+ )
+ # torchvision.utils.save_image(render_reference[4:8]['img'],
+ # logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx),
+ # padding=0,
+ # normalize=True,
+ # value_range=(-1,1),
+ # )
+
+ else:
+ # compat lst for enumerate
+ st()
+ render_reference = [{
+ k: v[idx:idx + 1]
+ for k, v in render_reference.items()
+ } for idx in range(40)]
+
+ # ! single-view version
+ render_reference[0]['img_to_encoder'] = render_reference[14][
+ 'img_to_encoder'] # encode side view
+ render_reference[0]['img'] = render_reference[14][
+ 'img'] # encode side view
+
+ # save gt
+ torchvision.utils.save_image(
+ render_reference[0]['img'],
+ logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx),
+ padding=0,
+ normalize=True,
+ value_range=(-1, 1))
+
+ # ! TODO, merge with render_video_given_triplane later
+ for i, batch in enumerate(render_reference):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ st()
+ if i == 0:
+ if mv_flag:
+ novel_view_micro = None
+ else:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ # v[14:15].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0],
+ 0) if isinstance(v, th.Tensor) else v[0:1]
+ for k, v in batch.items()
+ }
+
+ else:
+ if i == 1:
+
+ # ! output mesh
+ if export_mesh:
+
+ # ! get planes first
+ # self.latent_name = 'latent_normalized' # normalized triplane latent
+
+ # ddpm_latent = {
+ # self.latent_name: planes,
+ # }
+ # ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render'))
+
+ # mesh_size = 512
+ # mesh_size = 256
+ mesh_size = 384
+ # mesh_size = 320
+ # mesh_thres = 3 # TODO, requires tuning
+ # mesh_thres = 5 # TODO, requires tuning
+ mesh_thres = 10 # TODO, requires tuning
+ import mcubes
+ import trimesh
+ dump_path = f'{logger.get_dir()}/mesh/'
+
+ os.makedirs(dump_path, exist_ok=True)
+
+ grid_out = self.rec_model(
+ latent=pred,
+ grid_size=mesh_size,
+ behaviour='triplane_decode_grid',
+ )
+
+ vtx, faces = mcubes.marching_cubes(
+ grid_out['sigma'].squeeze(0).squeeze(
+ -1).cpu().numpy(), mesh_thres)
+ vtx = vtx / (mesh_size - 1) * 2 - 1
+
+ # vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0)
+ # vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
+ # vtx_colors = (vtx_colors * 255).astype(np.uint8)
+
+ # mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
+ mesh = trimesh.Trimesh(
+ vertices=vtx,
+ faces=faces,
+ )
+
+ mesh_dump_path = os.path.join(
+ dump_path, f'{eval_idx}.ply')
+ mesh.export(mesh_dump_path, 'ply')
+
+ print(f"Mesh dumped to {dump_path}")
+ del grid_out, mesh
+ th.cuda.empty_cache()
+ # return
+ # st()
+
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+ # target = {
+ # 'img': micro['img'],
+ # 'depth': micro['depth'],
+ # 'depth_mask': micro['depth_mask']
+ # }
+ # targe
+
+ # if not export_mesh:
+ if not real_flag:
+ _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ all_loss_dict.append(loss_dict)
+
+ # ! move to other places, add tensorboard
+
+ # pred_vis = th.cat([
+ # pred['image_raw'],
+ # -pred['image_depth'].repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ else:
+ # pred_vis = th.cat([
+ # self.pool_64(micro['img']), pred['image_raw'],
+ # pred_depth.repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1) # B, 3, H, W
+
+ pooled_depth = self.pool_128(pred_depth).repeat_interleave(
+ 3, dim=1)
+ pred_vis = th.cat(
+ [
+ # self.pool_128(micro['img']),
+ self.pool_128(novel_view_micro['img']
+ ), # use the input here
+ self.pool_128(pred['image_raw']),
+ pooled_depth,
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ if export_mesh:
+ # save image
+ torchvision.utils.save_image(
+ pred['image_raw'],
+ logger.get_dir() +
+ '/FID_Cals/{}_{}.png'.format(eval_idx, i),
+ padding=0,
+ normalize=True,
+ value_range=(-1, 1))
+
+ torchvision.utils.save_image(
+ pooled_depth,
+ logger.get_dir() +
+ '/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i),
+ padding=0,
+ normalize=True,
+ value_range=(0, 1))
+
+ # st()
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ # if not export_mesh:
+ if not real_flag or mv_flag:
+ val_scores_for_logging = calc_average_loss(all_loss_dict)
+ with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
+ 'a') as f:
+ json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # * log to tensorboard
+ for k, v in val_scores_for_logging.items():
+ self.writer.add_scalar(f'Eval/NovelView/{k}', v,
+ self.step + self.resume_step)
+
+ del video_out
+ # del pred_vis
+ # del pred
+
+ th.cuda.empty_cache()
+
+ @th.inference_mode()
+ # def eval_loop(self, c_list:list):
+ def eval_novelview_loop(self, camera=None, save_latent=False):
+ # novel view synthesis given evaluation camera trajectory
+ if save_latent: # for diffusion learning
+ latent_dir = Path(f'{logger.get_dir()}/latent_dir')
+ latent_dir.mkdir(exist_ok=True, parents=True)
+
+ # wds_path = os.path.join(logger.get_dir(), 'latent_dir',
+ # f'wds-%06d.tar')
+ # sink = wds.ShardWriter(wds_path, start_shard=0)
+
+ # eval_batch_size = 20
+ # eval_batch_size = 1
+ eval_batch_size = 40 # ! for i23d
+
+ for eval_idx, micro in enumerate(tqdm(self.eval_data)):
+
+ # if eval_idx > 500:
+ # break
+
+ latent = self.rec_model(
+ img=micro['img_to_encoder'][:4],
+ behaviour='encoder_vae') # pred: (B, 3, 64, 64)
+ # torchvision.utils.save_image(micro['img'], 'inp.jpg')
+ if micro['img'].shape[0] == 40:
+ assert eval_batch_size == 40
+
+ if save_latent:
+ # np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy())
+
+ latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}'
+ Path(latent_save_dir).mkdir(parents=True, exist_ok=True)
+
+ np.save(f'{latent_save_dir}/latent.npy',
+ latent[self.latent_name][0].cpu().numpy())
+ assert all([
+ micro['ins'][0] == micro['ins'][i]
+ for i in range(micro['c'].shape[0])
+ ]) # ! assert same instance
+
+ # for i in range(micro['img'].shape[0]):
+
+ # compressed_sample = {
+ # 'latent':latent[self.latent_name][0].cpu().numpy(), # 12 32 32
+ # 'caption': micro['caption'][0].encode('utf-8'),
+ # 'ins': micro['ins'][0].encode('utf-8'),
+ # 'c': micro['c'][i].cpu().numpy(),
+ # 'img': micro['img'][i].cpu().numpy() # 128x128, for diffusion log
+ # }
+
+ # sink.write({
+ # "__key__": f"sample_{eval_idx*eval_batch_size+i:07d}",
+ # 'sample.pyd': compressed_sample
+ # })
+
+ if eval_idx < 50:
+ # if False:
+ self.render_video_given_triplane(
+ latent[self.latent_name], # B 12 32 32
+ self.rec_model, # compatible with join_model
+ name_prefix=f'{self.step + self.resume_step}_{eval_idx}',
+ save_img=False,
+ render_reference={'c': camera},
+ save_mesh=True)
+
+
+class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward):
+
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+
+ def forward_backward(self, batch, behaviour='g_step', *args, **kwargs):
+ # add patch sampling
+
+ self.mp_trainer_rec.zero_grad()
+ batch_size = batch['img_to_encoder'].shape[0]
+
+ batch.pop('caption') # not required
+ batch.pop('ins') # not required
+ if '__key__' in batch.keys():
+ batch.pop('__key__')
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k:
+ v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
+ v, th.Tensor) else v[i:i + self.microbatch]
+ for k, v in batch.items()
+ }
+
+ # ! sample rendering patch
+ # nv_c = th.cat([micro['nv_c'], micro['c']])
+ nv_c = th.cat([micro['nv_c'], micro['c']])
+ # nv_c = micro['nv_c']
+ target = {
+ **self.eg3d_model(
+ c=nv_c, # type: ignore
+ ws=None,
+ planes=None,
+ sample_ray_only=True,
+ fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir
+ }
+
+ patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
+ 'patch_rendering_resolution'] # type: ignore
+ cropped_target = {
+ k:
+ th.empty_like(v).repeat_interleave(2, 0)
+ # th.empty_like(v).repeat_interleave(1, 0)
+ [..., :patch_rendering_resolution, :patch_rendering_resolution]
+ if k not in [
+ 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
+ 'nv_img_sr', 'c', 'caption', 'nv_caption'
+ ] else v
+ for k, v in micro.items()
+ }
+
+ # crop according to uv sampling
+ for j in range(2 * self.microbatch):
+ top, left, height, width = target['ray_bboxes'][
+ j] # list of tuple
+ # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
+ for key in ('img', 'depth_mask', 'depth'): # type: ignore
+
+ if j < self.microbatch:
+ cropped_target[f'{key}'][ # ! no nv_ here
+ j:j + 1] = torchvision.transforms.functional.crop(
+ micro[f'nv_{key}'][j:j + 1], top, left, height,
+ width)
+ else:
+ cropped_target[f'{key}'][ # ! no nv_ here
+ j:j + 1] = torchvision.transforms.functional.crop(
+ micro[f'{key}'][j - self.microbatch:j -
+ self.microbatch + 1], top,
+ left, height, width)
+
+ # for j in range(batch_size, 2*batch_size, 1):
+ # top, left, height, width = target['ray_bboxes'][
+ # j] # list of tuple
+ # # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
+ # for key in ('img', 'depth_mask', 'depth'): # type: ignore
+
+ # cropped_target[f'{key}'][ # ! no nv_ here
+ # j:j + 1] = torchvision.transforms.functional.crop(
+ # micro[f'{key}'][j-batch_size:j-batch_size + 1], top, left, height,
+ # width)
+
+ # ! vit no amp
+ latent = self.rec_model(img=micro['img_to_encoder'],
+ behaviour='enc_dec_wo_triplane')
+
+ # wrap forward within amp
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ # c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here
+ # c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here
+ # instance_mv_num = batch_size // 4 # 4 pairs by default
+ # instance_mv_num = 4
+ # ! roll views for multi-view supervision
+ # c = micro['nv_c']
+ ray_origins = target['ray_origins']
+ ray_directions = target['ray_directions']
+
+ pred_nv_cano = self.rec_model(
+ # latent=latent.expand(2,),
+ latent={
+ 'latent_after_vit': # ! triplane for rendering
+ latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4
+ # latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4
+ },
+ c=nv_c,
+ behaviour='triplane_dec',
+ ray_origins=ray_origins,
+ ray_directions=ray_directions,
+ )
+
+ pred_nv_cano.update(
+ latent
+ ) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True)
+ gt = cropped_target
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict, _ = self.loss_class(
+ pred_nv_cano,
+ gt, # prepare merged data
+ step=self.step + self.resume_step,
+ test_mode=False,
+ return_fg_mask=True,
+ behaviour=behaviour,
+ conf_sigma_l1=None,
+ conf_sigma_percl=None)
+ log_rec3d_loss_dict(loss_dict)
+
+ self.mp_trainer_rec.backward(loss)
+
+ # for name, p in self.rec_model.named_parameters():
+ # if p.grad is None:
+ # logger.log(f"found rec unused param: {name}")
+ # torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True)
+ # torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True)
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0:
+ try:
+ torchvision.utils.save_image(
+ th.cat(
+ [cropped_target['img'], pred_nv_cano['image_raw']
+ ], ),
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
+ normalize=True)
+
+ logger.log(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+ except Exception as e:
+ logger.log(e)
+
+ # micro_bs = micro['img_to_encoder'].shape[0]
+ # self.log_patch_img( # record one cano view and one novel view
+ # cropped_target,
+ # {
+ # k: pred_nv_cano[k][0:1]
+ # for k in ['image_raw', 'image_depth', 'image_mask']
+ # },
+ # {
+ # k: pred_nv_cano[k][1:2]
+ # for k in ['image_raw', 'image_depth', 'image_mask']
+ # },
+ # )
+
+ # def save(self):
+ # return super().save()
+
+
+class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss(
+ TrainLoop3DRecNVPatchSingleForwardMV):
+
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+
+ # create discriminator
+ disc_params = self.loss_class.get_trainable_parameters()
+
+ self.mp_trainer_disc = MixedPrecisionTrainer(
+ model=self.loss_class.discriminator,
+ use_fp16=self.use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ model_name='disc',
+ use_amp=use_amp,
+ model_params=disc_params)
+
+ # st() # check self.lr
+ self.opt_disc = AdamW(
+ self.mp_trainer_disc.master_params,
+ lr=self.lr, # follow sd code base
+ betas=(0, 0.999),
+ eps=1e-8)
+
+ # TODO, is loss cls already in the DDP?
+ if self.use_ddp:
+ self.ddp_disc = DDP(
+ self.loss_class.discriminator,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ self.ddp_disc = self.loss_class.discriminator
+
+ # def run_st
+
+ # def run_step(self, batch, *args):
+ # self.forward_backward(batch)
+ # took_step = self.mp_trainer_rec.optimize(self.opt)
+ # if took_step:
+ # self._update_ema()
+ # self._anneal_lr()
+ # self.log_step()
+
+ def save(self, mp_trainer=None, model_name='rec'):
+ if mp_trainer is None:
+ mp_trainer = self.mp_trainer_rec
+
+ def save_checkpoint(rate, params):
+ state_dict = mp_trainer.master_params_to_state_dict(params)
+ if dist_util.get_rank() == 0:
+ logger.log(f"saving model {model_name} {rate}...")
+ if not rate:
+ filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt"
+ else:
+ filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename),
+ "wb") as f:
+ th.save(state_dict, f)
+
+ save_checkpoint(0, mp_trainer.master_params)
+
+ dist.barrier()
+
+ def run_step(self, batch, step='g_step'):
+ # self.forward_backward(batch)
+
+ if step == 'g_step':
+ self.forward_backward(batch, behaviour='g_step')
+ took_step_g_rec = self.mp_trainer_rec.optimize(self.opt)
+
+ if took_step_g_rec:
+ self._update_ema() # g_ema
+
+ elif step == 'd_step':
+ self.forward_backward(batch, behaviour='d_step')
+ _ = self.mp_trainer_disc.optimize(self.opt_disc)
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self, batch=None):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ batch = next(self.data)
+ self.run_step(batch, 'g_step')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step')
+
+ if self.step % 1000 == 0:
+ dist_util.synchronize()
+ if self.step % 10000 == 0:
+ th.cuda.empty_cache() # avoid memory leak
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.eval_interval == 0 and self.step != 0:
+ if dist_util.get_rank() == 0:
+ try:
+ self.eval_loop()
+ except Exception as e:
+ logger.log(e)
+ dist_util.synchronize()
+
+ # if self.step % self.save_interval == 0 and self.step != 0:
+ if self.step % self.save_interval == 0:
+ self.save()
+ self.save(self.mp_trainer_disc,
+ self.mp_trainer_disc.model_name)
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ logger.log('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step -
+ 1) % self.save_interval != 0 and self.step != 1:
+ self.save()
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ # if (self.step - 1) % self.save_interval != 0 and self.step != 1:
+ if (self.step - 1) % self.save_interval != 0:
+ self.save() # save rec
+ self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name)
diff --git a/nsr/train_util.py b/nsr/train_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2dd1b7dcaf7e80c8ce7b073ab23b7715777c51d
--- /dev/null
+++ b/nsr/train_util.py
@@ -0,0 +1,1902 @@
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+
+import matplotlib.pyplot as plt
+import traceback
+import blobfile as bf
+import imageio
+import numpy as np
+# from sympy import O
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+from guided_diffusion.train_util import (calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+
+from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics
+
+# from ..guided_diffusion.train_util import TrainLoop
+
+
+def flip_yaw(pose_matrix):
+ flipped = pose_matrix.clone()
+ flipped[:, 0, 1] *= -1
+ flipped[:, 0, 2] *= -1
+ flipped[:, 1, 0] *= -1
+ flipped[:, 2, 0] *= -1
+ flipped[:, 0, 3] *= -1
+ # st()
+ return flipped
+
+
+# basic reconstruction model
+class TrainLoopBasic:
+
+ def __init__(
+ self,
+ *,
+ rec_model,
+ loss_class,
+ # diffusion,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ # schedule_sampler=None,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ compile=False,
+ **kwargs):
+ self.pool_512 = th.nn.AdaptiveAvgPool2d((512, 512))
+ self.pool_256 = th.nn.AdaptiveAvgPool2d((256, 256))
+ self.pool_128 = th.nn.AdaptiveAvgPool2d((128, 128))
+ self.pool_64 = th.nn.AdaptiveAvgPool2d((64, 64))
+ self.rec_model = rec_model
+ self.loss_class = loss_class
+ # self.diffusion = diffusion
+ # self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
+ self.data = data
+ self.eval_data = eval_data
+ self.batch_size = batch_size
+ self.microbatch = microbatch if microbatch > 0 else batch_size
+ self.lr = lr
+ self.ema_rate = ([ema_rate] if isinstance(ema_rate, float) else
+ [float(x) for x in ema_rate.split(",")])
+ self.log_interval = log_interval
+ self.eval_interval = eval_interval
+ self.save_interval = save_interval
+ self.iterations = iterations
+ self.resume_checkpoint = resume_checkpoint
+ self.use_fp16 = use_fp16
+ self.fp16_scale_growth = fp16_scale_growth
+ self.weight_decay = weight_decay
+ self.lr_anneal_steps = lr_anneal_steps
+
+ self.step = 0
+ self.resume_step = 0
+ # self.global_batch = self.batch_size * dist.get_world_size()
+ self.global_batch = self.batch_size * dist_util.get_world_size()
+
+ self.sync_cuda = th.cuda.is_available()
+
+ # self._load_and_sync_parameters(load_submodule_name)
+ self._load_and_sync_parameters()
+
+ self.mp_trainer_rec = MixedPrecisionTrainer(
+ model=self.rec_model,
+ use_fp16=self.use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ model_name=model_name,
+ use_amp=use_amp)
+ self.writer = SummaryWriter(log_dir=f'{logger.get_dir()}/runs')
+
+ self.opt = AdamW(self._init_optim_groups(kwargs))
+
+ if dist_util.get_rank() == 0:
+ logger.log(self.opt)
+
+ if self.resume_step:
+ if not ignore_resume_opt:
+ self._load_optimizer_state()
+ else:
+ logger.warn("Ignoring optimizer state from checkpoint.")
+ # Model was resumed, either due to a restart or a checkpoint
+ # being specified at the command line.
+ # self.ema_params = [
+ # self._load_ema_parameters(rate, load_submodule_name) for rate in self.ema_rate
+ # ]
+
+ self.ema_params = [
+ self._load_ema_parameters(
+ rate,
+ self.rec_model,
+ self.mp_trainer_rec,
+ model_name=self.mp_trainer_rec.model_name)
+ for rate in self.ema_rate
+ ]
+ else:
+ self.ema_params = [
+ copy.deepcopy(self.mp_trainer_rec.master_params)
+ for _ in range(len(self.ema_rate))
+ ]
+
+ # compile
+ if compile:
+ logger.log('compiling... ignore vit_decoder')
+ # self.rec_model.encoder = th.compile(self.rec_model.encoder)
+ self.rec_model.decoder.decoder_pred = th.compile(
+ self.rec_model.decoder.decoder_pred)
+ # self.rec_model.decoder.triplane_decoder = th.compile(self.rec_model.decoder.triplane_decoder)
+ for module_k, sub_module in self.rec_model.decoder.superresolution.items(
+ ):
+ self.rec_model.decoder.superresolution[module_k] = th.compile(
+ sub_module)
+
+ if th.cuda.is_available():
+ self.use_ddp = True
+
+ self.rec_model = th.nn.SyncBatchNorm.convert_sync_batchnorm(
+ self.rec_model)
+
+ self.rec_model = DDP(
+ self.rec_model,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ if dist_util.get_world_size() > 1:
+ logger.warn("Distributed training requires CUDA. "
+ "Gradients will not be synchronized properly!")
+ self.use_ddp = False
+ self.rec_model = self.rec_model
+
+ self.novel_view_poses = None
+ th.cuda.empty_cache()
+
+ def _init_optim_groups(self, kwargs):
+ raise NotImplementedError('')
+
+ def _load_and_sync_parameters(self, submodule_name=''):
+ # resume_checkpoint, self.resume_step = find_resume_checkpoint() or self.resume_checkpoint
+ resume_checkpoint = self.resume_checkpoint # * default behaviour
+ # logger.log('resume_checkpoint', resume_checkpoint, self.resume_checkpoint)
+
+ if resume_checkpoint:
+ self.resume_step = parse_resume_step_from_filename(
+ resume_checkpoint)
+ if dist_util.get_rank() == 0:
+ logger.log(
+ f"loading model from checkpoint: {resume_checkpoint}...")
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ resume_state_dict = dist_util.load_state_dict(
+ resume_checkpoint, map_location=map_location)
+ if submodule_name != '':
+ model_state_dict = getattr(self.rec_model,
+ submodule_name).state_dict()
+ if dist_util.get_rank() == 0:
+ logger.log('loading submodule: ', submodule_name)
+ else:
+ model_state_dict = self.rec_model.state_dict()
+
+ model = self.rec_model
+
+ # for k, v in resume_state_dict.items():
+ # if k in model_state_dict.keys() and v.size(
+ # ) == model_state_dict[k].size():
+ # model_state_dict[k] = v
+ # else:
+ # logger.log('!!!! ignore key: ', k, ": ", v.size())
+
+ for k, v in resume_state_dict.items():
+ if '._orig_mod' in k: # prefix in torch.compile
+ k = k.replace('._orig_mod', '')
+ if k in model_state_dict.keys():
+ if v.size() == model_state_dict[k].size():
+ model_state_dict[k] = v
+ # model_state_dict[k].copy_(v)
+ else:
+ # if v.ndim > 1:
+ # model_state_dict[k][:v.shape[0], :v.shape[1], ...] = v # load the decoder
+ # model_state_dict[k][v.shape[0]:, v.shape[1]:, ...] = 0
+ # else:
+ # model_state_dict[k][:v.shape[0], ...] = v # load the decoder
+ # model_state_dict[k][v.shape[0]:, ...] = 0
+ # logger.log('!!!! size mismatch, partially load: ', k, ": ", v.size(), "state_dict: ", model_state_dict[k].size())
+ logger.log('!!!! size mismatch, ignore: ', k, ": ",
+ v.size(), "state_dict: ",
+ model_state_dict[k].size())
+
+ elif 'decoder.vit_decoder.blocks' in k:
+ # st()
+ # load from 2D ViT pre-trained into 3D ViT blocks.
+ assert len(model.decoder.vit_decoder.blocks[0].vit_blks
+ ) == 2 # assert depth=2 here.
+ fusion_ca_depth = len(
+ model.decoder.vit_decoder.blocks[0].vit_blks)
+ vit_subblk_index = int(k.split('.')[3])
+ vit_blk_keyname = ('.').join(k.split('.')[4:])
+ fusion_blk_index = vit_subblk_index // fusion_ca_depth
+ fusion_blk_subindex = vit_subblk_index % fusion_ca_depth
+ model_state_dict[
+ f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}'] = v
+ logger.log('load 2D ViT weight: {}'.format(
+ f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}'
+ ))
+
+ else:
+ logger.log(
+ '!!!! ignore key, not in the model_state_dict: ',
+ k, ": ", v.size())
+
+ logger.log('model loading finished')
+
+ if submodule_name != '':
+ getattr(self.rec_model,
+ submodule_name).load_state_dict(model_state_dict,
+ strict=True)
+ else:
+ self.rec_model.load_state_dict(model_state_dict,
+ strict=False)
+ # strict=True)
+
+ if dist_util.get_world_size() > 1:
+ # dist_util.sync_params(self.model.named_parameters())
+ dist_util.sync_params(self.rec_model.parameters())
+ logger.log('synced params')
+
+ def _load_ema_parameters(self,
+ rate,
+ model=None,
+ mp_trainer=None,
+ model_name='ddpm'):
+
+ if mp_trainer is None:
+ mp_trainer = self.mp_trainer_rec
+ if model is None:
+ model = self.rec_model
+
+ ema_params = copy.deepcopy(mp_trainer.master_params)
+
+ # main_checkpoint, _ = find_resume_checkpoint(
+ # self.resume_checkpoint, model_name) or self.resume_checkpoint
+
+ main_checkpoint = self.resume_checkpoint
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step,
+ rate, model_name)
+ if ema_checkpoint and model_name == 'ddpm':
+
+ if dist_util.get_rank() == 0:
+
+ if not Path(ema_checkpoint).exists():
+ logger.log(
+ f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist"
+ )
+ return
+
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
+
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ state_dict = dist_util.load_state_dict(
+ ema_checkpoint, map_location=map_location)
+
+ model_ema_state_dict = model.state_dict()
+
+ for k, v in state_dict.items():
+ if k in model_ema_state_dict.keys() and v.size(
+ ) == model_ema_state_dict[k].size():
+ model_ema_state_dict[k] = v
+
+ elif 'IN' in k and getattr(model, 'decomposed_IN', False):
+ model_ema_state_dict[k.replace(
+ 'IN', 'IN.IN')] = v # decomposed IN
+
+ else:
+ logger.log('ignore key: ', k, ": ", v.size())
+
+ ema_params = mp_trainer.state_dict_to_master_params(
+ model_ema_state_dict)
+
+ del state_dict
+
+ # logger.log('ema mark 3, ', model_name, )
+
+ # ! debugging, remove to check which key fails.
+ if dist_util.get_world_size() > 1:
+ dist_util.sync_params(ema_params)
+
+ # logger.log('ema mark 4, ', model_name, )
+ # del ema_params
+ return ema_params
+
+ def _load_optimizer_state(self):
+ main_checkpoint, _ = find_resume_checkpoint() or self.resume_checkpoint
+ opt_checkpoint = bf.join(bf.dirname(main_checkpoint),
+ f"opt{self.resume_step:06}.pt")
+ if bf.exists(opt_checkpoint):
+ logger.log(
+ f"loading optimizer state from checkpoint: {opt_checkpoint}")
+
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ state_dict = dist_util.load_state_dict(opt_checkpoint,
+ map_location=map_location)
+ self.opt.load_state_dict(state_dict)
+ # self.opt.load_state_dict({k: v for k, v in state_dict.items() if k in self.opt.state_dict()})
+
+ del state_dict
+
+ def run_loop(self, batch=None):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch, cond = next(self.data)
+ # if batch is None:
+ if isinstance(self.data, list):
+ if self.step <= self.data[2]:
+ batch = next(self.data[1])
+ else:
+ batch = next(self.data[0])
+ else:
+ batch = next(self.data)
+
+ # batch = next(self.data)
+ if self.novel_view_poses is None:
+ self.novel_view_poses = th.roll(batch['c'], 1, 0).to(
+ dist_util.dev()) # save for eval visualization use
+
+ self.run_step(batch)
+
+ if self.step % 1000 == 0:
+ dist_util.synchronize()
+ th.cuda.empty_cache() # avoid memory leak
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.eval_interval == 0 and self.step != 0:
+ # if self.step % self.eval_interval == 0 and (self.step +
+ # self.resume_step) != 0:
+ # if self.step % self.eval_interval == 0: # ! for debugging
+ # if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ try:
+ self.eval_loop()
+ except Exception as e:
+ logger.log(e)
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0 and self.step != 0:
+ self.save()
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ logger.log('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step -
+ 1) % self.save_interval != 0 and self.step != 1:
+ self.save()
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0 and self.step != 1:
+ self.save()
+
+ @th.no_grad()
+ def eval_loop(self):
+ raise NotImplementedError('')
+
+ def run_step(self, batch, *args):
+ self.forward_backward(batch)
+ took_step = self.mp_trainer_rec.optimize(self.opt)
+ if took_step:
+ self._update_ema()
+ self._anneal_lr()
+ self.log_step()
+
+ def forward_backward(self, batch, *args, **kwargs):
+ # th.cuda.empty_cache()
+ raise NotImplementedError('')
+
+ def _update_ema(self):
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ update_ema(params, self.mp_trainer_rec.master_params, rate=rate)
+
+ def _anneal_lr(self):
+ if not self.lr_anneal_steps:
+ return
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
+ lr = self.lr * (1 - frac_done)
+ for param_group in self.opt.param_groups:
+ param_group["lr"] = lr
+
+ def log_step(self):
+ logger.logkv("step", self.step + self.resume_step)
+ logger.logkv("samples",
+ (self.step + self.resume_step + 1) * self.global_batch)
+
+ def save(self):
+
+ def save_checkpoint(rate, params):
+ state_dict = self.mp_trainer_rec.master_params_to_state_dict(
+ params)
+ if dist_util.get_rank() == 0:
+ logger.log(f"saving model {rate}...")
+ if not rate:
+ filename = f"model_rec{(self.step+self.resume_step):07d}.pt"
+ else:
+ filename = f"ema_{rate}_{(self.step+self.resume_step):07d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename),
+ "wb") as f:
+ th.save(state_dict, f)
+
+ save_checkpoint(
+ 0, self.mp_trainer_rec.master_params) # avoid OOM when saving ckpt
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ save_checkpoint(rate, params)
+ th.cuda.empty_cache()
+
+ dist.barrier()
+
+
+class TrainLoop3DRec(TrainLoopBasic):
+
+ def __init__(
+ self,
+ *,
+ rec_model,
+ loss_class,
+ # diffusion,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ # schedule_sampler=None,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ compile=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ compile=compile,
+ **kwargs)
+
+ # self.rec_model = self.ddp_model
+ # self._prepare_nvs_pose() # for eval novelview visualization
+
+ self.triplane_scaling_divider = 1.0
+ self.latent_name = 'latent_normalized_2Ddiffusion' # normalized triplane latent
+ self.render_latent_behaviour = 'decode_after_vae' # directly render using triplane operations
+
+ th.cuda.empty_cache()
+
+ @th.inference_mode()
+ def render_video_given_triplane(self,
+ planes,
+ rec_model,
+ name_prefix='0',
+ save_img=False,
+ render_reference=None,
+ save_mesh=False):
+
+ planes *= self.triplane_scaling_divider # if setting clip_denoised=True, the sampled planes will lie in [-1,1]. Thus, values beyond [+- std] will be abandoned in this version. Move to IN for later experiments.
+
+ # sr_w_code = getattr(self.ddp_rec_model.module.decoder, 'w_avg', None)
+ # sr_w_code = None
+ batch_size = planes.shape[0]
+
+ # if sr_w_code is not None:
+ # sr_w_code = sr_w_code.reshape(1, 1,
+ # -1).repeat_interleave(batch_size, 0)
+
+ # used during diffusion sampling inference
+ # if not save_img:
+
+ # ! mesh
+
+ if planes.shape[1] == 16: # ffhq/car
+ ddpm_latent = {
+ self.latent_name: planes[:, :12],
+ 'bg_plane': planes[:, 12:16],
+ }
+ else:
+ ddpm_latent = {
+ self.latent_name: planes,
+ }
+
+ ddpm_latent.update(
+ rec_model(latent=ddpm_latent,
+ behaviour='decode_after_vae_no_render'))
+
+ # if export_mesh:
+ # if True:
+ if save_mesh:
+ # mesh_size = 512
+ mesh_size = 256
+ # mesh_size = 384
+ # mesh_size = 320
+ # mesh_thres = 3 # TODO, requires tuning
+ # mesh_thres = 5 # TODO, requires tuning
+ mesh_thres = 10 # TODO, requires tuning
+ import mcubes
+ import trimesh
+ dump_path = f'{logger.get_dir()}/mesh/'
+
+ os.makedirs(dump_path, exist_ok=True)
+
+ grid_out = rec_model(
+ latent=ddpm_latent,
+ grid_size=mesh_size,
+ behaviour='triplane_decode_grid',
+ )
+
+ vtx, faces = mcubes.marching_cubes(
+ grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
+ mesh_thres)
+ vtx = vtx / (mesh_size - 1) * 2 - 1
+
+ # vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0)
+ # vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
+ # vtx_colors = (vtx_colors * 255).astype(np.uint8)
+
+ # mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
+ mesh = trimesh.Trimesh(
+ vertices=vtx,
+ faces=faces,
+ )
+
+ mesh_dump_path = os.path.join(dump_path, f'{name_prefix}.ply')
+ mesh.export(mesh_dump_path, 'ply')
+
+ print(f"Mesh dumped to {dump_path}")
+ del grid_out, mesh
+ th.cuda.empty_cache()
+ # return
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/triplane_{name_prefix}.mp4',
+ mode='I',
+ fps=15,
+ codec='libx264')
+
+ if planes.shape[1] == 16: # ffhq/car
+ ddpm_latent = {
+ self.latent_name: planes[:, :12],
+ 'bg_plane': planes[:, 12:16],
+ }
+ else:
+ ddpm_latent = {
+ self.latent_name: planes,
+ }
+
+ ddpm_latent.update(
+ rec_model(latent=ddpm_latent,
+ behaviour='decode_after_vae_no_render'))
+
+ # planes = planes.repeat_interleave(micro['c'].shape[0], 0)
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ # micro_batchsize = 2
+ # micro_batchsize = batch_size
+
+ if render_reference is None:
+ render_reference = self.eval_data # compat
+ else: # use train_traj
+ for key in ['ins', 'bbox', 'caption']:
+ if key in render_reference:
+ render_reference.pop(key)
+ # render_reference.pop('bbox')
+ # render_reference.pop('caption')
+
+ # compat lst for enumerate
+ render_reference = [{
+ k: v[idx:idx + 1]
+ for k, v in render_reference.items()
+ } for idx in range(40)]
+
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+ for i, batch in enumerate(tqdm(render_reference)):
+ micro = {
+ k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+ # micro = {'c': batch['c'].to(dist_util.dev()).repeat_interleave(batch_size, 0)}
+
+ # all_pred = []
+ pred = rec_model(
+ img=None,
+ c=micro['c'],
+ latent=ddpm_latent,
+ # latent={
+ # # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None
+ # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None
+ # for k, v in ddpm_latent.items()
+ # },
+ behaviour='triplane_dec')
+
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+
+ # save viridis_r depth
+ pred_depth = pred_depth.cpu()[0].permute(1, 2, 0).numpy()
+ pred_depth = (plt.cm.viridis(pred_depth[..., 0])[..., :3]) * 2 - 1
+ pred_depth = th.from_numpy(pred_depth).to(
+ pred['image_raw'].device).permute(2, 0, 1).unsqueeze(0)
+ # st()
+ # pred_depth =
+
+ if 'image_sr' in pred:
+
+ gen_img = pred['image_sr']
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), gen_img,
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 128:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']), pred['image_sr'],
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ gen_img = pred['image_raw']
+
+ pred_vis = th.cat(
+ [
+ # self.pool_128(micro['img']),
+ self.pool_128(gen_img),
+ # self.pool_128(pred_depth.repeat_interleave(3, dim=1))
+ self.pool_128(pred_depth)
+ ],
+ dim=-1) # B, 3, H, W
+
+ if save_img:
+ for batch_idx in range(gen_img.shape[0]):
+ sampled_img = Image.fromarray(
+ (gen_img[batch_idx].permute(1, 2, 0).cpu().numpy() *
+ 127.5 + 127.5).clip(0, 255).astype(np.uint8))
+ if sampled_img.size != (512, 512):
+ sampled_img = sampled_img.resize(
+ (128, 128), Image.HAMMING) # for shapenet
+ sampled_img.save(logger.get_dir() +
+ '/FID_Cals/{}_{}.png'.format(
+ int(name_prefix) * batch_size +
+ batch_idx, i))
+ # print('FID_Cals/{}_{}.png'.format(int(name_prefix)*batch_size+batch_idx, i))
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # if vis.shape[0] > 1:
+ # vis = np.concatenate(np.split(vis, vis.shape[0], axis=0),
+ # axis=-3)
+
+ # if not save_img:
+ for j in range(vis.shape[0]
+ ): # ! currently only export one plane at a time
+ video_out.append_data(vis[j])
+
+ # if not save_img:
+ video_out.close()
+ del video_out
+ print('logged video to: ',
+ f'{logger.get_dir()}/triplane_{name_prefix}.mp4')
+
+ del vis, pred_vis, micro, pred,
+
+ def _init_optim_groups(self, kwargs):
+ if kwargs.get('decomposed', False): # AE
+
+ optim_groups = [
+ # vit encoder
+ {
+ 'name': 'encoder',
+ 'params': self.mp_trainer_rec.model.encoder.parameters(),
+ 'lr': kwargs['encoder_lr'],
+ 'weight_decay': kwargs['encoder_weight_decay']
+ },
+
+ # vit decoder backbone
+ {
+ 'name':
+ 'decoder.vit_decoder',
+ 'params':
+ self.mp_trainer_rec.model.decoder.vit_decoder.parameters(),
+ 'lr':
+ kwargs['vit_decoder_lr'],
+ 'weight_decay':
+ kwargs['vit_decoder_wd']
+ },
+
+ # triplane decoder, may include bg synthesis network
+ {
+ 'name':
+ 'decoder.triplane_decoder',
+ 'params':
+ self.mp_trainer_rec.model.decoder.triplane_decoder.
+ parameters(),
+ 'lr':
+ kwargs['triplane_decoder_lr'],
+ # 'weight_decay': self.weight_decay
+ },
+ ]
+
+ if self.mp_trainer_rec.model.decoder.superresolution is not None:
+ optim_groups.append({
+ 'name':
+ 'decoder.superresolution',
+ 'params':
+ self.mp_trainer_rec.model.decoder.superresolution.
+ parameters(),
+ 'lr':
+ kwargs['super_resolution_lr'],
+ })
+
+ if self.mp_trainer_rec.model.dim_up_mlp is not None:
+ optim_groups.append({
+ 'name':
+ 'dim_up_mlp',
+ 'params':
+ self.mp_trainer_rec.model.dim_up_mlp.parameters(),
+ 'lr':
+ kwargs['encoder_lr'],
+ # 'weight_decay':
+ # self.weight_decay
+ })
+
+ # add 3D aware operators
+ if self.mp_trainer_rec.model.decoder.decoder_pred_3d is not None:
+ optim_groups.append({
+ 'name':
+ 'decoder_pred_3d',
+ 'params':
+ self.mp_trainer_rec.model.decoder.decoder_pred_3d.
+ parameters(),
+ 'lr':
+ kwargs['vit_decoder_lr'],
+ 'weight_decay':
+ kwargs['vit_decoder_wd']
+ })
+
+ if self.mp_trainer_rec.model.decoder.transformer_3D_blk is not None:
+ optim_groups.append({
+ 'name':
+ 'decoder_transformer_3D_blk',
+ 'params':
+ self.mp_trainer_rec.model.decoder.transformer_3D_blk.
+ parameters(),
+ 'lr':
+ kwargs['vit_decoder_lr'],
+ 'weight_decay':
+ kwargs['vit_decoder_wd']
+ })
+
+ if self.mp_trainer_rec.model.decoder.logvar is not None:
+ optim_groups.append({
+ 'name':
+ 'decoder_logvar',
+ 'params':
+ self.mp_trainer_rec.model.decoder.logvar,
+ 'lr':
+ kwargs['vit_decoder_lr'],
+ 'weight_decay':
+ kwargs['vit_decoder_wd']
+ })
+
+ if self.mp_trainer_rec.model.decoder.decoder_pred is not None:
+ optim_groups.append(
+ # MLP triplane SR
+ {
+ 'name':
+ 'decoder.decoder_pred',
+ 'params':
+ self.mp_trainer_rec.model.decoder.decoder_pred.
+ parameters(),
+ 'lr':
+ kwargs['vit_decoder_lr'],
+ # 'weight_decay': 0
+ 'weight_decay':
+ kwargs['vit_decoder_wd']
+ }, )
+
+ if self.mp_trainer_rec.model.confnet is not None:
+ optim_groups.append({
+ 'name':
+ 'confnet',
+ 'params':
+ self.mp_trainer_rec.model.confnet.parameters(),
+ 'lr':
+ 1e-5, # as in unsup3d
+ })
+
+ # self.opt = AdamW(optim_groups)
+
+ if dist_util.get_rank() == 0:
+ logger.log('using independent optimizer for each components')
+ else:
+ optim_groups = [
+ dict(name='mp_trainer.master_params',
+ params=self.mp_trainer_rec.master_params,
+ lr=self.lr,
+ weight_decay=self.weight_decay)
+ ]
+
+ logger.log(optim_groups)
+
+ return optim_groups
+
+ @th.no_grad()
+ # def eval_loop(self, c_list:list):
+ def eval_novelview_loop(self):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ all_loss_dict = []
+ novel_view_micro = {}
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i == 0:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ if isinstance(v, th.Tensor) else v[0:1]
+ for k, v in batch.items()
+ }
+ else:
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+ # target = {
+ # 'img': micro['img'],
+ # 'depth': micro['depth'],
+ # 'depth_mask': micro['depth_mask']
+ # }
+ # targe
+
+ _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ all_loss_dict.append(loss_dict)
+
+ # ! move to other places, add tensorboard
+
+ # pred_vis = th.cat([
+ # pred['image_raw'],
+ # -pred['image_depth'].repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ # pred_vis = th.cat([
+ # self.pool_64(micro['img']), pred['image_raw'],
+ # pred_depth.repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1) # B, 3, H, W
+
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ val_scores_for_logging = calc_average_loss(all_loss_dict)
+ with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
+ 'a') as f:
+ json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # * log to tensorboard
+ for k, v in val_scores_for_logging.items():
+ self.writer.add_scalar(f'Eval/NovelView/{k}', v,
+ self.step + self.resume_step)
+ del video_out
+ # del pred_vis
+ # del pred
+
+ th.cuda.empty_cache()
+
+ # @th.no_grad()
+ # def eval_loop(self, c_list:list):
+ @th.inference_mode()
+ def eval_loop(self):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+ all_loss_dict = []
+ self.rec_model.eval()
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {
+ k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+
+ pred = self.rec_model(img=micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+ # target = {
+ # 'img': micro['img'],
+ # 'depth': micro['depth'],
+ # 'depth_mask': micro['depth_mask']
+ # }
+
+ # if last_batch or not self.use_ddp:
+ # loss, loss_dict = self.loss_class(pred, target)
+ # else:
+ # with self.ddp_model.no_sync(): # type: ignore
+ _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ all_loss_dict.append(loss_dict)
+
+ # ! move to other places, add tensorboard
+ # gt_vis = th.cat([micro['img'], micro['img']], dim=-1) # TODO, fail to load depth. range [0, 1]
+ # pred_vis = th.cat([
+ # pred['image_raw'],
+ # -pred['image_depth'].repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(1,2,0).cpu().numpy() # ! pred in range[-1, 1]
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ val_scores_for_logging = calc_average_loss(all_loss_dict)
+ with open(os.path.join(logger.get_dir(), 'scores.json'), 'a') as f:
+ json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # * log to tensorboard
+ for k, v in val_scores_for_logging.items():
+ self.writer.add_scalar(f'Eval/Rec/{k}', v,
+ self.step + self.resume_step)
+
+ th.cuda.empty_cache()
+ # if 'SuperresolutionHybrid8X' in self.rendering_kwargs: # ffhq/afhq
+ # self.eval_novelview_loop_trajectory()
+ # else:
+ self.eval_novelview_loop()
+ self.rec_model.train()
+
+ @th.inference_mode()
+ def eval_novelview_loop_trajectory(self):
+ # novel view synthesis given evaluation camera trajectory
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_batch_{i}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ for idx, c in enumerate(self.all_nvs_params):
+ pred = self.rec_model(img=micro['img_to_encoder'],
+ c=c.unsqueeze(0).repeat_interleave(
+ micro['img'].shape[0],
+ 0)) # pred: (B, 3, 64, 64)
+ # c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ else:
+
+ # st()
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ # ! cooncat h dim
+ pred_vis = pred_vis.permute(0, 2, 3, 1).flatten(0,
+ 1) # H W 3
+
+ # vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ # vis = pred_vis.permute(1,2,0).cpu().numpy()
+ vis = pred_vis.cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # for j in range(vis.shape[0]):
+ # video_out.append_data(vis[j])
+ video_out.append_data(vis)
+
+ video_out.close()
+
+ th.cuda.empty_cache()
+
+ def _prepare_nvs_pose(self):
+
+ device = dist_util.dev()
+
+ fov_deg = 18.837 # for ffhq/afhq
+ intrinsics = FOV_to_intrinsics(fov_deg, device=device)
+
+ all_nvs_params = []
+
+ pitch_range = 0.25
+ yaw_range = 0.35
+ num_keyframes = 10 # how many nv poses to sample from
+ w_frames = 1
+
+ cam_pivot = th.Tensor(
+ self.rendering_kwargs.get('avg_camera_pivot')).to(device)
+ cam_radius = self.rendering_kwargs.get('avg_camera_radius')
+
+ for frame_idx in range(num_keyframes):
+
+ cam2world_pose = LookAtPoseSampler.sample(
+ 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_idx /
+ (num_keyframes * w_frames)),
+ 3.14 / 2 - 0.05 +
+ pitch_range * np.cos(2 * 3.14 * frame_idx /
+ (num_keyframes * w_frames)),
+ cam_pivot,
+ radius=cam_radius,
+ device=device)
+
+ camera_params = th.cat(
+ [cam2world_pose.reshape(-1, 16),
+ intrinsics.reshape(-1, 9)], 1)
+
+ all_nvs_params.append(camera_params)
+
+ self.all_nvs_params = th.cat(all_nvs_params, 0)
+
+ def forward_backward(self, batch, *args, **kwargs):
+ # th.cuda.empty_cache()
+ self.mp_trainer_rec.zero_grad()
+ batch_size = batch['img_to_encoder'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ # wrap forward within amp
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ pred = self.rec_model(img=micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+ target = micro
+
+ # ! only enable in ffhq dataset
+ conf_sigma_percl = None
+ conf_sigma_percl_flip = None
+ if 'conf_sigma' in pred:
+ # all_conf_sigma_l1, all_conf_sigma_percl = pred['conf_sigma']
+ # all_conf_sigma_l1 = pred['conf_sigma']
+ all_conf_sigma_l1 = th.nn.functional.interpolate(
+ pred['conf_sigma'],
+ size=pred['image_raw'].shape[-2:],
+ mode='bilinear'
+ ) # dynamically resize to target img size
+ conf_sigma_l1 = all_conf_sigma_l1[:, :1]
+ conf_sigma_l1_flip = all_conf_sigma_l1[:, 1:]
+ # conf_sigma_percl = all_conf_sigma_percl[:,:1]
+ # conf_sigma_percl_flip = all_conf_sigma_percl[:,1:]
+ else:
+ conf_sigma = None
+ conf_sigma_l1 = None
+ conf_sigma_l1_flip = None
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict, fg_mask = self.loss_class(
+ pred,
+ target,
+ step=self.step + self.resume_step,
+ test_mode=False,
+ return_fg_mask=True,
+ conf_sigma_l1=conf_sigma_l1,
+ conf_sigma_percl=conf_sigma_percl)
+
+ if self.loss_class.opt.symmetry_loss:
+ loss_dict['conf_sigma_log'] = conf_sigma_l1.log()
+ pose, intrinsics = micro['c'][:, :16].reshape(
+ -1, 4, 4), micro['c'][:, 16:]
+ flipped_pose = flip_yaw(pose)
+ mirror_c = th.cat(
+ [flipped_pose.reshape(-1, 16), intrinsics], -1)
+
+ nvs_pred = self.rec_model(latent={
+ k: v
+ for k, v in pred.items() if 'latent' in k
+ },
+ c=mirror_c,
+ behaviour='triplane_dec',
+ return_raw_only=True)
+
+ # concat data for supervision
+ nvs_gt = {
+ k: th.flip(target[k], [-1])
+ for k in
+ ['img'] # fliplr leads to wrong color; B 3 H W shape
+ }
+ flipped_fg_mask = th.flip(fg_mask, [-1])
+
+ # if 'conf_sigma' in pred:
+ # conf_sigma = th.flip(pred['conf_sigma'], [-1])
+ # conf_sigma = th.nn.AdaptiveAvgPool2d(fg_mask.shape[-2:])(conf_sigma) # dynamically resize to target img size
+ # else:
+ # conf_sigma=None
+
+ with self.rec_model.no_sync(): # type: ignore
+ loss_symm, loss_dict_symm = self.loss_class.calc_2d_rec_loss(
+ nvs_pred['image_raw'],
+ nvs_gt['img'],
+ flipped_fg_mask,
+ # test_mode=True,
+ test_mode=False,
+ step=self.step + self.resume_step,
+ # conf_sigma=conf_sigma,
+ conf_sigma_l1=conf_sigma_l1_flip,
+ conf_sigma_percl=conf_sigma_percl_flip)
+ # )
+ loss += (loss_symm * 1.0) # as in unsup3d
+ # loss += (loss_symm * 0.5) # as in unsup3d
+ # loss += (loss_symm * 0.01) # as in unsup3d
+ # if conf_sigma is not None:
+ # loss += th.nn.functional.mse_loss(conf_sigma, flipped_fg_mask) * 0.001 # a log that regularizes all confidence to 1
+ for k, v in loss_dict_symm.items():
+ loss_dict[f'{k}_symm'] = v
+ loss_dict[
+ 'flip_conf_sigma_log'] = conf_sigma_l1_flip.log()
+
+ # ! add density-reg in eg3d, tv-loss
+
+ if self.loss_class.opt.density_reg > 0 and self.step % self.loss_class.opt.density_reg_every == 0:
+
+ initial_coordinates = th.rand(
+ (batch_size, 1000, 3),
+ device=dist_util.dev()) * 2 - 1 # [-1, 1]
+ perturbed_coordinates = initial_coordinates + th.randn_like(
+ initial_coordinates
+ ) * self.loss_class.opt.density_reg_p_dist
+ all_coordinates = th.cat(
+ [initial_coordinates, perturbed_coordinates], dim=1)
+
+ sigma = self.rec_model(
+ latent=pred['latent'],
+ coordinates=all_coordinates,
+ directions=th.randn_like(all_coordinates),
+ behaviour='triplane_renderer',
+ )['sigma']
+
+ sigma_initial = sigma[:, :sigma.shape[1] // 2]
+ sigma_perturbed = sigma[:, sigma.shape[1] // 2:]
+
+ TVloss = th.nn.functional.l1_loss(
+ sigma_initial,
+ sigma_perturbed) * self.loss_class.opt.density_reg
+
+ loss_dict.update(dict(tv_loss=TVloss))
+ loss += TVloss
+
+ self.mp_trainer_rec.backward(loss)
+ log_rec3d_loss_dict(loss_dict)
+
+ # for name, p in self.rec_model.named_parameters():
+ # if p.grad is None:
+ # logger.log(f"found rec unused param: {name}")
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ def norm_depth(pred_depth): # to [-1,1]
+ # pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ return -(pred_depth * 2 - 1)
+
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ # infer novel view also
+ if self.loss_class.opt.symmetry_loss:
+ pred_nv_img = nvs_pred
+ else:
+ pred_nv_img = self.rec_model(
+ img=micro['img_to_encoder'],
+ c=self.novel_view_poses) # pred: (B, 3, 64, 64)
+
+ # if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = norm_depth(gt_depth)
+ # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ # gt_depth.min())
+ # if True:
+ fg_mask = pred['image_mask'] * 2 - 1 # 0-1
+ nv_fg_mask = pred_nv_img['image_mask'] * 2 - 1 # 0-1
+ if 'image_depth' in pred:
+ pred_depth = norm_depth(pred['image_depth'])
+ pred_nv_depth = norm_depth(pred_nv_img['image_depth'])
+ else:
+ pred_depth = th.zeros_like(gt_depth)
+ pred_nv_depth = th.zeros_like(gt_depth)
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat(
+ [self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat(
+ [self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+ else:
+ gt_img = self.pool_128(gt_img)
+ gt_depth = self.pool_128(gt_depth)
+
+ pred_vis = th.cat([
+ pred_img,
+ pred_depth.repeat_interleave(3, dim=1),
+ fg_mask.repeat_interleave(3, dim=1),
+ ],
+ dim=-1) # B, 3, H, W
+
+ if 'conf_sigma' in pred:
+ conf_sigma_l1 = (1 / (conf_sigma_l1 + 1e-7)
+ ).repeat_interleave(3, dim=1) * 2 - 1
+ pred_vis = th.cat([
+ pred_vis,
+ conf_sigma_l1,
+ ], dim=-1) # B, 3, H, W
+
+ pred_vis_nv = th.cat([
+ pred_nv_img['image_raw'],
+ pred_nv_depth.repeat_interleave(3, dim=1),
+ nv_fg_mask.repeat_interleave(3, dim=1),
+ ],
+ dim=-1) # B, 3, H, W
+
+ if 'conf_sigma' in pred:
+ # conf_sigma_for_vis = (1/conf_sigma).repeat_interleave(3, dim=1)
+ # conf_sigma_for_vis = (conf_sigma_for_vis / conf_sigma_for_vis.max() ) * 2 - 1 # normalize to [-1,1]
+ conf_sigma_for_vis_flip = (
+ 1 / (conf_sigma_l1_flip + 1e-7)).repeat_interleave(
+ 3, dim=1) * 2 - 1
+ pred_vis_nv = th.cat(
+ [
+ pred_vis_nv,
+ conf_sigma_for_vis_flip,
+ # th.cat([conf_sigma_for_vis, flipped_fg_mask*2-1], -1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ pred_vis = th.cat([pred_vis, pred_vis_nv],
+ dim=-2) # cat in H dim
+
+ gt_vis = th.cat(
+ [
+ gt_img,
+ gt_depth.repeat_interleave(3, dim=1),
+ th.zeros_like(gt_img)
+ ],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ if 'conf_sigma' in pred:
+ gt_vis = th.cat([gt_vis, fg_mask],
+ dim=-1) # placeholder
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ # st()
+ vis = th.cat([gt_vis, pred_vis], dim=-2)
+ # .permute(
+ # 0, 2, 3, 1).cpu()
+ vis_tensor = torchvision.utils.make_grid(
+ vis, nrow=vis.shape[-1] // 64) # HWC
+ torchvision.utils.save_image(
+ vis_tensor,
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
+ value_range=(-1, 1),
+ normalize=True)
+ # vis = vis.numpy() * 127.5 + 127.5
+ # vis = vis.clip(0, 255).astype(np.uint8)
+
+ # Image.fromarray(vis).save(
+ # f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ logger.log(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ # self.writer.add_image(f'images',
+ # vis,
+ # self.step + self.resume_step,
+ # dataformats='HWC')
+ return pred
+
+
+class TrainLoop3DTriplaneRec(TrainLoop3DRec):
+
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ compile=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ compile=compile,
+ **kwargs)
+
+ @th.inference_mode()
+ def eval_loop(self):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+ all_loss_dict = []
+ self.rec_model.eval()
+
+ device = dist_util.dev()
+
+ # to get intrinsics
+ demo_pose = next(self.data)
+ intrinsics = demo_pose['c'][0][16:25].to(device)
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=24,
+ bitrate='10M',
+ codec='libx264')
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+
+ cam_pivot = th.tensor([0, 0, 0], device=dist_util.dev())
+ cam_radius = 1.8
+
+ pitch_range = 0.45
+ yaw_range = 3.14 # 0.35
+ frames = 72
+
+ # TODO, use PanoHead trajectory
+ # for frame_idx in range(frames):
+
+ for pose_idx, (angle_y, angle_p) in enumerate(
+ # zip(np.linspace(-0.4, 0.4, 72), [-0.2] * 72)):
+ # zip(np.linspace(-1.57, 1.57, 72), [-1.57] * 72)):
+ # zip(np.linspace(0,3.14, 72), [0] * 72)): # check canonical pose
+ zip([0.2] * 72, np.linspace(-3.14, 3.14, 72))):
+
+ # cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.cos(2 * 3.14 * frame_idx / (frames)),
+ # 3.14/2 -0.05 + pitch_range * np.sin(2 * 3.14 * frame_idx / (frames)),
+ # cam_pivot,
+ # radius=cam_radius, device=device)
+
+ cam2world_pose = LookAtPoseSampler.sample(
+ np.pi / 2 + angle_y,
+ np.pi / 2 + angle_p,
+ # angle_p,
+ cam_pivot,
+ # horizontal_stddev=0.1, # 0.25
+ # vertical_stddev=0.125, # 0.35,
+ radius=cam_radius,
+ device=device)
+
+ camera_params = th.cat(
+ [cam2world_pose.reshape(-1, 16),
+ intrinsics.reshape(-1, 9)], 1).to(dist_util.dev())
+
+ # micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+ micro = {'c': camera_params}
+
+ pred = self.rec_model(c=micro['c'])
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+
+ pred_vis = th.cat([
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ self.rec_model.train()
+
+
+class TrainLoop3DRecTrajVis(TrainLoop3DRec):
+
+ def __init__(self,
+ *,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+ self.rendering_kwargs = self.rec_model.module.decoder.triplane_decoder.rendering_kwargs # type: ignore
+ self._prepare_nvs_pose() # for eval novelview visualization
+
+ @th.inference_mode()
+ def eval_novelview_loop(self):
+ # novel view synthesis given evaluation camera trajectory
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_batch_{i}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ for idx, c in enumerate(self.all_nvs_params):
+ pred = self.rec_model(img=micro['img_to_encoder'],
+ c=c.unsqueeze(0).repeat_interleave(
+ micro['img'].shape[0],
+ 0)) # pred: (B, 3, 64, 64)
+ # c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3,
+ dim=1)
+ ],
+ dim=-1)
+
+ else:
+
+ # st()
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ # ! cooncat h dim
+ pred_vis = pred_vis.permute(0, 2, 3, 1).flatten(0,
+ 1) # H W 3
+
+ # vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ # vis = pred_vis.permute(1,2,0).cpu().numpy()
+ vis = pred_vis.cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # for j in range(vis.shape[0]):
+ # video_out.append_data(vis[j])
+ video_out.append_data(vis)
+
+ video_out.close()
+
+ th.cuda.empty_cache()
+
+ def _prepare_nvs_pose(self):
+
+ device = dist_util.dev()
+
+ fov_deg = 18.837 # for ffhq/afhq
+ intrinsics = FOV_to_intrinsics(fov_deg, device=device)
+
+ all_nvs_params = []
+
+ pitch_range = 0.25
+ yaw_range = 0.35
+ num_keyframes = 10 # how many nv poses to sample from
+ w_frames = 1
+
+ cam_pivot = th.Tensor(
+ self.rendering_kwargs.get('avg_camera_pivot')).to(device)
+ cam_radius = self.rendering_kwargs.get('avg_camera_radius')
+
+ for frame_idx in range(num_keyframes):
+
+ cam2world_pose = LookAtPoseSampler.sample(
+ 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_idx /
+ (num_keyframes * w_frames)),
+ 3.14 / 2 - 0.05 +
+ pitch_range * np.cos(2 * 3.14 * frame_idx /
+ (num_keyframes * w_frames)),
+ cam_pivot,
+ radius=cam_radius,
+ device=device)
+
+ camera_params = th.cat(
+ [cam2world_pose.reshape(-1, 16),
+ intrinsics.reshape(-1, 9)], 1)
+
+ all_nvs_params.append(camera_params)
+
+ self.all_nvs_params = th.cat(all_nvs_params, 0)
diff --git a/nsr/train_util_cvD.py b/nsr/train_util_cvD.py
new file mode 100644
index 0000000000000000000000000000000000000000..70a1a89da6576457ada7d85dac7fc25f216e6888
--- /dev/null
+++ b/nsr/train_util_cvD.py
@@ -0,0 +1,637 @@
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+import torchvision
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from tqdm import tqdm
+
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion import dist_util, logger
+from guided_diffusion.train_util import (calc_average_loss,
+ log_rec3d_loss_dict,
+ find_resume_checkpoint)
+
+from torch.optim import AdamW
+
+from .train_util import TrainLoopBasic, TrainLoop3DRec
+import vision_aided_loss
+from dnnlib.util import calculate_adaptive_weight
+
+
+def get_blob_logdir():
+ # You can change this to be a separate path to save checkpoints to
+ # a blobstore or some external drive.
+ return logger.get_dir()
+
+
+class TrainLoop3DcvD(TrainLoop3DRec):
+
+ def __init__(
+ self,
+ *,
+ rec_model,
+ loss_class,
+ # diffusion,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ # schedule_sampler=None,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ use_amp=False,
+ cvD_name='cvD',
+ model_name='rec',
+ # SR_TRAINING=True,
+ SR_TRAINING=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ cvD_name=cvD_name,
+ **kwargs)
+
+ # self.rec_model = self.ddp_model
+
+ # device = loss_class.device
+ device = dist_util.dev()
+ # * create vision aided model
+ # TODO, load model
+ self.nvs_cvD = vision_aided_loss.Discriminator(
+ cv_type='clip', loss_type='multilevel_sigmoid_s',
+ device=device).to(device)
+ self.nvs_cvD.cv_ensemble.requires_grad_(False) # Freeze feature extractor
+ # self.nvs_cvD.train()
+
+ #
+ # SR_TRAINING = False
+ cvD_model_params=list(self.nvs_cvD.decoder.parameters())
+ self.SR_TRAINING = SR_TRAINING
+ # SR_TRAINING = True
+ if SR_TRAINING:
+ # width, patch_size = self.nvs_cvD.cv_ensemble
+ vision_width, vision_patch_size = [self.nvs_cvD.cv_ensemble.models[0].model.conv1.weight.shape[k] for k in [0, -1]]
+ self.nvs_cvD.cv_ensemble.models[0].model.conv1 = th.nn.Conv2d(in_channels=6, out_channels=vision_width, kernel_size=vision_patch_size, stride=vision_patch_size, bias=False).to(dist_util.dev())
+ self.nvs_cvD.cv_ensemble.models[0].model.conv1.requires_grad_(True)
+ cvD_model_params += list(self.nvs_cvD.cv_ensemble.models[0].model.conv1.parameters())
+
+ # change normalization metrics
+ self.nvs_cvD.cv_ensemble.models[0].image_mean = self.nvs_cvD.cv_ensemble.models[0].image_mean.repeat(2)
+ self.nvs_cvD.cv_ensemble.models[0].image_std = self.nvs_cvD.cv_ensemble.models[0].image_std.repeat(2)
+
+ # logger.log(f'nvs_cvD_model_params: {cvD_model_params}')
+
+ self._load_and_sync_parameters(model=self.nvs_cvD, model_name='cvD')
+
+ self.mp_trainer_cvD = MixedPrecisionTrainer(
+ model=self.nvs_cvD,
+ use_fp16=self.use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ model_name=cvD_name,
+ use_amp=use_amp,
+ model_params=cvD_model_params
+ )
+
+ # cvD_lr = 4e-5*(lr/1e-5)
+ # cvD_lr = 4e-4*(lr/1e-5)
+ cvD_lr = 1e-4*(lr/1e-5) * self.loss_class.opt.nvs_D_lr_mul
+ # cvD_lr = 1e-5*(lr/1e-5)
+ self.opt_cvD = AdamW(
+ self.mp_trainer_cvD.master_params,
+ lr=cvD_lr,
+ betas=(0, 0.999),
+ eps=1e-8) # dlr in biggan cfg
+
+ logger.log(f'cpt_cvD lr: {cvD_lr}')
+
+ if self.use_ddp:
+ self.ddp_nvs_cvD = DDP(
+ self.nvs_cvD,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ self.ddp_nvs_cvD = self.nvs_cvD
+
+ th.cuda.empty_cache()
+
+ def run_step(self, batch, step='g_step'):
+ # self.forward_backward(batch)
+
+ if step == 'g_step_rec':
+ self.forward_G_rec(batch)
+ took_step_g_rec = self.mp_trainer_rec.optimize(self.opt)
+
+ if took_step_g_rec:
+ self._update_ema() # g_ema
+
+ elif step == 'g_step_nvs':
+ self.forward_G_nvs(batch)
+ took_step_g_nvs = self.mp_trainer_rec.optimize(self.opt)
+
+ if took_step_g_nvs:
+ self._update_ema() # g_ema
+
+ elif step == 'd_step':
+ self.forward_D(batch)
+ _ = self.mp_trainer_cvD.optimize(self.opt_cvD)
+
+ self._anneal_lr()
+ self.log_step()
+
+ def run_loop(self):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch, cond = next(self.data)
+ # if batch is None:
+ batch = next(self.data)
+ self.run_step(batch, 'g_step_rec')
+
+ batch = next(self.data)
+ self.run_step(batch, 'g_step_nvs')
+
+ batch = next(self.data)
+ self.run_step(batch, 'd_step')
+
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.eval_interval == 0 and self.step != 0:
+ if dist_util.get_rank() == 0:
+ self.eval_loop()
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save()
+ self.save(self.mp_trainer_cvD, 'cvD')
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ logger.log('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ self.save(self.mp_trainer_cvD, 'cvD')
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ self.save(self.mp_trainer_cvD, 'cvD')
+
+ # def forward_backward(self, batch, *args, **kwargs):
+ # blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
+
+ def run_D_Diter(self, real, fake, D=None):
+ # Dmain: Minimize logits for generated images and maximize logits for real images.
+ if D is None:
+ D = self.ddp_nvs_cvD
+
+ lossD = D(real, for_real=True).mean() + D(
+ fake, for_real=False).mean()
+ return lossD
+
+ def forward_D(self, batch): # update D
+ self.mp_trainer_cvD.zero_grad()
+ self.ddp_nvs_cvD.requires_grad_(True)
+ self.rec_model.requires_grad_(False)
+
+ batch_size = batch['img'].shape[0]
+
+ # * sample a new batch for D training
+ for i in range(0, batch_size, self.microbatch):
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous()
+ for k, v in batch.items()
+ }
+
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_cvD.use_amp):
+
+ # pred = self.rec_model(img=micro['img_to_encoder'],
+ # c=micro['c']) # pred: (B, 3, 64, 64)
+
+ pred = self.rec_model(
+ img=micro['img_to_encoder'],
+ c=th.cat([
+ micro['c'][1:],
+ micro['c'][:1], # half novel view, half orig view
+ ]))
+
+ real_logits_cv = self.run_D_Diter(
+ real=micro['img_to_encoder'],
+ fake=pred['image_raw']) # TODO, add SR for FFHQ
+
+ log_rec3d_loss_dict({'vision_aided_loss/D': real_logits_cv})
+
+ self.mp_trainer_cvD.backward(real_logits_cv)
+
+ def forward_G_rec(self, batch): # update G
+
+ self.mp_trainer_rec.zero_grad()
+ self.rec_model.requires_grad_(True)
+ self.ddp_nvs_cvD.requires_grad_(False)
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous()
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ # VQ3D novel view d loss
+ # duplicated_for_nvs = th.cat([
+ # micro['img_to_encoder'][:batch_size - 2],
+ # micro['img_to_encoder'][:2]
+ # ], 0)
+
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ pred = self.rec_model(
+ img=micro['img_to_encoder'], c=micro['c']
+ ) # render novel view for first half of the batch for D loss
+
+ target_for_rec = micro
+ pred_for_rec = pred
+
+ # pred_for_rec = {
+ # k: v[:batch_size - 2] if v is not None else None
+ # for k, v in pred.items()
+ # }
+ # target_for_rec = {
+ # k: v[:batch_size - 2] if v is not None else None
+ # for k, v in target.items()
+ # }
+
+ if last_batch or not self.use_ddp:
+ loss, loss_dict = self.loss_class(pred_for_rec,
+ target_for_rec,
+ test_mode=False)
+ else:
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict = self.loss_class(pred_for_rec,
+ target_for_rec,
+ test_mode=False)
+
+ # add cvD supervision
+ vision_aided_loss = self.ddp_nvs_cvD(
+ pred_for_rec['image_raw'],
+ for_G=True).mean() # [B, 1] shape
+
+ last_layer = self.rec_model.module.decoder.triplane_decoder.decoder.net[ # type: ignore
+ -1].weight # type: ignore
+
+ d_weight = calculate_adaptive_weight(
+ loss, vision_aided_loss, last_layer,
+ # disc_weight_max=0.1) * 0.1
+ # disc_weight_max=0.1) * 0.05
+ disc_weight_max=1)
+ loss += vision_aided_loss * d_weight
+
+ loss_dict.update({
+ 'vision_aided_loss/G_rec': vision_aided_loss,
+ 'd_weight': d_weight
+ })
+
+ log_rec3d_loss_dict(loss_dict)
+
+ self.mp_trainer_rec.backward(loss)
+
+ # ! move to other places, add tensorboard
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ elif pred['image_sr'].shape[-1] == 256:
+ pred_img = th.cat(
+ [self.pool_256(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_256(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_256(pred_depth)
+ gt_depth = self.pool_256(gt_depth)
+
+ else:
+ pred_img = th.cat(
+ [self.pool_128(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_128(micro['img']), micro['img_sr']],
+ dim=-1)
+ gt_depth = self.pool_128(gt_depth)
+ pred_depth = self.pool_128(pred_depth)
+
+ gt_vis = th.cat(
+ [gt_img, gt_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ pred_vis = th.cat(
+ [pred_img,
+ pred_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}_rec.jpg'
+ )
+ logger.log(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}_rec.jpg'
+ )
+
+ def forward_G_nvs(self, batch): # update G
+
+ self.mp_trainer_rec.zero_grad()
+ self.rec_model.requires_grad_(True)
+ self.ddp_nvs_cvD.requires_grad_(False)
+
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous()
+ for k, v in batch.items()
+ }
+
+ # last_batch = (i + self.microbatch) >= batch_size
+
+ # VQ3D novel view d loss
+ # duplicated_for_nvs = th.cat([
+ # micro['img_to_encoder'][batch_size // 2:],
+ # micro['img_to_encoder'][:batch_size // 2]
+ # ], 0)
+
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ pred = self.rec_model(
+ # img=duplicated_for_nvs, c=micro['c']
+ img=micro['img_to_encoder'],
+ c=th.cat([
+ micro['c'][1:],
+ micro['c'][:1],
+ ])
+ ) # render novel view for first half of the batch for D loss
+
+ # add cvD supervision
+ vision_aided_loss = self.ddp_nvs_cvD(
+ pred['image_raw'], for_G=True).mean() # [B, 1] shape
+
+ # loss = vision_aided_loss * 0.01
+ # loss = vision_aided_loss * 0.005
+ # loss = vision_aided_loss * 0.1
+ loss = vision_aided_loss * 0.01
+
+ log_rec3d_loss_dict({
+ 'vision_aided_loss/G_nvs':
+ vision_aided_loss,
+ })
+
+ self.mp_trainer_rec.backward(loss)
+
+ # ! move to other places, add tensorboard
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ if 'image_sr' in pred:
+ pred_img = th.cat(
+ [self.pool_512(pred_img), pred['image_sr']],
+ dim=-1)
+ gt_img = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr']],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ gt_vis = th.cat(
+ [gt_img, gt_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ pred_vis = th.cat(
+ [pred_img,
+ pred_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # B, 3, H, W
+
+ # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ # 1, 2, 0).cpu() # ! pred in range[-1, 1]
+ vis = th.cat([gt_vis, pred_vis], dim=-2)
+
+ vis = torchvision.utils.make_grid(
+ vis,
+ normalize=True,
+ scale_each=True,
+ value_range=(-1, 1)).cpu().permute(1, 2, 0) # H W 3
+ vis = vis.numpy() * 255
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # logger.log(vis.shape)
+
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg'
+ )
+ logger.log(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg'
+ )
+
+ def save(self, mp_trainer=None, model_name='rec'):
+ if mp_trainer is None:
+ mp_trainer = self.mp_trainer_rec
+
+ def save_checkpoint(rate, params):
+ state_dict = mp_trainer.master_params_to_state_dict(params)
+ if dist_util.get_rank() == 0:
+ logger.log(f"saving model {model_name} {rate}...")
+ if not rate:
+ filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt"
+ else:
+ filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename),
+ "wb") as f:
+ th.save(state_dict, f)
+
+ save_checkpoint(0, mp_trainer.master_params)
+
+ if model_name == 'ddpm':
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ save_checkpoint(rate, params)
+
+ dist.barrier()
+
+ def _load_and_sync_parameters(self, model=None, model_name='rec'):
+ resume_checkpoint, self.resume_step = find_resume_checkpoint(
+ self.resume_checkpoint, model_name) or self.resume_checkpoint
+
+ if model is None:
+ model = self.rec_model # default model in the parent class
+
+ logger.log(resume_checkpoint)
+
+ if resume_checkpoint and Path(resume_checkpoint).exists():
+ if dist_util.get_rank() == 0:
+
+ logger.log(
+ f"loading model from checkpoint: {resume_checkpoint}...")
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ logger.log(f'mark {model_name} loading ', )
+ resume_state_dict = dist_util.load_state_dict(
+ resume_checkpoint, map_location=map_location)
+ logger.log(f'mark {model_name} loading finished', )
+
+ model_state_dict = model.state_dict()
+
+ for k, v in resume_state_dict.items():
+
+ if k in model_state_dict.keys() and v.size(
+ ) == model_state_dict[k].size():
+ model_state_dict[k] = v
+
+ # elif 'IN' in k and model_name == 'rec' and getattr(model.decoder, 'decomposed_IN', False):
+ # model_state_dict[k.replace('IN', 'superresolution.norm.norm_layer')] = v # decomposed IN
+ elif 'attn.wk' in k or 'attn.wv' in k: # old qkv
+ logger.log('ignore ', k)
+
+ elif 'decoder.vit_decoder.blocks' in k:
+ # st()
+ # load from 2D ViT pre-trained into 3D ViT blocks.
+ assert len(model.decoder.vit_decoder.blocks[0].vit_blks) == 2 # assert depth=2 here.
+ fusion_ca_depth = len(model.decoder.vit_decoder.blocks[0].vit_blks)
+ vit_subblk_index = int(k.split('.')[3])
+ vit_blk_keyname = ('.').join(k.split('.')[4:])
+ fusion_blk_index = vit_subblk_index // fusion_ca_depth
+ fusion_blk_subindex = vit_subblk_index % fusion_ca_depth
+ model_state_dict[f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}'] = v
+ # logger.log('load 2D ViT weight: {}'.format(f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}'))
+
+ elif 'IN' in k:
+ logger.log('ignore ', k)
+
+ elif 'quant_conv' in k:
+ logger.log('ignore ', k)
+
+ else:
+ logger.log('!!!! ignore key: ', k, ": ", v.size(),)
+ if k in model_state_dict:
+ logger.log('shape in model: ', model_state_dict[k].size())
+ else:
+ logger.log(k, 'not in model_state_dict')
+
+ model.load_state_dict(model_state_dict, strict=True)
+ del model_state_dict
+
+ if dist_util.get_world_size() > 1:
+ dist_util.sync_params(model.parameters())
+ logger.log(f'synced {model_name} params')
diff --git a/nsr/train_util_diffusion.py b/nsr/train_util_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4718ca441d954a34df8c094d8027209c42a0dd68
--- /dev/null
+++ b/nsr/train_util_diffusion.py
@@ -0,0 +1,1736 @@
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+# from PIL import Image
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard.writer import SummaryWriter
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+
+from guided_diffusion.gaussian_diffusion import _extract_into_tensor
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+# from .train_util import TrainLoop3DRec
+from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_loss_dict,
+ log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+
+import dnnlib
+
+from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
+
+# AMP
+# from accelerate import Accelerator
+
+# from ..guided_diffusion.train_util import TrainLoop
+
+# use_amp = False
+# use_amp = True
+
+
+class TrainLoopDiffusionWithRec(TrainLoop):
+ """an interface with rec_model required apis
+ """
+
+ def __init__(
+ self,
+ *,
+ model,
+ diffusion,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ triplane_scaling_divider=1,
+ use_amp=False,
+ diffusion_input_size=224,
+ schedule_sampler=None,
+ model_name='ddpm',
+ **kwargs,
+ ):
+ super().__init__(
+ model=model,
+ diffusion=diffusion,
+ data=data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ schedule_sampler=schedule_sampler,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ use_amp=use_amp,
+ model_name=model_name,
+ **kwargs,
+ )
+
+ self.latent_name = 'latent_normalized' # normalized triplane latent
+ self.diffusion_input_size = diffusion_input_size
+ self.render_latent_behaviour = 'triplane_dec' # directly render using triplane operations
+
+ self.loss_class = loss_class
+ # self.rec_model = rec_model
+ self.eval_interval = eval_interval
+ self.eval_data = eval_data
+ self.iterations = iterations
+ # self.triplane_std = 10
+ self.triplane_scaling_divider = triplane_scaling_divider
+
+ if dist_util.get_rank() == 0:
+ self.writer = SummaryWriter(log_dir=f'{logger.get_dir()}/runs')
+
+ # def _init_optim_groups(self, rec_model):
+ # """for initializing the reconstruction model.
+ # """
+ # kwargs = self.kwargs
+ # optim_groups = [
+ # # vit encoder
+ # {
+ # 'name': 'vit_encoder',
+ # 'params': rec_model.encoder.parameters(),
+ # 'lr': kwargs['encoder_lr'],
+ # 'weight_decay': kwargs['encoder_weight_decay']
+ # },
+ # # vit decoder
+ # {
+ # 'name': 'vit_decoder',
+ # 'params': rec_model.decoder.vit_decoder.parameters(),
+ # 'lr': kwargs['vit_decoder_lr'],
+ # 'weight_decay': kwargs['vit_decoder_wd']
+ # },
+ # {
+ # 'name': 'vit_decoder_pred',
+ # 'params': rec_model.decoder.decoder_pred.parameters(),
+ # 'lr': kwargs['vit_decoder_lr'],
+ # # 'weight_decay': 0
+ # 'weight_decay': kwargs['vit_decoder_wd']
+ # },
+
+ # # triplane decoder
+ # {
+ # 'name': 'triplane_decoder',
+ # 'params': rec_model.decoder.triplane_decoder.parameters(),
+ # 'lr': kwargs['triplane_decoder_lr'],
+ # # 'weight_decay': self.weight_decay
+ # },
+ # ]
+
+ # if rec_model.decoder.superresolution is not None:
+ # optim_groups.append({
+ # 'name':
+ # 'triplane_decoder_superresolution',
+ # 'params':
+ # rec_model.decoder.superresolution.parameters(),
+ # 'lr':
+ # kwargs['super_resolution_lr'],
+ # })
+
+ # return optim_groups
+
+ @th.inference_mode()
+ def render_video_given_triplane(self,
+ planes,
+ rec_model,
+ name_prefix='0',
+ save_img=False,
+ render_reference=None,
+ export_mesh=False,
+ render_all=False):
+
+ planes *= self.triplane_scaling_divider # if setting clip_denoised=True, the sampled planes will lie in [-1,1]. Thus, values beyond [+- std] will be abandoned in this version. Move to IN for later experiments.
+
+ # sr_w_code = getattr(self.ddp_rec_model.module.decoder, 'w_avg', None)
+ # sr_w_code = None
+ batch_size = planes.shape[0]
+
+
+
+ # ! ffhq visualization
+ # cam_radius = 2.7
+ # cam_pivot = th.tensor([0,0,0.2], device=dist_util.dev())
+ # fov_deg = 18.837 # ! fixed for FFHQ
+
+ # device = dist_util.dev()
+ # intrinsics = FOV_to_intrinsics(fov_deg, device=dist_util.dev())
+ # all_camera = []
+
+ # angle_p = -0.2
+ # for pose_idx, (angle_y, angle_p) in enumerate(
+ # zip(np.linspace(-1.57/2, 1.57*3/2, 72), [-0.2] * 36)):
+
+ # cam2world_pose = LookAtPoseSampler.sample(
+ # np.pi / 2 + angle_y,
+ # np.pi / 2 + angle_p,
+ # cam_pivot,
+ # radius=cam_radius,
+ # device=device)
+ # camera_params = th.cat(
+ # [cam2world_pose.reshape(-1, 16),
+ # intrinsics.reshape(-1, 9)],
+ # 1)
+ # all_camera.append(camera_params)
+ # all_camera = th.cat(all_camera, 0)
+ # st() # th.save(all_camera, 'assets/ffhq_eval_pose.pt')
+
+
+ # if sr_w_code is not None:
+ # sr_w_code = sr_w_code.reshape(1, 1,
+ # -1).repeat_interleave(batch_size, 0)
+
+ # used during diffusion sampling inference
+ # if not save_img:
+
+ # ! mesh
+
+ if planes.shape[1] == 16: # ffhq/car
+ ddpm_latent = {
+ self.latent_name: planes[:, :12],
+ 'bg_plane': planes[:, 12:16],
+ }
+ else:
+ ddpm_latent = {
+ self.latent_name: planes,
+ }
+
+ ddpm_latent.update(
+ rec_model(latent=ddpm_latent,
+ behaviour='decode_after_vae_no_render'))
+
+ if export_mesh:
+ # if True:
+ # mesh_size = 512
+ mesh_size = 256 # avoid OOM on V100
+ # mesh_size = 384 # only available on A100, otherwise OOM
+ # mesh_size = 320
+ mesh_thres = 10 # TODO, requires tuning
+ import mcubes
+ import trimesh
+ dump_path = f'{logger.get_dir()}/mesh/'
+
+ os.makedirs(dump_path, exist_ok=True)
+
+ grid_out = rec_model(
+ latent=ddpm_latent,
+ grid_size=mesh_size,
+ behaviour='triplane_decode_grid',
+ )
+
+ vtx, faces = mcubes.marching_cubes(
+ grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
+ mesh_thres)
+ vtx = vtx / (mesh_size - 1) * 2 - 1
+
+ # vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0)
+ # vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
+ # vtx_colors = (vtx_colors * 255).astype(np.uint8)
+
+ # mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
+ mesh = trimesh.Trimesh(
+ vertices=vtx,
+ faces=faces,
+ )
+
+ mesh_dump_path = os.path.join(dump_path, f'{name_prefix}.ply')
+ mesh.export(mesh_dump_path, 'ply')
+
+ print(f"Mesh dumped to {dump_path}")
+ del grid_out, mesh
+ th.cuda.empty_cache()
+ # return
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/triplane_{name_prefix}.mp4',
+ mode='I',
+ fps=15,
+ codec='libx264')
+
+ if planes.shape[1] == 16: # ffhq/car
+ ddpm_latent = {
+ self.latent_name: planes[:, :12],
+ 'bg_plane': planes[:, 12:16],
+ }
+ else:
+ ddpm_latent = {
+ self.latent_name: planes,
+ }
+
+ ddpm_latent.update(
+ rec_model(latent=ddpm_latent,
+ behaviour='decode_after_vae_no_render'))
+
+ # planes = planes.repeat_interleave(micro['c'].shape[0], 0)
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ # micro_batchsize = 2
+ # micro_batchsize = batch_size
+
+ if render_reference is None:
+ render_reference = self.eval_data # compat
+ else: # use train_traj
+ for key in ['ins', 'bbox', 'caption']:
+ if key in render_reference:
+ render_reference.pop(key)
+ # render_reference.pop('bbox')
+ # render_reference.pop('caption')
+
+ # compat lst for enumerate
+ if render_all: # render 50 or 250 views, for shapenet
+ render_reference = [{
+ k: v[idx:idx + 1]
+ for k, v in render_reference.items()
+ } for idx in range(render_reference['c'].shape[0])]
+ else:
+ render_reference = [{
+ k: v[idx:idx + 1]
+ for k, v in render_reference.items()
+ } for idx in range(40)]
+
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+ for i, batch in enumerate(tqdm(render_reference)):
+ micro = {
+ k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v
+ for k, v in batch.items()
+ }
+ # micro = {'c': batch['c'].to(dist_util.dev()).repeat_interleave(batch_size, 0)}
+
+ # all_pred = []
+ pred = rec_model(
+ img=None,
+ c=micro['c'],
+ latent=ddpm_latent,
+ # latent={
+ # # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None
+ # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None
+ # for k, v in ddpm_latent.items()
+ # },
+ behaviour='triplane_dec')
+
+ # if True:
+ # pred_depth = pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ # pred_depth.min())
+
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+
+ # save viridis_r depth
+ pred_depth = pred_depth.cpu()[0].permute(1, 2, 0).numpy()
+ pred_depth = (plt.cm.viridis(pred_depth[..., 0])[..., :3]) * 2 - 1
+ pred_depth = th.from_numpy(pred_depth).to(
+ pred['image_raw'].device).permute(2, 0, 1).unsqueeze(0)
+
+
+ if 'image_sr' in pred:
+
+ gen_img = pred['image_sr']
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), gen_img,
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 128:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']), pred['image_sr'],
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ gen_img = pred['image_raw']
+
+ pred_vis = th.cat(
+ [
+ # self.pool_128(micro['img']),
+ self.pool_128(gen_img),
+ # self.pool_128(pred_depth.repeat_interleave(3, dim=1))
+ self.pool_128(pred_depth)
+ ],
+ dim=-1) # B, 3, H, W
+
+ if save_img:
+ for batch_idx in range(gen_img.shape[0]):
+ sampled_img = Image.fromarray(
+ (gen_img[batch_idx].permute(1, 2, 0).cpu().numpy() *
+ 127.5 + 127.5).clip(0, 255).astype(np.uint8))
+ if sampled_img.size != (512, 512):
+ sampled_img = sampled_img.resize(
+ (128, 128), Image.HAMMING) # for shapenet
+ sampled_img.save(logger.get_dir() +
+ '/FID_Cals/{}_{}.png'.format(
+ int(name_prefix) * batch_size +
+ batch_idx, i))
+ # print('FID_Cals/{}_{}.png'.format(int(name_prefix)*batch_size+batch_idx, i))
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ # if vis.shape[0] > 1:
+ # vis = np.concatenate(np.split(vis, vis.shape[0], axis=0),
+ # axis=-3)
+
+ # if not save_img:
+ for j in range(vis.shape[0]
+ ): # ! currently only export one plane at a time
+ video_out.append_data(vis[j])
+
+ # if not save_img:
+ video_out.close()
+ del video_out
+ print('logged video to: ',
+ f'{logger.get_dir()}/triplane_{name_prefix}.mp4')
+
+ del vis, pred_vis, micro, pred,
+
+ def _init_optim_groups(self, rec_model, freeze_decoder=False):
+ """for initializing the reconstruction model; fixing decoder part.
+ """
+ kwargs = self.kwargs
+ optim_groups = [
+ # vit encoder
+ {
+ 'name': 'vit_encoder',
+ 'params': rec_model.encoder.parameters(),
+ 'lr': kwargs['encoder_lr'],
+ 'weight_decay': kwargs['encoder_weight_decay']
+ },
+ ]
+
+ if not freeze_decoder:
+ optim_groups += [
+ # vit decoder
+ {
+ 'name': 'vit_decoder',
+ 'params': rec_model.decoder.vit_decoder.parameters(),
+ 'lr': kwargs['vit_decoder_lr'],
+ 'weight_decay': kwargs['vit_decoder_wd']
+ },
+ {
+ 'name': 'vit_decoder_pred',
+ 'params': rec_model.decoder.decoder_pred.parameters(),
+ 'lr': kwargs['vit_decoder_lr'],
+ # 'weight_decay': 0
+ 'weight_decay': kwargs['vit_decoder_wd']
+ },
+
+ # triplane decoder
+ {
+ 'name': 'triplane_decoder',
+ 'params': rec_model.decoder.triplane_decoder.parameters(),
+ 'lr': kwargs['triplane_decoder_lr'],
+ # 'weight_decay': self.weight_decay
+ },
+ ]
+
+ if rec_model.decoder.superresolution is not None:
+ optim_groups.append({
+ 'name':
+ 'triplane_decoder_superresolution',
+ 'params':
+ rec_model.decoder.superresolution.parameters(),
+ 'lr':
+ kwargs['super_resolution_lr'],
+ })
+
+ return optim_groups
+
+ # @th.no_grad()
+ # # def eval_loop(self, c_list:list):
+ # def eval_novelview_loop(self, rec_model):
+ # # novel view synthesis given evaluation camera trajectory
+ # video_out = imageio.get_writer(
+ # f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}.mp4',
+ # mode='I',
+ # fps=60,
+ # codec='libx264')
+
+ # all_loss_dict = []
+ # novel_view_micro = {}
+
+ # # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+ # # for i in range(0, 8, self.microbatch):
+ # # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ # micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ # # st()
+
+ # if i == 0:
+ # novel_view_micro = {
+ # 'img_to_encoder': micro['img_to_encoder'][0:1]
+ # }
+
+ # latent = rec_model(img=novel_view_micro['img_to_encoder'],
+ # behaviour='enc_dec_wo_triplane')
+
+ # # else:
+ # # # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ # # novel_view_micro = {
+ # # k:
+ # # v[0:1].to(dist_util.dev()).repeat_interleave(
+ # # micro['img'].shape[0], 0)
+ # # for k, v in novel_view_micro.items()
+ # # }
+
+ # # pred = rec_model(img=novel_view_micro['img_to_encoder'].repeat_interleave(micro['img'].shape[0], 0),
+ # # c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # # ! only render
+ # pred = rec_model(
+ # latent={
+ # 'latent_after_vit': latent['latent_after_vit'].repeat_interleave(micro['img'].shape[0], 0)
+ # },
+ # c=micro['c'], # predict novel view here
+ # behaviour='triplane_dec',
+ # )
+
+ # # target = {
+ # # 'img': micro['img'],
+ # # 'depth': micro['depth'],
+ # # 'depth_mask': micro['depth_mask']
+ # # }
+ # # targe
+
+ # _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ # all_loss_dict.append(loss_dict)
+
+ # # ! move to other places, add tensorboard
+
+ # # pred_vis = th.cat([
+ # # pred['image_raw'],
+ # # -pred['image_depth'].repeat_interleave(3, dim=1)
+ # # ],
+ # # dim=-1)
+
+ # # normalize depth
+ # # if True:
+ # pred_depth = pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ # pred_depth.min())
+ # if 'image_sr' in pred:
+ # if pred['image_sr'].shape[-1] == 512:
+ # pred_vis = th.cat([
+ # micro['img_sr'],
+ # self.pool_512(pred['image_raw']), pred['image_sr'],
+ # self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+ # else:
+ # assert pred['image_sr'].shape[-1] == 128
+ # pred_vis = th.cat([
+ # micro['img_sr'],
+ # self.pool_128(pred['image_raw']), pred['image_sr'],
+ # self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+ # else:
+ # pred_vis = th.cat([
+ # self.pool_128(micro['img']),
+ # self.pool_128(pred['image_raw']),
+ # self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1) # B, 3, H, W
+
+ # vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ # vis = vis * 127.5 + 127.5
+ # vis = vis.clip(0, 255).astype(np.uint8)
+
+ # for j in range(vis.shape[0]):
+ # video_out.append_data(vis[j])
+
+ # video_out.close()
+
+ # del video_out, vis, pred_vis, pred
+ # th.cuda.empty_cache()
+
+ # val_scores_for_logging = calc_average_loss(all_loss_dict)
+ # with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
+ # 'a') as f:
+ # json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # # * log to tensorboard
+ # for k, v in val_scores_for_logging.items():
+ # self.writer.add_scalar(f'Eval/NovelView/{k}', v,
+ # self.step + self.resume_step)
+
+ @th.no_grad()
+ # def eval_loop(self, c_list:list):
+ def eval_novelview_loop(self, rec_model):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ all_loss_dict = []
+ novel_view_micro = {}
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i == 0:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in batch.items()
+ }
+ else:
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ pred = rec_model(img=novel_view_micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+ # target = {
+ # 'img': micro['img'],
+ # 'depth': micro['depth'],
+ # 'depth_mask': micro['depth_mask']
+ # }
+ # targe
+
+ _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ all_loss_dict.append(loss_dict)
+
+ # ! move to other places, add tensorboard
+
+ # pred_vis = th.cat([
+ # pred['image_raw'],
+ # -pred['image_depth'].repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+
+ # normalize depth
+ # if True:
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ if 'image_sr' in pred:
+
+ if pred['image_sr'].shape[-1] == 512:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ elif pred['image_sr'].shape[-1] == 256:
+
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_256(pred['image_raw']), pred['image_sr'],
+ self.pool_256(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred['image_sr']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+
+ else:
+ # pred_vis = th.cat([
+ # self.pool_64(micro['img']), pred['image_raw'],
+ # pred_depth.repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1) # B, 3, H, W
+
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ val_scores_for_logging = calc_average_loss(all_loss_dict)
+ with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
+ 'a') as f:
+ json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # * log to tensorboard
+ for k, v in val_scores_for_logging.items():
+ self.writer.add_scalar(f'Eval/NovelView/{k}', v,
+ self.step + self.resume_step)
+ del video_out
+ # del pred_vis
+ # del pred
+
+ th.cuda.empty_cache()
+
+ @th.no_grad()
+ def eval_loop(self, rec_model):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+ all_loss_dict = []
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ # pred = self.model(img=micro['img_to_encoder'],
+ # c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # pred of rec model
+ pred = rec_model(img=micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+
+ if 'image_sr' in pred:
+ if pred['image_sr'].shape[-1] == 512:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+ else:
+ assert pred['image_sr'].shape[-1] == 128
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_128(pred['image_raw']), pred['image_sr'],
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+ else:
+ pred_vis = th.cat([
+ self.pool_128(micro['img']),
+ self.pool_128(pred['image_raw']),
+ self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ val_scores_for_logging = calc_average_loss(all_loss_dict)
+ with open(os.path.join(logger.get_dir(), 'scores.json'), 'a') as f:
+ json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # * log to tensorboard
+ for k, v in val_scores_for_logging.items():
+ self.writer.add_scalar(f'Eval/Rec/{k}', v,
+ self.step + self.resume_step)
+
+ del video_out, vis, pred_vis, pred
+ th.cuda.empty_cache()
+ self.eval_novelview_loop(rec_model)
+
+ def save(self, mp_trainer=None, model_name='ddpm'):
+ if mp_trainer is None:
+ mp_trainer = self.mp_trainer
+
+ def save_checkpoint(rate, params):
+ state_dict = mp_trainer.master_params_to_state_dict(params)
+ if dist_util.get_rank() == 0:
+ logger.log(f"saving model {model_name} {rate}...")
+ if not rate:
+ filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt"
+ else:
+ filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename),
+ "wb") as f:
+ th.save(state_dict, f)
+
+ # save_checkpoint(0, self.mp_trainer_ddpm.master_params)
+ save_checkpoint(0, mp_trainer.master_params)
+ if model_name == 'ddpm':
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ save_checkpoint(rate, params)
+
+ th.cuda.empty_cache()
+ dist_util.synchronize()
+
+ def _load_and_sync_parameters(self,
+ model=None,
+ model_name='ddpm',
+ resume_checkpoint=None):
+ if resume_checkpoint is None:
+ resume_checkpoint, self.resume_step = find_resume_checkpoint(
+ self.resume_checkpoint, model_name) or self.resume_checkpoint
+
+ if model is None:
+ model = self.model
+
+ if resume_checkpoint and Path(resume_checkpoint).exists():
+ if dist_util.get_rank() == 0:
+ # ! rank 0 return will cause all other ranks to hang
+ logger.log(
+ f"loading model from checkpoint: {resume_checkpoint}...")
+ map_location = {
+ 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
+ } # configure map_location properly
+
+ logger.log(f'mark {model_name} loading ')
+ resume_state_dict = dist_util.load_state_dict(
+ resume_checkpoint, map_location=map_location)
+ logger.log(f'mark {model_name} loading finished')
+
+ model_state_dict = model.state_dict()
+
+ for k, v in resume_state_dict.items():
+ if k in model_state_dict.keys() and v.size(
+ ) == model_state_dict[k].size():
+ model_state_dict[k] = v
+
+ else:
+ print(
+ '!!!! ignore key: ',
+ k,
+ ": ",
+ v.size(),
+ )
+ if k in model_state_dict:
+ print('shape in model: ',
+ model_state_dict[k].size())
+ else:
+ print(k, ' not in model')
+
+ model.load_state_dict(model_state_dict, strict=True)
+ del model_state_dict
+ else:
+ logger.log(f'{resume_checkpoint} not found.')
+ # print(resume_checkpoint)
+
+ if dist_util.get_world_size() > 1:
+ dist_util.sync_params(model.parameters())
+ # dist_util.sync_params(model.named_parameters())
+ print(f'synced {model_name} params')
+
+ @th.inference_mode()
+ def apply_model_inference(self,
+ x_noisy,
+ t,
+ c=None,
+ model_kwargs={}): # compatiable api
+ # pred_params = self.ddp_model(x_noisy, t, c=c, model_kwargs=model_kwargs)
+ pred_params = self.ddp_model(x_noisy, t,
+ **model_kwargs) # unconditional model
+ return pred_params
+
+ @th.inference_mode()
+ def eval_ddpm_sample(self, rec_model, **kwargs): # , ddpm_model=None):
+ # rec_model.eval()
+ # self.ddpm_model.eval()
+ self.model.eval()
+
+ # if ddpm_model is None:
+ # ddpm_model = self.ddp_model
+
+ args = dnnlib.EasyDict(
+ dict(
+ batch_size=1,
+ # image_size=224,
+ image_size=self.diffusion_input_size,
+ # ddpm_image_size=224,
+ # denoise_in_channels=self.ddp_rec_model.module.decoder.triplane_decoder.out_chans, # type: ignore
+ denoise_in_channels=self.ddpm_model.
+ in_channels, # type: ignore
+ clip_denoised=False,
+ class_cond=False,
+ use_ddim=False))
+
+ model_kwargs = {}
+
+ if args.class_cond:
+ classes = th.randint(low=0,
+ high=NUM_CLASSES,
+ size=(args.batch_size, ),
+ device=dist_util.dev())
+ model_kwargs["y"] = classes
+
+ diffusion = self.diffusion
+ sample_fn = (diffusion.p_sample_loop
+ if not args.use_ddim else diffusion.ddim_sample_loop)
+
+ # for i in range(2):
+ for i in range(1):
+ triplane_sample = sample_fn(
+ # self.ddp_model,
+ self,
+ (args.batch_size, args.denoise_in_channels,
+ self.diffusion_input_size, self.diffusion_input_size),
+ clip_denoised=args.clip_denoised,
+ # model_kwargs=model_kwargs,
+ mixing_normal=True, # !
+ device=dist_util.dev(),
+ # model_kwargs=model_kwargs,
+ **model_kwargs)
+
+ th.cuda.empty_cache()
+ self.render_video_given_triplane(
+ triplane_sample,
+ rec_model,
+ name_prefix=f'{self.step + self.resume_step}_{i}')
+ th.cuda.empty_cache()
+
+ # rec_model.train()
+ # self.ddpm_model.train()
+ # ddpm_model.train()
+ self.model.train()
+
+ # @th.inference_mode()
+ # def render_video_given_triplane(self,
+ # planes,
+ # rec_model,
+ # name_prefix='0',
+ # save_img=False):
+
+ # planes *= self.triplane_scaling_divider # if setting clip_denoised=True, the sampled planes will lie in [-1,1]. Thus, values beyond [+- std] will be abandoned in this version. Move to IN for later experiments.
+
+ # # sr_w_code = getattr(self.ddp_rec_model.module.decoder, 'w_avg', None)
+ # # sr_w_code = None
+ # batch_size = planes.shape[0]
+
+ # # if sr_w_code is not None:
+ # # sr_w_code = sr_w_code.reshape(1, 1,
+ # # -1).repeat_interleave(batch_size, 0)
+
+ # # used during diffusion sampling inference
+ # # if not save_img:
+ # video_out = imageio.get_writer(
+ # f'{logger.get_dir()}/triplane_{name_prefix}.mp4',
+ # mode='I',
+ # fps=15,
+ # codec='libx264')
+
+ # if planes.shape[1] == 16: # ffhq/car
+ # ddpm_latent = {
+ # self.latent_name: planes[:, :12],
+ # 'bg_plane': planes[:, 12:16],
+ # }
+ # else:
+ # ddpm_latent = {
+ # self.latent_name: planes,
+ # }
+
+ # ddpm_latent.update(rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render'))
+
+ # # planes = planes.repeat_interleave(micro['c'].shape[0], 0)
+
+ # # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ # # micro_batchsize = 2
+ # # micro_batchsize = batch_size
+
+ # for i, batch in enumerate(tqdm(self.eval_data)):
+ # micro = {
+ # k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v
+ # for k, v in batch.items()
+ # }
+ # # micro = {'c': batch['c'].to(dist_util.dev()).repeat_interleave(batch_size, 0)}
+
+ # # all_pred = []
+ # pred = rec_model(
+ # img=None,
+ # c=micro['c'],
+ # latent=ddpm_latent,
+ # # latent={
+ # # # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None
+ # # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None
+ # # for k, v in ddpm_latent.items()
+ # # },
+ # behaviour='triplane_dec')
+
+ # # if True:
+ # pred_depth = pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ # pred_depth.min())
+
+ # if 'image_sr' in pred:
+
+ # gen_img = pred['image_sr']
+
+ # if pred['image_sr'].shape[-1] == 512:
+
+ # pred_vis = th.cat([
+ # micro['img_sr'],
+ # self.pool_512(pred['image_raw']), gen_img,
+ # self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+
+ # elif pred['image_sr'].shape[-1] == 128:
+
+ # pred_vis = th.cat([
+ # micro['img_sr'],
+ # self.pool_128(pred['image_raw']), pred['image_sr'],
+ # self.pool_128(pred_depth).repeat_interleave(3, dim=1)
+ # ],
+ # dim=-1)
+
+ # else:
+ # gen_img = pred['image_raw']
+
+ # pooled_depth = self.pool_128(pred_depth.repeat_interleave(3, dim=1))
+ # pred_vis = th.cat(
+ # [
+ # # self.pool_128(micro['img']),
+ # self.pool_128(gen_img),
+ # pooled_depth,
+ # ],
+ # dim=-1) # B, 3, H, W
+
+ # if save_img:
+ # for batch_idx in range(gen_img.shape[0]):
+ # sampled_img = Image.fromarray(
+ # (gen_img[batch_idx].permute(1, 2, 0).cpu().numpy() *
+ # 127.5 + 127.5).clip(0, 255).astype(np.uint8))
+ # if sampled_img.size != (512, 512):
+ # sampled_img = sampled_img.resize(
+ # (128, 128), Image.HAMMING) # for shapenet
+ # sampled_img.save(logger.get_dir() +
+ # '/FID_Cals/{}_{}.png'.format(
+ # int(name_prefix) * batch_size +
+ # batch_idx, i))
+ # # ! save depth
+ # torchvision.utils.save_image(pooled_depth[batch_idx:batch_idx+1],logger.get_dir() +
+ # '/FID_Cals/{}_{}_depth.png'.format(
+ # int(name_prefix) * batch_size +
+ # batch_idx, i), normalize=True, val_range=(0,1), padding=0)
+
+ # # print('FID_Cals/{}_{}.png'.format(int(name_prefix)*batch_size+batch_idx, i))
+
+ # vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ # vis = vis * 127.5 + 127.5
+ # vis = vis.clip(0, 255).astype(np.uint8)
+
+ # # if vis.shape[0] > 1:
+ # # vis = np.concatenate(np.split(vis, vis.shape[0], axis=0),
+ # # axis=-3)
+
+ # # if not save_img:
+ # for j in range(vis.shape[0]
+ # ): # ! currently only export one plane at a time
+ # video_out.append_data(vis[j])
+
+ # # if not save_img:
+ # video_out.close()
+ # del video_out
+ # print('logged video to: ',
+ # f'{logger.get_dir()}/triplane_{name_prefix}.mp4')
+
+ # del vis, pred_vis, micro, pred,
+
+ @th.inference_mode()
+ def render_video_noise_schedule(self, name_prefix='0'):
+
+ # planes *= self.triplane_std # denormalize for rendering
+
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/triplane_visnoise_{name_prefix}.mp4',
+ mode='I',
+ fps=30,
+ codec='libx264')
+
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i % 10 != 0:
+ continue
+
+ # ========= novel view plane settings ====
+ if i == 0:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in batch.items()
+ }
+ else:
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k:
+ v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ latent = self.ddp_rec_model(
+ img=novel_view_micro['img_to_encoder'],
+ c=micro['c'])[self.latent_name] # pred: (B, 3, 64, 64)
+
+ x_start = latent / self.triplane_scaling_divider # normalize std to 1
+ # x_start = latent
+
+ all_pred_vis = []
+ # for t in th.range(0,
+ # 4001,
+ # 500,
+ # dtype=th.long,
+ # device=dist_util.dev()): # cosine 4k steps
+ for t in th.range(0,
+ 1001,
+ 125,
+ dtype=th.long,
+ device=dist_util.dev()): # cosine 4k steps
+
+ # ========= add noise according to t
+ noise = th.randn_like(x_start) # x_start is the x0 image
+ x_t = self.diffusion.q_sample(
+ x_start, t, noise=noise
+ ) # * add noise according to predefined schedule
+ planes_x_t = (x_t * self.triplane_scaling_divider).clamp(
+ -50, 50) # de-scaling noised x_t
+
+ # planes_x_t = (x_t * 1).clamp(
+ # -50, 50) # de-scaling noised x_t
+
+ # ===== visualize
+ pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'],
+ latent=planes_x_t,
+ behaviour=self.render_latent_behaviour
+ ) # pred: (B, 3, 64, 64)
+
+ # pred_depth = pred['image_depth']
+ # pred_depth = (pred_depth - pred_depth.min()) / (
+ # pred_depth.max() - pred_depth.min())
+ # pred_vis = th.cat([
+ # # self.pool_128(micro['img']),
+ # pred['image_raw'],
+ # ],
+ # dim=-1) # B, 3, H, W
+ pred_vis = pred['image_raw']
+
+ all_pred_vis.append(pred_vis)
+ # TODO, make grid
+
+ all_pred_vis = torchvision.utils.make_grid(
+ th.cat(all_pred_vis, 0),
+ nrow=len(all_pred_vis),
+ normalize=True,
+ value_range=(-1, 1),
+ scale_each=True) # normalized to [-1,1]
+
+ vis = all_pred_vis.permute(1, 2, 0).cpu().numpy() # H W 3
+
+ vis = (vis * 255).clip(0, 255).astype(np.uint8)
+
+ video_out.append_data(vis)
+
+ video_out.close()
+ print('logged video to: ',
+ f'{logger.get_dir()}/triplane_visnoise_{name_prefix}.mp4')
+
+ th.cuda.empty_cache()
+
+ @th.inference_mode()
+ def plot_noise_nsr_curve(self, name_prefix='0'):
+ # planes *= self.triplane_std # denormalize for rendering
+
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i % 10 != 0:
+ continue
+
+ # if i == 0:
+ latent = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='enc_dec_wo_triplane') # pred: (B, 3, 64, 64)
+
+ x_start = latent[
+ self.
+ latent_name] / self.triplane_scaling_divider # normalize std to 1
+
+ snr_list = []
+ snr_wo_data_list = []
+ xt_mean = []
+ xt_std = []
+
+ for t in th.range(0,
+ 1001,
+ 5,
+ dtype=th.long,
+ device=dist_util.dev()): # cosine 4k steps
+
+ # ========= add noise according to t
+ noise = th.randn_like(x_start) # x_start is the x0 image
+
+ beta_t = _extract_into_tensor(
+ self.diffusion.sqrt_alphas_cumprod, t, x_start.shape)
+ one_minus_beta_t = _extract_into_tensor(
+ self.diffusion.sqrt_one_minus_alphas_cumprod, t,
+ x_start.shape)
+
+ signal_t = beta_t * x_start
+ noise_t = one_minus_beta_t * noise
+
+ x_t = signal_t + noise_t
+
+ snr = signal_t / (noise_t + 1e-6)
+ snr_wo_data = beta_t / (one_minus_beta_t + 1e-6)
+
+ snr_list.append(abs(snr).mean().cpu().numpy())
+ snr_wo_data_list.append(abs(snr_wo_data).mean().cpu().numpy())
+ xt_mean.append(x_t.mean().cpu().numpy())
+ xt_std.append(x_t.std().cpu().numpy())
+
+ print('xt_mean', xt_mean)
+ print('xt_std', xt_std)
+ print('snr', snr_list)
+
+ th.save(
+ {
+ 'xt_mean': xt_mean,
+ 'xt_std': xt_std,
+ 'snr': snr_list,
+ 'snr_wo_data': snr_wo_data_list,
+ },
+ Path(logger.get_dir()) / f'snr_{i}.pt')
+
+ th.cuda.empty_cache()
+
+
+# a legacy class for direct diffusion training, not joint.
+class TrainLoop3DDiffusion(TrainLoopDiffusionWithRec):
+
+ def __init__(
+ self,
+ *,
+ # model,
+ rec_model,
+ denoise_model,
+ diffusion,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ schedule_sampler=None,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ ignore_resume_opt=False,
+ freeze_ae=False,
+ denoised_ae=True,
+ triplane_scaling_divider=10,
+ use_amp=False,
+ diffusion_input_size=224,
+ **kwargs):
+
+ super().__init__(
+ model=denoise_model,
+ diffusion=diffusion,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ triplane_scaling_divider=triplane_scaling_divider,
+ use_amp=use_amp,
+ diffusion_input_size=diffusion_input_size,
+ schedule_sampler=schedule_sampler,
+ )
+
+ # self.accelerator = Accelerator()
+
+ self._load_and_sync_parameters(model=self.rec_model, model_name='rec')
+
+ # * for loading EMA
+ self.mp_trainer_rec = MixedPrecisionTrainer(
+ model=self.rec_model,
+ use_fp16=self.use_fp16,
+ use_amp=use_amp,
+ fp16_scale_growth=fp16_scale_growth,
+ model_name='rec',
+ )
+ self.denoised_ae = denoised_ae
+
+ if not freeze_ae:
+ self.opt_rec = AdamW(
+ self._init_optim_groups(self.mp_trainer_rec.model))
+ else:
+ print('!! freezing AE !!')
+
+ # if not freeze_ae:
+ if self.resume_step:
+ if not ignore_resume_opt:
+ self._load_optimizer_state()
+ else:
+ logger.warn("Ignoring optimizer state from checkpoint.")
+
+ self.ema_params_rec = [
+ self._load_ema_parameters(
+ rate,
+ self.rec_model,
+ self.mp_trainer_rec,
+ model_name=self.mp_trainer_rec.model_name)
+ for rate in self.ema_rate
+ ] # for sync reconstruction model
+ else:
+ if not freeze_ae:
+ self.ema_params_rec = [
+ copy.deepcopy(self.mp_trainer_rec.master_params)
+ for _ in range(len(self.ema_rate))
+ ]
+
+ if self.use_ddp is True:
+ self.rec_model = th.nn.SyncBatchNorm.convert_sync_batchnorm(
+ self.rec_model)
+ self.ddp_rec_model = DDP(
+ self.rec_model,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ # find_unused_parameters=True,
+ )
+ else:
+ self.ddp_rec_model = self.rec_model
+
+ if freeze_ae:
+ self.ddp_rec_model.eval()
+ self.ddp_rec_model.requires_grad_(False)
+ self.freeze_ae = freeze_ae
+
+ # if use_amp:
+
+ def _update_ema_rec(self):
+ for rate, params in zip(self.ema_rate, self.ema_params_rec):
+ update_ema(params, self.mp_trainer_rec.master_params, rate=rate)
+
+ def run_loop(self, batch=None):
+ th.cuda.empty_cache()
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # if self.step % self.eval_interval == 0 and self.step != 0:
+ if self.step % self.eval_interval == 0:
+ if dist_util.get_rank() == 0:
+ self.eval_ddpm_sample(self.ddp_rec_model)
+ # continue # TODO, diffusion inference
+ # self.eval_loop()
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+ th.cuda.empty_cache()
+
+ batch = next(self.data)
+ self.run_step(batch)
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.save_interval == 0 and self.step != 0:
+ self.save()
+ if not self.freeze_ae:
+ self.save(self.mp_trainer_rec, 'rec')
+ dist_util.synchronize()
+
+ th.cuda.empty_cache()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ if not self.freeze_ae:
+ self.save(self.mp_trainer_rec, 'rec')
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+ if not self.freeze_ae:
+ self.save(self.mp_trainer_rec, 'rec')
+
+ def run_step(self, batch, cond=None):
+ self.forward_backward(batch,
+ cond) # type: ignore # * 3D Reconstruction step
+ took_step_ddpm = self.mp_trainer.optimize(self.opt)
+ if took_step_ddpm:
+ self._update_ema()
+
+ if not self.freeze_ae:
+ took_step_rec = self.mp_trainer_rec.optimize(self.opt_rec)
+ if took_step_rec:
+ self._update_ema_rec()
+
+ self._anneal_lr()
+ self.log_step()
+
+ def forward_backward(self, batch, *args, **kwargs):
+ # return super().forward_backward(batch, *args, **kwargs)
+ self.mp_trainer.zero_grad()
+ # all_denoised_out = dict()
+ batch_size = batch['img'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {
+ k: v[i:i + self.microbatch].to(dist_util.dev())
+ for k, v in batch.items()
+ }
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ # if not freeze_ae:
+
+ # =================================== ae part ===================================
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp
+ and not self.freeze_ae):
+ # with th.cuda.amp.autocast(dtype=th.float16,
+ # enabled=False,): # ! debugging, no AMP on all the input
+
+ latent = self.ddp_rec_model(
+ img=micro['img_to_encoder'],
+ c=micro['c'],
+ behaviour='enc_dec_wo_triplane') # pred: (B, 3, 64, 64)
+
+ if not self.freeze_ae:
+ target = micro
+ pred = self.rec_model(latent=latent,
+ c=micro['c'],
+ behaviour='triplane_dec')
+
+ if last_batch or not self.use_ddp:
+ ae_loss, loss_dict = self.loss_class(pred,
+ target,
+ test_mode=False)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ ae_loss, loss_dict = self.loss_class(
+ pred, target, test_mode=False)
+
+ log_rec3d_loss_dict(loss_dict)
+ else:
+ ae_loss = th.tensor(0.0).to(dist_util.dev())
+
+ # =================================== prepare for ddpm part ===================================
+
+ micro_to_denoise = latent[
+ self.
+ latent_name] / self.triplane_scaling_divider # normalize std to 1
+
+ t, weights = self.schedule_sampler.sample(
+ micro_to_denoise.shape[0], dist_util.dev())
+
+ model_kwargs = {}
+
+ # print(micro_to_denoise.min(), micro_to_denoise.max())
+ compute_losses = functools.partial(
+ self.diffusion.training_losses,
+ self.ddp_model,
+ micro_to_denoise, # x_start
+ t,
+ model_kwargs=model_kwargs,
+ )
+
+ with th.cuda.amp.autocast(dtype=th.float16,
+ enabled=self.mp_trainer.use_amp):
+
+ if last_batch or not self.use_ddp:
+ losses = compute_losses()
+ # denoised_out = denoised_fn()
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ losses = compute_losses()
+
+ if isinstance(self.schedule_sampler, LossAwareSampler):
+ self.schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach())
+
+ denoise_loss = (losses["loss"] * weights).mean()
+
+ x_t = losses['x_t']
+ model_output = losses['model_output']
+ losses.pop('x_t')
+ losses.pop('model_output')
+
+ log_loss_dict(self.diffusion, t, {
+ k: v * weights
+ for k, v in losses.items()
+ })
+
+ # self.mp_trainer.backward(denoise_loss)
+ # =================================== denosied ae part ===================================
+ # if self.denoised_ae or self.step % 500 == 0:
+ if self.denoised_ae:
+ with th.cuda.amp.autocast(
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp
+ and not self.freeze_ae):
+ # continue
+ denoised_out = denoised_fn()
+
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'],
+ latent=denoised_out['pred_xstart'] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically?
+ behaviour=self.render_latent_behaviour)
+
+ # if self.denoised_ae:
+
+ if last_batch or not self.use_ddp:
+ denoised_ae_loss, loss_dict = self.loss_class(
+ denoised_ae_pred, micro, test_mode=False)
+ else:
+ with self.ddp_model.no_sync(): # type: ignore
+ denoised_ae_loss, loss_dict = self.loss_class(
+ denoised_ae_pred, micro, test_mode=False)
+
+ # * rename
+ loss_dict_denoise_ae = {}
+ for k, v in loss_dict.items():
+ loss_dict_denoise_ae[f'{k}_denoised'] = v.mean()
+ log_rec3d_loss_dict(loss_dict_denoise_ae)
+
+ else:
+ denoised_ae_loss = th.tensor(0.0).to(dist_util.dev())
+
+ loss = ae_loss + denoise_loss + denoised_ae_loss
+ # self.mp_trainer.backward(denosied_ae_loss)
+ # self.mp_trainer.backward(loss)
+
+ # exit AMP before backward
+ self.mp_trainer.backward(loss)
+ # if self.freeze_ae:
+ # else:
+ # self.mp_trainer.backward(denoise_loss)
+
+ # TODO, merge visualization with original AE
+ # =================================== denoised AE log part ===================================
+
+ # if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ if dist_util.get_rank() == 1 and self.step % 500 == 0:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
+ gt_depth.min())
+ # if True:
+
+ if self.freeze_ae:
+ latent_micro = {
+ k:
+ v[0:1].to(dist_util.dev()) if v is not None else v
+ for k, v in latent.items()
+ }
+
+ pred = self.rec_model(latent=latent_micro,
+ c=micro['c'][0:1],
+ behaviour='triplane_dec')
+ else:
+ assert pred is not None
+
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+ pred_img = pred['image_raw']
+ gt_img = micro['img']
+
+ # if 'image_sr' in pred: # TODO
+ # pred_img = th.cat(
+ # [self.pool_512(pred_img), pred['image_sr']],
+ # dim=-1)
+ # gt_img = th.cat(
+ # [self.pool_512(micro['img']), micro['img_sr']],
+ # dim=-1)
+ # pred_depth = self.pool_512(pred_depth)
+ # gt_depth = self.pool_512(gt_depth)
+
+ gt_vis = th.cat(
+ [
+ gt_img, micro['img'], micro['img'],
+ gt_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1)[0:1] # TODO, fail to load depth. range [0, 1]
+
+ sr_w_code = latent_micro.get('sr_w_code', None)
+ if sr_w_code is not None:
+ sr_w_code = sr_w_code[0:1]
+
+ noised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ latent={
+ 'latent_normalized':
+ x_t[0:1] * self.triplane_scaling_divider,
+ # 'sr_w_code': getattr(self.ddp_rec_model.module.decoder,'w_avg').reshape(1,1,-1)
+ 'sr_w_code': sr_w_code
+ }, # TODO, how to define the scale automatically
+ behaviour=self.render_latent_behaviour)
+
+ denoised_fn = functools.partial(
+ self.diffusion.p_mean_variance,
+ self.ddp_model,
+ x_t, # x_start
+ t,
+ model_kwargs=model_kwargs)
+
+ denoised_out = denoised_fn()
+
+ denoised_ae_pred = self.ddp_rec_model(
+ img=None,
+ c=micro['c'][0:1],
+ # latent=denoised_out['pred_xstart'][0:1] * self.
+ # triplane_scaling_divider, # TODO, how to define the scale automatically
+ latent={
+ 'latent_normalized':
+ denoised_out['pred_xstart'][0:1] * self.
+ triplane_scaling_divider, # TODO, how to define the scale automatically
+ # 'sr_w_code': getattr(self.ddp_rec_model.module.decoder,'w_avg').reshape(1,1,-1)
+ # 'sr_w_code': latent_micro['sr_w_code'][0:1]
+ 'sr_w_code':
+ sr_w_code
+ },
+ behaviour=self.render_latent_behaviour)
+
+ assert denoised_ae_pred is not None
+
+ # print(pred_img.shape)
+ # print('denoised_ae:', self.denoised_ae)
+
+ pred_vis = th.cat([
+ pred_img[0:1], noised_ae_pred['image_raw'],
+ denoised_ae_pred['image_raw'],
+ pred_depth[0:1].repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis = th.cat([
+ # self.pool_128(micro['img']), x_t[:, :3, ...],
+ # denoised_out['pred_xstart'][:, :3, ...]
+ # ],
+ # dim=-1)[0].permute(
+ # 1, 2, 0).cpu() # ! pred in range[-1, 1]
+
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}.jpg'
+ )
+ print(
+ 'log denoised vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}.jpg'
+ )
+
+ th.cuda.empty_cache()
+
+
+# /mnt/lustre/yslan/logs/nips23/LSGM/cldm/inference/car/ablation_nomixing/FID50k
diff --git a/nsr/train_util_with_eg3d.py b/nsr/train_util_with_eg3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07240316ca7d963d18b4c7eff3a8aa455343ac6
--- /dev/null
+++ b/nsr/train_util_with_eg3d.py
@@ -0,0 +1,587 @@
+import copy
+import functools
+import json
+import os
+from pathlib import Path
+from pdb import set_trace as st
+
+import blobfile as bf
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+import torchvision
+from PIL import Image
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.fp16_util import MixedPrecisionTrainer
+from guided_diffusion.nn import update_ema
+from guided_diffusion.resample import LossAwareSampler, UniformSampler
+from guided_diffusion.train_util import (calc_average_loss,
+ find_ema_checkpoint,
+ find_resume_checkpoint,
+ get_blob_logdir, log_rec3d_loss_dict,
+ parse_resume_step_from_filename)
+
+from .train_util import TrainLoop3DRec
+
+
+class TrainLoop3DRecEG3D(TrainLoop3DRec):
+
+ def __init__(self,
+ *,
+ G,
+ rec_model,
+ loss_class,
+ data,
+ eval_data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ eval_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=0.001,
+ weight_decay=0,
+ lr_anneal_steps=0,
+ iterations=10001,
+ load_submodule_name='',
+ ignore_resume_opt=False,
+ model_name='rec',
+ use_amp=False,
+ # hybrid_training=False,
+ **kwargs):
+ super().__init__(rec_model=rec_model,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ batch_size=batch_size,
+ microbatch=microbatch,
+ lr=lr,
+ ema_rate=ema_rate,
+ log_interval=log_interval,
+ eval_interval=eval_interval,
+ save_interval=save_interval,
+ resume_checkpoint=resume_checkpoint,
+ use_fp16=use_fp16,
+ fp16_scale_growth=fp16_scale_growth,
+ weight_decay=weight_decay,
+ lr_anneal_steps=lr_anneal_steps,
+ iterations=iterations,
+ load_submodule_name=load_submodule_name,
+ ignore_resume_opt=ignore_resume_opt,
+ model_name=model_name,
+ use_amp=use_amp,
+ **kwargs)
+ self.G = G
+ # self.hybrid_training = hybrid_training
+
+ self.pool_224 = th.nn.AdaptiveAvgPool2d((224, 224))
+
+ @th.no_grad()
+ def run_G(
+ self,
+ z,
+ c,
+ swapping_prob,
+ neural_rendering_resolution,
+ update_emas=False,
+ return_raw_only=False,
+ ):
+ """add truncation psi
+
+ Args:
+ z (_type_): _description_
+ c (_type_): _description_
+ swapping_prob (_type_): _description_
+ neural_rendering_resolution (_type_): _description_
+ update_emas (bool, optional): _description_. Defaults to False.
+
+ Returns:
+ _type_: _description_
+ """
+
+ c_gen_conditioning = th.zeros_like(c)
+
+ # ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
+
+ ws = self.G.mapping(
+ z,
+ c_gen_conditioning,
+ truncation_psi=0.7,
+ truncation_cutoff=None,
+ update_emas=update_emas,
+ )
+
+ gen_output = self.G.synthesis(
+ ws, # BS * 14 * 512
+ c,
+ neural_rendering_resolution=neural_rendering_resolution,
+ update_emas=update_emas,
+ noise_mode='const',
+ return_raw_only=return_raw_only
+ # return_meta=True # return feature_volume
+ ) # fix the SynthesisLayer modulation noise, otherviwe the same latent code may output two different ID
+
+ return gen_output, ws
+
+ def run_loop(self, batch=None):
+ while (not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps):
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # batch, cond = next(self.data)
+ # if batch is None:
+ batch = next(self.data)
+ # batch = self.run_G()
+
+ self.run_step(batch)
+ if self.step % self.log_interval == 0 and dist_util.get_rank(
+ ) == 0:
+ out = logger.dumpkvs()
+ # * log to tensorboard
+ for k, v in out.items():
+ self.writer.add_scalar(f'Loss/{k}', v,
+ self.step + self.resume_step)
+
+ if self.step % self.eval_interval == 0 and self.step != 0:
+ # if dist_util.get_rank() == 0:
+ # self.eval_loop()
+ # self.eval_novelview_loop()
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ if self.step % self.save_interval == 0:
+ self.save()
+ dist_util.synchronize()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST",
+ "") and self.step > 0:
+ return
+
+ self.step += 1
+
+ if self.step > self.iterations:
+ print('reached maximum iterations, exiting')
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+
+ exit()
+
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+
+ def run_step(self, batch, *args):
+ self.forward_backward(batch)
+ took_step = self.mp_trainer_rec.optimize(self.opt)
+ if took_step:
+ self._update_ema()
+ self._anneal_lr()
+ self.log_step()
+
+ def forward_backward(self, batch, *args, **kwargs):
+
+ self.mp_trainer_rec.zero_grad()
+
+ batch_size = batch['c'].shape[0]
+
+ for i in range(0, batch_size, self.microbatch):
+
+ micro = {'c': batch['c'].to(dist_util.dev())}
+
+ with th.no_grad(): # * infer gt
+ eg3d_batch, ws = self.run_G(
+ z=th.randn(micro['c'].shape[0],
+ 512).to(dist_util.dev()),
+ c=micro['c'].to(dist_util.dev(
+ )), # use real img pose here? or synthesized pose.
+ swapping_prob=0,
+ neural_rendering_resolution=128)
+
+ micro.update({
+ 'img':
+ eg3d_batch['image_raw'], # gt
+ 'img_to_encoder':
+ self.pool_224(eg3d_batch['image']),
+ 'depth':
+ eg3d_batch['image_depth'],
+ 'img_sr': eg3d_batch['image'],
+ })
+
+ last_batch = (i + self.microbatch) >= batch_size
+
+ # wrap forward within amp
+ with th.autocast(device_type='cuda',
+ dtype=th.float16,
+ enabled=self.mp_trainer_rec.use_amp):
+
+ pred_gen_output = self.rec_model(
+ img=micro['img_to_encoder'], # pool from 512
+ c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # target = micro
+ target = dict(
+ img=eg3d_batch['image_raw'],
+ shape_synthesized=eg3d_batch['shape_synthesized'],
+ img_sr=eg3d_batch['image'],
+ )
+
+ pred_gen_output['shape_synthesized_query'] = {
+ 'coarse_densities':
+ pred_gen_output['shape_synthesized']['coarse_densities'],
+ 'image_depth': pred_gen_output['image_depth'],
+ }
+
+ eg3d_batch['shape_synthesized']['image_depth'] = eg3d_batch['image_depth']
+
+ batch_size, num_rays, _, _ = pred_gen_output[
+ 'shape_synthesized']['coarse_densities'].shape
+
+
+ for coord_key in ['fine_coords']: # TODO add surface points
+
+ sigma = self.rec_model(
+ latent=pred_gen_output['latent_denormalized'],
+ coordinates=eg3d_batch['shape_synthesized'][coord_key],
+ directions=th.randn_like(
+ eg3d_batch['shape_synthesized'][coord_key]),
+ behaviour='triplane_renderer',
+ )['sigma']
+
+ rendering_kwargs = self.rec_model(
+ behaviour='get_rendering_kwargs')
+
+ sigma = sigma.reshape(
+ batch_size, num_rays,
+ rendering_kwargs['depth_resolution_importance'], 1)
+
+ pred_gen_output['shape_synthesized_query'][
+ f"{coord_key.split('_')[0]}_densities"] = sigma
+
+ # * 2D reconstruction loss
+ if last_batch or not self.use_ddp:
+ loss, loss_dict = self.loss_class(pred_gen_output,
+ target,
+ test_mode=False)
+ else:
+ with self.rec_model.no_sync(): # type: ignore
+ loss, loss_dict = self.loss_class(pred_gen_output,
+ target,
+ test_mode=False)
+
+ # * fully mimic 3D geometry output
+
+ loss_shape = self.calc_shape_rec_loss(
+ pred_gen_output['shape_synthesized_query'],
+ eg3d_batch['shape_synthesized'])
+
+ loss += loss_shape.mean()
+
+ # * add feature loss on feature_image
+ loss_feature_volume = th.nn.functional.mse_loss(
+ eg3d_batch['feature_volume'],
+ pred_gen_output['feature_volume'])
+ loss += loss_feature_volume * 0.1
+
+ loss_ws = th.nn.functional.mse_loss(
+ ws[:, -1:, :],
+ pred_gen_output['sr_w_code'])
+ loss += loss_ws * 0.1
+
+ loss_dict.update(
+ dict(loss_feature_volume=loss_feature_volume,
+ loss=loss,
+ loss_shape=loss_shape,
+ loss_ws=loss_ws))
+
+ loss_dict.update(dict(loss_feature_volume=loss_feature_volume, loss=loss, loss_shape=loss_shape))
+
+ log_rec3d_loss_dict(loss_dict)
+
+
+ self.mp_trainer_rec.backward(loss)
+
+ # for name, p in self.ddp_model.named_parameters():
+ # if p.grad is None:
+ # print(f"found rec unused param: {name}")
+
+ if dist_util.get_rank() == 0 and self.step % 500 == 0:
+ with th.no_grad():
+ # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
+
+ pred_img = pred_gen_output['image_raw']
+ gt_img = micro['img']
+
+ if 'depth' in micro:
+ gt_depth = micro['depth']
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth.unsqueeze(1)
+ gt_depth = (gt_depth - gt_depth.min()) / (
+ gt_depth.max() - gt_depth.min())
+
+ pred_depth = pred_gen_output['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (
+ pred_depth.max() - pred_depth.min())
+
+ gt_vis = th.cat(
+ [gt_img,
+ gt_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+ else:
+
+ gt_vis = th.cat(
+ [gt_img],
+ dim=-1) # TODO, fail to load depth. range [0, 1]
+
+ if 'image_sr' in pred_gen_output:
+ pred_img = th.cat([
+ self.pool_512(pred_img),
+ pred_gen_output['image_sr']
+ ],
+ dim=-1)
+ pred_depth = self.pool_512(pred_depth)
+ gt_depth = self.pool_512(gt_depth)
+
+ gt_vis = th.cat(
+ [self.pool_512(micro['img']), micro['img_sr'], gt_depth.repeat_interleave(3, dim=1)],
+ dim=-1)
+
+ pred_vis = th.cat(
+ [pred_img,
+ pred_depth.repeat_interleave(3, dim=1)],
+ dim=-1) # B, 3, H, W
+
+ vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
+ 1, 2, 0).cpu() # ! pred in range[-1, 1]
+ # vis_grid = torchvision.utils.make_grid(vis) # HWC
+ vis = vis.numpy() * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+ Image.fromarray(vis).save(
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+ print(
+ 'log vis to: ',
+ f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
+
+ # self.writer.add_image(f'images',
+ # vis,
+ # self.step + self.resume_step,
+ # dataformats='HWC')
+ return pred_gen_output
+
+ def calc_shape_rec_loss(
+ self,
+ pred_shape: dict,
+ gt_shape: dict,
+ ):
+
+ loss_shape, loss_shape_dict = self.loss_class.calc_shape_rec_loss(
+ pred_shape,
+ gt_shape,
+ dist_util.dev(),
+ )
+
+ for loss_k, loss_v in loss_shape_dict.items():
+ # training_stats.report('Loss/E/3D/{}'.format(loss_k), loss_v)
+ log_rec3d_loss_dict({'Loss/3D/{}'.format(loss_k): loss_v})
+
+ return loss_shape
+
+ # @th.inference_mode()
+ def eval_novelview_loop(self):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_real_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ all_loss_dict = []
+ novel_view_micro = {}
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i == 0:
+ novel_view_micro = {
+ k: v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in batch.items()
+ }
+
+ else:
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k: v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ # st()
+
+ pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ # all_loss_dict.append(loss_dict)
+
+ # ! move to other places, add tensorboard
+
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ if 'image_sr' in pred:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+ else:
+ pred_vis = th.cat([
+ self.pool_128(micro['img']), pred['image_raw'],
+ pred_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ # val_scores_for_logging = calc_average_loss(all_loss_dict)
+ # with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
+ # 'a') as f:
+ # json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # * log to tensorboard
+ # for k, v in val_scores_for_logging.items():
+ # self.writer.add_scalar(f'Eval/NovelView/{k}', v,
+ # self.step + self.resume_step)
+ del video_out
+ # del pred_vis
+ # del pred
+
+ th.cuda.empty_cache()
+ # self.eval_novelview_loop_eg3d()
+
+
+ @th.inference_mode()
+ def eval_novelview_loop_eg3d(self):
+ # novel view synthesis given evaluation camera trajectory
+ video_out = imageio.get_writer(
+ f'{logger.get_dir()}/video_novelview_synthetic_{self.step+self.resume_step}.mp4',
+ mode='I',
+ fps=60,
+ codec='libx264')
+
+ all_loss_dict = []
+ novel_view_micro = {}
+
+ # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
+ for i, batch in enumerate(tqdm(self.eval_data)):
+ # for i in range(0, 8, self.microbatch):
+ # c = c_list[i].to(dist_util.dev()).reshape(1, -1)
+ micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
+
+ if i == 0:
+ # novel_view_micro = {
+ # k: v[0:1].to(dist_util.dev()).repeat_interleave(
+ # micro['img'].shape[0], 0)
+ # for k, v in batch.items()
+ # }
+
+ with th.no_grad(): # * infer gt
+ eg3d_batch, _ = self.run_G(
+ z=th.randn(micro['c'].shape[0],
+ 512).to(dist_util.dev()),
+ c=micro['c'].to(dist_util.dev(
+ )), # use real img pose here? or synthesized pose.
+ swapping_prob=0,
+ neural_rendering_resolution=128)
+
+ novel_view_micro.update({
+ 'img':
+ eg3d_batch['image_raw'], # gt
+ 'img_to_encoder':
+ self.pool_224(eg3d_batch['image']),
+ 'depth':
+ eg3d_batch['image_depth'],
+ })
+
+ else:
+ # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
+ novel_view_micro = {
+ k: v[0:1].to(dist_util.dev()).repeat_interleave(
+ micro['img'].shape[0], 0)
+ for k, v in novel_view_micro.items()
+ }
+
+ # st()
+
+ pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
+ c=micro['c']) # pred: (B, 3, 64, 64)
+
+ # _, loss_dict = self.loss_class(pred, micro, test_mode=True)
+ # all_loss_dict.append(loss_dict)
+
+ # ! move to other places, add tensorboard
+
+ pred_depth = pred['image_depth']
+ pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
+ pred_depth.min())
+ if 'image_sr' in pred:
+ pred_vis = th.cat([
+ micro['img_sr'],
+ self.pool_512(pred['image_raw']), pred['image_sr'],
+ self.pool_512(pred_depth).repeat_interleave(3, dim=1)
+ ],
+ dim=-1)
+ else:
+ pred_vis = th.cat([
+ self.pool_128(micro['img']), pred['image_raw'],
+ pred_depth.repeat_interleave(3, dim=1)
+ ],
+ dim=-1) # B, 3, H, W
+
+ vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
+ vis = vis * 127.5 + 127.5
+ vis = vis.clip(0, 255).astype(np.uint8)
+
+ for j in range(vis.shape[0]):
+ video_out.append_data(vis[j])
+
+ video_out.close()
+
+ # val_scores_for_logging = calc_average_loss(all_loss_dict)
+ # with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
+ # 'a') as f:
+ # json.dump({'step': self.step, **val_scores_for_logging}, f)
+
+ # # * log to tensorboard
+ # for k, v in val_scores_for_logging.items():
+ # self.writer.add_scalar(f'Eval/NovelView/{k}', v,
+ # self.step + self.resume_step)
+ del video_out
+ # del pred_vis
+ # del pred
+
+ th.cuda.empty_cache()
\ No newline at end of file
diff --git a/nsr/triplane.py b/nsr/triplane.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe317ad864e5c0289115f14946ecce43848e1ad9
--- /dev/null
+++ b/nsr/triplane.py
@@ -0,0 +1,946 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+from threading import local
+import torch
+import torch.nn as nn
+from utils.torch_utils import persistence
+from .networks_stylegan2 import Generator as StyleGAN2Backbone
+from .networks_stylegan2 import ToRGBLayer, SynthesisNetwork, MappingNetwork
+from .volumetric_rendering.renderer import ImportanceRenderer
+from .volumetric_rendering.ray_sampler import RaySampler, PatchRaySampler
+import dnnlib
+from pdb import set_trace as st
+import math
+
+import torch.nn.functional as F
+import itertools
+from ldm.modules.diffusionmodules.model import SimpleDecoder, Decoder
+
+
+@persistence.persistent_class
+class TriPlaneGenerator(torch.nn.Module):
+
+ def __init__(
+ self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ sr_num_fp16_res=0,
+ mapping_kwargs={}, # Arguments for MappingNetwork.
+ rendering_kwargs={},
+ sr_kwargs={},
+ bcg_synthesis_kwargs={},
+ # pifu_kwargs={},
+ # ada_kwargs={}, # not used, place holder
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.renderer = ImportanceRenderer()
+ # if 'PatchRaySampler' in rendering_kwargs:
+ # self.ray_sampler = PatchRaySampler()
+ # else:
+ # self.ray_sampler = RaySampler()
+ self.backbone = StyleGAN2Backbone(z_dim,
+ c_dim,
+ w_dim,
+ img_resolution=256,
+ img_channels=32 * 3,
+ mapping_kwargs=mapping_kwargs,
+ **synthesis_kwargs)
+ self.superresolution = dnnlib.util.construct_class_by_name(
+ class_name=rendering_kwargs['superresolution_module'],
+ channels=32,
+ img_resolution=img_resolution,
+ sr_num_fp16_res=sr_num_fp16_res,
+ sr_antialias=rendering_kwargs['sr_antialias'],
+ **sr_kwargs)
+
+ # self.bcg_synthesis = None
+ if rendering_kwargs.get('use_background', False):
+ self.bcg_synthesis = SynthesisNetwork(
+ w_dim,
+ img_resolution=self.superresolution.input_resolution,
+ img_channels=32,
+ **bcg_synthesis_kwargs)
+ self.bcg_mapping = MappingNetwork(z_dim=z_dim,
+ c_dim=c_dim,
+ w_dim=w_dim,
+ num_ws=self.num_ws,
+ **mapping_kwargs)
+ # New mapping network for self-adaptive camera pose, dim = 3
+
+ self.decoder = OSGDecoder(
+ 32, {
+ 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
+ 'decoder_output_dim': 32
+ })
+ self.neural_rendering_resolution = 64
+ self.rendering_kwargs = rendering_kwargs
+
+ self._last_planes = None
+ self.pool_256 = torch.nn.AdaptiveAvgPool2d((256, 256))
+
+ def mapping(self,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False):
+ if self.rendering_kwargs['c_gen_conditioning_zero']:
+ c = torch.zeros_like(c)
+ return self.backbone.mapping(z,
+ c *
+ self.rendering_kwargs.get('c_scale', 0),
+ truncation_psi=truncation_psi,
+ truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+
+ def synthesis(self,
+ ws,
+ c,
+ neural_rendering_resolution=None,
+ update_emas=False,
+ cache_backbone=False,
+ use_cached_backbone=False,
+ return_meta=False,
+ return_raw_only=False,
+ **synthesis_kwargs):
+
+ return_sampling_details_flag = self.rendering_kwargs.get(
+ 'return_sampling_details_flag', False)
+
+ if return_sampling_details_flag:
+ return_meta = True
+
+ cam2world_matrix = c[:, :16].view(-1, 4, 4)
+ # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0)
+ # c[:, :16] = cam2world_matrix.view(-1, 16)
+ intrinsics = c[:, 16:25].view(-1, 3, 3)
+
+ if neural_rendering_resolution is None:
+ neural_rendering_resolution = self.neural_rendering_resolution
+ else:
+ self.neural_rendering_resolution = neural_rendering_resolution
+
+ H = W = self.neural_rendering_resolution
+ # Create a batch of rays for volume rendering
+ ray_origins, ray_directions = self.ray_sampler(
+ cam2world_matrix, intrinsics, neural_rendering_resolution)
+
+ # Create triplanes by running StyleGAN backbone
+ N, M, _ = ray_origins.shape
+ if use_cached_backbone and self._last_planes is not None:
+ planes = self._last_planes
+ else:
+ planes = self.backbone.synthesis(
+ ws[:, :self.backbone.num_ws, :], # ws, BS 14 512
+ update_emas=update_emas,
+ **synthesis_kwargs)
+ if cache_backbone:
+ self._last_planes = planes
+
+ # Reshape output into three 32-channel planes
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2],
+ planes.shape[-1]) # BS 96 256 256
+
+ # Perform volume rendering
+ # st()
+ rendering_details = self.renderer(
+ planes,
+ self.decoder,
+ ray_origins,
+ ray_directions,
+ self.rendering_kwargs,
+ # return_meta=True)
+ return_meta=return_meta)
+
+ # calibs = create_calib_matrix(c)
+ # all_coords = rendering_details['all_coords']
+ # B, num_rays, S, _ = all_coords.shape
+ # all_coords_B3N = all_coords.reshape(B, -1, 3).permute(0,2,1)
+ # homo_coords = torch.cat([all_coords, torch.zeros_like(all_coords[..., :1])], -1)
+ # homo_coords[..., -1] = 1
+ # homo_coords = homo_coords.reshape(homo_coords.shape[0], -1, 4)
+ # homo_coords = homo_coords.permute(0,2,1)
+ # xyz = calibs @ homo_coords
+ # xyz = xyz.permute(0,2,1).reshape(B, H, W, S, 4)
+ # st()
+
+ # xyz_proj = perspective(all_coords_B3N, calibs)
+ # xyz_proj = xyz_proj.permute(0,2,1).reshape(B, H, W, S, 3) # [0,0] - [1,1]
+ # st()
+
+ feature_samples, depth_samples, weights_samples = (
+ rendering_details[k]
+ for k in ['feature_samples', 'depth_samples', 'weights_samples'])
+
+ if return_sampling_details_flag:
+ shape_synthesized = rendering_details['shape_synthesized']
+ else:
+ shape_synthesized = None
+
+ # Reshape into 'raw' neural-rendered image
+ feature_image = feature_samples.permute(0, 2, 1).reshape(
+ N, feature_samples.shape[-1], H, W).contiguous() # B 32 H W
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+
+ # Run superresolution to get final image
+ rgb_image = feature_image[:, :3] # B 3 H W
+ if not return_raw_only:
+ sr_image = self.superresolution(
+ rgb_image,
+ feature_image,
+ ws[:, -1:, :], # only use the last layer
+ noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
+ **{
+ k: synthesis_kwargs[k]
+ for k in synthesis_kwargs.keys() if k != 'noise_mode'
+ })
+ else:
+ sr_image = rgb_image
+
+ ret_dict = {
+ 'image': sr_image,
+ 'image_raw': rgb_image,
+ 'image_depth': depth_image,
+ 'weights_samples': weights_samples,
+ 'shape_synthesized': shape_synthesized
+ }
+ if return_meta:
+ ret_dict.update({
+ # 'feature_image': feature_image,
+ 'feature_volume':
+ rendering_details['feature_volume'],
+ 'all_coords':
+ rendering_details['all_coords'],
+ 'weights':
+ rendering_details['weights'],
+ })
+
+ return ret_dict
+
+ def sample(self,
+ coordinates,
+ directions,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False,
+ **synthesis_kwargs):
+ # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
+ ws = self.mapping(z,
+ c,
+ truncation_psi=truncation_psi,
+ truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+ planes = self.backbone.synthesis(ws,
+ update_emas=update_emas,
+ **synthesis_kwargs)
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2],
+ planes.shape[-1])
+ return self.renderer.run_model(planes, self.decoder, coordinates,
+ directions, self.rendering_kwargs)
+
+ def sample_mixed(self,
+ coordinates,
+ directions,
+ ws,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ update_emas=False,
+ **synthesis_kwargs):
+ # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
+ planes = self.backbone.synthesis(ws,
+ update_emas=update_emas,
+ **synthesis_kwargs)
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2],
+ planes.shape[-1])
+ return self.renderer.run_model(planes, self.decoder, coordinates,
+ directions, self.rendering_kwargs)
+
+ def forward(self,
+ z,
+ c,
+ truncation_psi=1,
+ truncation_cutoff=None,
+ neural_rendering_resolution=None,
+ update_emas=False,
+ cache_backbone=False,
+ use_cached_backbone=False,
+ **synthesis_kwargs):
+ # Render a batch of generated images.
+ ws = self.mapping(z,
+ c,
+ truncation_psi=truncation_psi,
+ truncation_cutoff=truncation_cutoff,
+ update_emas=update_emas)
+ return self.synthesis(
+ ws,
+ c,
+ update_emas=update_emas,
+ neural_rendering_resolution=neural_rendering_resolution,
+ cache_backbone=cache_backbone,
+ use_cached_backbone=use_cached_backbone,
+ **synthesis_kwargs)
+
+
+from .networks_stylegan2 import FullyConnectedLayer
+
+# class OSGDecoder(torch.nn.Module):
+
+# def __init__(self, n_features, options):
+# super().__init__()
+# self.hidden_dim = 64
+# self.output_dim = options['decoder_output_dim']
+# self.n_features = n_features
+
+# self.net = torch.nn.Sequential(
+# FullyConnectedLayer(n_features,
+# self.hidden_dim,
+# lr_multiplier=options['decoder_lr_mul']),
+# torch.nn.Softplus(),
+# FullyConnectedLayer(self.hidden_dim,
+# 1 + options['decoder_output_dim'],
+# lr_multiplier=options['decoder_lr_mul']))
+
+# def forward(self, sampled_features, ray_directions):
+# # Aggregate features
+# sampled_features = sampled_features.mean(1)
+# x = sampled_features
+
+# N, M, C = x.shape
+# x = x.view(N * M, C)
+
+# x = self.net(x)
+# x = x.view(N, M, -1)
+# rgb = torch.sigmoid(x[..., 1:]) * (
+# 1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+# sigma = x[..., 0:1]
+# return {'rgb': rgb, 'sigma': sigma}
+
+
+@persistence.persistent_class
+class OSGDecoder(torch.nn.Module):
+
+ def __init__(self, n_features, options):
+ super().__init__()
+ self.hidden_dim = 64
+ self.decoder_output_dim = options['decoder_output_dim']
+
+ self.net = torch.nn.Sequential(
+ FullyConnectedLayer(n_features,
+ self.hidden_dim,
+ lr_multiplier=options['decoder_lr_mul']),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(self.hidden_dim,
+ 1 + options['decoder_output_dim'],
+ lr_multiplier=options['decoder_lr_mul']))
+ self.activation = options.get('decoder_activation', 'sigmoid')
+
+ def forward(self, sampled_features, ray_directions):
+ # Aggregate features
+ sampled_features = sampled_features.mean(1)
+ x = sampled_features
+
+ N, M, C = x.shape
+ x = x.view(N * M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = x[..., 1:]
+ sigma = x[..., 0:1]
+ if self.activation == "sigmoid":
+ # Original EG3D
+ rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001
+ elif self.activation == "lrelu":
+ # StyleGAN2-style, use with toRGB
+ rgb = torch.nn.functional.leaky_relu(rgb, 0.2,
+ inplace=True) * math.sqrt(2)
+ return {'rgb': rgb, 'sigma': sigma}
+
+
+class LRMOSGDecoder(nn.Module):
+ """
+ Triplane decoder that gives RGB and sigma values from sampled features.
+ Using ReLU here instead of Softplus in the original implementation.
+
+ Reference:
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
+ """
+ def __init__(self, n_features: int,
+ hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
+ super().__init__()
+ self.decoder_output_dim = 3
+ self.net = nn.Sequential(
+ nn.Linear(3 * n_features, hidden_dim),
+ activation(),
+ *itertools.chain(*[[
+ nn.Linear(hidden_dim, hidden_dim),
+ activation(),
+ ] for _ in range(num_layers - 2)]),
+ nn.Linear(hidden_dim, 1 + self.decoder_output_dim),
+ )
+ # init all bias to zero
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.zeros_(m.bias)
+
+ def forward(self, sampled_features, ray_directions):
+ # Aggregate features by mean
+ # sampled_features = sampled_features.mean(1)
+ # Aggregate features by concatenation
+ _N, n_planes, _M, _C = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+ x = sampled_features
+
+ N, M, C = x.shape
+ x = x.contiguous().view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+
+ return {'rgb': rgb, 'sigma': sigma}
+
+
+class Triplane(torch.nn.Module):
+
+ def __init__(
+ self,
+ c_dim=25, # Conditioning label (C) dimensionality.
+ img_resolution=128, # Output resolution.
+ img_channels=3, # Number of output color channels.
+ out_chans=96,
+ triplane_size=224,
+ rendering_kwargs={},
+ decoder_in_chans=32,
+ decoder_output_dim=32,
+ sr_num_fp16_res=0,
+ sr_kwargs={},
+ create_triplane=False, # for overfitting single instance study
+ bcg_synthesis_kwargs={},
+ lrm_decoder=False,
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution # TODO
+ self.img_channels = img_channels
+ self.triplane_size = triplane_size
+
+ self.decoder_in_chans = decoder_in_chans
+ self.out_chans = out_chans
+
+ self.renderer = ImportanceRenderer()
+
+ if 'PatchRaySampler' in rendering_kwargs:
+ self.ray_sampler = PatchRaySampler()
+ else:
+ self.ray_sampler = RaySampler()
+
+ if lrm_decoder:
+ self.decoder = LRMOSGDecoder(
+ decoder_in_chans,)
+ else:
+ self.decoder = OSGDecoder(
+ decoder_in_chans,
+ {
+ 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
+ # 'decoder_output_dim': 32
+ 'decoder_output_dim': decoder_output_dim
+ })
+
+ self.neural_rendering_resolution = img_resolution # TODO
+ # self.neural_rendering_resolution = 128 # TODO
+ self.rendering_kwargs = rendering_kwargs
+ self.create_triplane = create_triplane
+ if create_triplane:
+ self.planes = nn.Parameter(torch.randn(1, out_chans, 256, 256))
+
+ if bool(sr_kwargs): # check whether empty
+ assert decoder_in_chans == decoder_output_dim, 'tradition'
+ if rendering_kwargs['superresolution_module'] in [
+ 'utils.torch_utils.components.PixelUnshuffleUpsample',
+ 'utils.torch_utils.components.NearestConvSR',
+ 'utils.torch_utils.components.NearestConvSR_Residual'
+ ]:
+ self.superresolution = dnnlib.util.construct_class_by_name(
+ class_name=rendering_kwargs['superresolution_module'],
+ # * for PixelUnshuffleUpsample
+ sr_ratio=2, # 2x SR, 128 -> 256
+ output_dim=decoder_output_dim,
+ num_out_ch=3,
+ )
+ else:
+ self.superresolution = dnnlib.util.construct_class_by_name(
+ class_name=rendering_kwargs['superresolution_module'],
+ # * for stylegan upsample
+ channels=decoder_output_dim,
+ img_resolution=img_resolution,
+ sr_num_fp16_res=sr_num_fp16_res,
+ sr_antialias=rendering_kwargs['sr_antialias'],
+ **sr_kwargs)
+ else:
+ self.superresolution = None
+
+ self.bcg_synthesis = None
+
+ # * pure reconstruction
+ def forward(
+ self,
+ planes=None,
+ # img,
+ c=None,
+ ws=None,
+ ray_origins=None,
+ ray_directions=None,
+ z_bcg=None,
+ neural_rendering_resolution=None,
+ update_emas=False,
+ cache_backbone=False,
+ use_cached_backbone=False,
+ return_meta=False,
+ return_raw_only=False,
+ sample_ray_only=False,
+ fg_bbox=None,
+ **synthesis_kwargs):
+
+ cam2world_matrix = c[:, :16].reshape(-1, 4, 4)
+ # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0)
+ # c[:, :16] = cam2world_matrix.view(-1, 16)
+ intrinsics = c[:, 16:25].reshape(-1, 3, 3)
+
+ if neural_rendering_resolution is None:
+ neural_rendering_resolution = self.neural_rendering_resolution
+ else:
+ self.neural_rendering_resolution = neural_rendering_resolution
+
+ if ray_directions is None: # when output video
+ H = W = self.neural_rendering_resolution
+ # Create a batch of rays for volume rendering
+ # ray_origins, ray_directions, ray_bboxes = self.ray_sampler(
+ # cam2world_matrix, intrinsics, neural_rendering_resolution)
+
+ if sample_ray_only: # ! for sampling
+ ray_origins, ray_directions, ray_bboxes = self.ray_sampler(
+ cam2world_matrix, intrinsics,
+ self.rendering_kwargs.get( 'patch_rendering_resolution' ),
+ self.neural_rendering_resolution, fg_bbox)
+
+ # for patch supervision
+ ret_dict = {
+ 'ray_origins': ray_origins,
+ 'ray_directions': ray_directions,
+ 'ray_bboxes': ray_bboxes,
+ }
+
+ return ret_dict
+
+ else: # ! for rendering
+ ray_origins, ray_directions, _ = self.ray_sampler(
+ cam2world_matrix, intrinsics, self.neural_rendering_resolution,
+ self.neural_rendering_resolution)
+
+ else:
+ assert ray_origins is not None
+ H = W = int(ray_directions.shape[1]**
+ 0.5) # dynamically set patch resolution
+
+ # ! match the batch size, if not returned
+ if planes is None:
+ assert self.planes is not None
+ planes = self.planes.repeat_interleave(c.shape[0], dim=0)
+ return_sampling_details_flag = self.rendering_kwargs.get(
+ 'return_sampling_details_flag', False)
+
+ if return_sampling_details_flag:
+ return_meta = True
+
+ # Create triplanes by running StyleGAN backbone
+ N, M, _ = ray_origins.shape
+
+ # Reshape output into three 32-channel planes
+ if planes.shape[1] == 3 * 2 * self.decoder_in_chans:
+ # if isinstance(planes, tuple):
+ # N *= 2
+ triplane_bg = True
+ # planes = torch.cat(planes, 0) # inference in parallel
+ # ray_origins = ray_origins.repeat(2,1,1)
+ # ray_directions = ray_directions.repeat(2,1,1)
+
+ else:
+ triplane_bg = False
+
+ # assert not triplane_bg
+
+ # ! hard coded, will fix later
+ # if planes.shape[1] == 3 * self.decoder_in_chans:
+ # else:
+
+ # planes = planes.view(len(planes), 3, self.decoder_in_chans,
+ planes = planes.reshape(
+ len(planes),
+ 3,
+ -1, # ! support background plane
+ planes.shape[-2],
+ planes.shape[-1]) # BS 96 256 256
+
+ # Perform volume rendering
+ rendering_details = self.renderer(planes,
+ self.decoder,
+ ray_origins,
+ ray_directions,
+ self.rendering_kwargs,
+ return_meta=return_meta)
+
+ feature_samples, depth_samples, weights_samples = (
+ rendering_details[k]
+ for k in ['feature_samples', 'depth_samples', 'weights_samples'])
+
+ if return_sampling_details_flag:
+ shape_synthesized = rendering_details['shape_synthesized']
+ else:
+ shape_synthesized = None
+
+ # Reshape into 'raw' neural-rendered image
+ feature_image = feature_samples.permute(0, 2, 1).reshape(
+ N, feature_samples.shape[-1], H,
+ W).contiguous() # B 32 H W, in [-1,1]
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+ weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+
+ # Generate Background
+ # if self.bcg_synthesis:
+
+ # # bg composition
+ # # if self.decoder.activation == "sigmoid":
+ # # feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher
+
+ # assert isinstance(
+ # z_bcg, torch.Tensor
+ # ) # 512 latents after reparmaterization, reuse the name
+ # # ws_bcg = ws[:,:self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:,:self.bcg_synthesis.num_ws]
+
+ # with torch.autocast(device_type='cuda',
+ # dtype=torch.float16,
+ # enabled=False):
+
+ # ws_bcg = self.bcg_mapping(z_bcg, c=None) # reuse the name
+ # if ws_bcg.size(1) < self.bcg_synthesis.num_ws:
+ # ws_bcg = torch.cat([
+ # ws_bcg, ws_bcg[:, -1:].repeat(
+ # 1, self.bcg_synthesis.num_ws - ws_bcg.size(1), 1)
+ # ], 1)
+
+ # bcg_image = self.bcg_synthesis(ws_bcg,
+ # update_emas=update_emas,
+ # **synthesis_kwargs)
+ # bcg_image = torch.nn.functional.interpolate(
+ # bcg_image,
+ # size=feature_image.shape[2:],
+ # mode='bilinear',
+ # align_corners=False,
+ # antialias=self.rendering_kwargs['sr_antialias'])
+ # feature_image = feature_image + (1 - weights_samples) * bcg_image
+
+ # # Generate Raw image
+ # assert self.torgb
+ # rgb_image = self.torgb(feature_image,
+ # ws_bcg[:, -1],
+ # fused_modconv=False)
+ # rgb_image = rgb_image.to(dtype=torch.float32,
+ # memory_format=torch.contiguous_format)
+ # # st()
+ # else:
+
+ mask_image = weights_samples * (1 + 2 * 0.001) - 0.001
+ if triplane_bg:
+ # true_bs = N // 2
+ # weights_samples = weights_samples[:true_bs]
+ # mask_image = mask_image[:true_bs]
+ # feature_image = feature_image[:true_bs] * mask_image + feature_image[true_bs:] * (1-mask_image) # the first is foreground
+ # depth_image = depth_image[:true_bs]
+
+ # ! composited colors
+ # rgb_final = (
+ # 1 - fg_ret_dict['weights']
+ # ) * bg_ret_dict['rgb_final'] + fg_ret_dict[
+ # 'feature_samples'] # https://github.com/SizheAn/PanoHead/blob/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/training/triplane.py#L127C45-L127C64
+
+ # ret_dict.update({
+ # 'feature_samples': rgb_final,
+ # })
+ # st()
+ feature_image = (1 - mask_image) * rendering_details[
+ 'bg_ret_dict']['rgb_final'] + feature_image
+
+ rgb_image = feature_image[:, :3]
+
+ # # Run superresolution to get final image
+ if self.superresolution is not None and not return_raw_only:
+ # assert ws is not None, 'feed in [cls] token here for SR module'
+
+ if ws is not None and ws.ndim == 2:
+ ws = ws.unsqueeze(
+ 1)[:, -1:, :] # follow stylegan tradition, B, N, C
+
+ sr_image = self.superresolution(
+ rgb=rgb_image,
+ x=feature_image,
+ base_x=rgb_image,
+ ws=ws, # only use the last layer
+ noise_mode=self.
+ rendering_kwargs['superresolution_noise_mode'], # none
+ **{
+ k: synthesis_kwargs[k]
+ for k in synthesis_kwargs.keys() if k != 'noise_mode'
+ })
+ else:
+ # sr_image = rgb_image
+ sr_image = None
+
+ if shape_synthesized is not None:
+ shape_synthesized.update({
+ 'image_depth': depth_image,
+ }) # for 3D loss easy computation, wrap all 3D in a single dict
+
+ ret_dict = {
+ 'feature_image': feature_image,
+ # 'image_raw': feature_image[:, :3],
+ 'image_raw': rgb_image,
+ 'image_depth': depth_image,
+ 'weights_samples': weights_samples,
+ # 'silhouette': mask_image,
+ # 'silhouette_normalized_3channel': (mask_image*2-1).repeat_interleave(3,1), # N 3 H W
+ 'shape_synthesized': shape_synthesized,
+ "image_mask": mask_image,
+ }
+
+ if sr_image is not None:
+ ret_dict.update({
+ 'image_sr': sr_image,
+ })
+
+ if return_meta:
+ ret_dict.update({
+ 'feature_volume':
+ rendering_details['feature_volume'],
+ 'all_coords':
+ rendering_details['all_coords'],
+ 'weights':
+ rendering_details['weights'],
+ })
+
+ return ret_dict
+
+
+class Triplane_fg_bg_plane(Triplane):
+ # a separate background plane
+
+ def __init__(self,
+ c_dim=25,
+ img_resolution=128,
+ img_channels=3,
+ out_chans=96,
+ triplane_size=224,
+ rendering_kwargs={},
+ decoder_in_chans=32,
+ decoder_output_dim=32,
+ sr_num_fp16_res=0,
+ sr_kwargs={},
+ bcg_synthesis_kwargs={}):
+ super().__init__(c_dim, img_resolution, img_channels, out_chans,
+ triplane_size, rendering_kwargs, decoder_in_chans,
+ decoder_output_dim, sr_num_fp16_res, sr_kwargs,
+ bcg_synthesis_kwargs)
+
+ self.bcg_decoder = Decoder(
+ ch=64, # half channel size
+ out_ch=32,
+ # ch_mult=(1, 2, 4),
+ ch_mult=(1, 2), # use res=64 for now
+ num_res_blocks=2,
+ dropout=0.0,
+ attn_resolutions=(),
+ z_channels=4,
+ resolution=64,
+ in_channels=3,
+ )
+
+ # * pure reconstruction
+ def forward(
+ self,
+ planes,
+ bg_plane,
+ # img,
+ c,
+ ws=None,
+ z_bcg=None,
+ neural_rendering_resolution=None,
+ update_emas=False,
+ cache_backbone=False,
+ use_cached_backbone=False,
+ return_meta=False,
+ return_raw_only=False,
+ **synthesis_kwargs):
+
+ # ! match the batch size
+ if planes is None:
+ assert self.planes is not None
+ planes = self.planes.repeat_interleave(c.shape[0], dim=0)
+ return_sampling_details_flag = self.rendering_kwargs.get(
+ 'return_sampling_details_flag', False)
+
+ if return_sampling_details_flag:
+ return_meta = True
+
+ cam2world_matrix = c[:, :16].reshape(-1, 4, 4)
+ # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0)
+ # c[:, :16] = cam2world_matrix.view(-1, 16)
+ intrinsics = c[:, 16:25].reshape(-1, 3, 3)
+
+ if neural_rendering_resolution is None:
+ neural_rendering_resolution = self.neural_rendering_resolution
+ else:
+ self.neural_rendering_resolution = neural_rendering_resolution
+
+ H = W = self.neural_rendering_resolution
+ # Create a batch of rays for volume rendering
+ ray_origins, ray_directions, _ = self.ray_sampler(
+ cam2world_matrix, intrinsics, neural_rendering_resolution)
+
+ # Create triplanes by running StyleGAN backbone
+ N, M, _ = ray_origins.shape
+
+ # # Reshape output into three 32-channel planes
+ # if planes.shape[1] == 3 * 2 * self.decoder_in_chans:
+ # # if isinstance(planes, tuple):
+ # # N *= 2
+ # triplane_bg = True
+ # # planes = torch.cat(planes, 0) # inference in parallel
+ # # ray_origins = ray_origins.repeat(2,1,1)
+ # # ray_directions = ray_directions.repeat(2,1,1)
+
+ # else:
+ # triplane_bg = False
+
+ # assert not triplane_bg
+
+ planes = planes.view(
+ len(planes),
+ 3,
+ -1, # ! support background plane
+ planes.shape[-2],
+ planes.shape[-1]) # BS 96 256 256
+
+ # Perform volume rendering
+ rendering_details = self.renderer(planes,
+ self.decoder,
+ ray_origins,
+ ray_directions,
+ self.rendering_kwargs,
+ return_meta=return_meta)
+
+ feature_samples, depth_samples, weights_samples = (
+ rendering_details[k]
+ for k in ['feature_samples', 'depth_samples', 'weights_samples'])
+
+ if return_sampling_details_flag:
+ shape_synthesized = rendering_details['shape_synthesized']
+ else:
+ shape_synthesized = None
+
+ # Reshape into 'raw' neural-rendered image
+ feature_image = feature_samples.permute(0, 2, 1).reshape(
+ N, feature_samples.shape[-1], H,
+ W).contiguous() # B 32 H W, in [-1,1]
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+ weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+
+ bcg_image = self.bcg_decoder(bg_plane)
+ bcg_image = torch.nn.functional.interpolate(
+ bcg_image,
+ size=feature_image.shape[2:],
+ mode='bilinear',
+ align_corners=False,
+ antialias=self.rendering_kwargs['sr_antialias'])
+
+ mask_image = weights_samples * (1 + 2 * 0.001) - 0.001
+
+ # ! fuse fg/bg model output
+ feature_image = feature_image + (1 - weights_samples) * bcg_image
+
+ rgb_image = feature_image[:, :3]
+
+ # # Run superresolution to get final image
+ if self.superresolution is not None and not return_raw_only:
+ # assert ws is not None, 'feed in [cls] token here for SR module'
+
+ if ws is not None and ws.ndim == 2:
+ ws = ws.unsqueeze(
+ 1)[:, -1:, :] # follow stylegan tradition, B, N, C
+
+ sr_image = self.superresolution(
+ rgb=rgb_image,
+ x=feature_image,
+ base_x=rgb_image,
+ ws=ws, # only use the last layer
+ noise_mode=self.
+ rendering_kwargs['superresolution_noise_mode'], # none
+ **{
+ k: synthesis_kwargs[k]
+ for k in synthesis_kwargs.keys() if k != 'noise_mode'
+ })
+ else:
+ # sr_image = rgb_image
+ sr_image = None
+
+ if shape_synthesized is not None:
+ shape_synthesized.update({
+ 'image_depth': depth_image,
+ }) # for 3D loss easy computation, wrap all 3D in a single dict
+
+ ret_dict = {
+ 'feature_image': feature_image,
+ # 'image_raw': feature_image[:, :3],
+ 'image_raw': rgb_image,
+ 'image_depth': depth_image,
+ 'weights_samples': weights_samples,
+ # 'silhouette': mask_image,
+ # 'silhouette_normalized_3channel': (mask_image*2-1).repeat_interleave(3,1), # N 3 H W
+ 'shape_synthesized': shape_synthesized,
+ "image_mask": mask_image,
+ }
+
+ if sr_image is not None:
+ ret_dict.update({
+ 'image_sr': sr_image,
+ })
+
+ if return_meta:
+ ret_dict.update({
+ 'feature_volume':
+ rendering_details['feature_volume'],
+ 'all_coords':
+ rendering_details['all_coords'],
+ 'weights':
+ rendering_details['weights'],
+ })
+
+ return ret_dict
diff --git a/nsr/volumetric_rendering/__init__.py b/nsr/volumetric_rendering/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..daba66567a95beabb103f7996198a9675ab20b4a
--- /dev/null
+++ b/nsr/volumetric_rendering/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
\ No newline at end of file
diff --git a/nsr/volumetric_rendering/__pycache__/__init__.cpython-39.pyc b/nsr/volumetric_rendering/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ec38e1ab0fd3eb3c77e1fd3a34691cd68594378
Binary files /dev/null and b/nsr/volumetric_rendering/__pycache__/__init__.cpython-39.pyc differ
diff --git a/nsr/volumetric_rendering/__pycache__/math_utils.cpython-39.pyc b/nsr/volumetric_rendering/__pycache__/math_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40ec3963bed16722d6e68330b79e2a8c42c4d072
Binary files /dev/null and b/nsr/volumetric_rendering/__pycache__/math_utils.cpython-39.pyc differ
diff --git a/nsr/volumetric_rendering/__pycache__/ray_marcher.cpython-39.pyc b/nsr/volumetric_rendering/__pycache__/ray_marcher.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02fe52ca23b7fd4d9ced4bcd683b7a09a4383a25
Binary files /dev/null and b/nsr/volumetric_rendering/__pycache__/ray_marcher.cpython-39.pyc differ
diff --git a/nsr/volumetric_rendering/__pycache__/ray_sampler.cpython-39.pyc b/nsr/volumetric_rendering/__pycache__/ray_sampler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b59b3209f77b531e3aa46b3c366f29aaccadaea
Binary files /dev/null and b/nsr/volumetric_rendering/__pycache__/ray_sampler.cpython-39.pyc differ
diff --git a/nsr/volumetric_rendering/__pycache__/renderer.cpython-39.pyc b/nsr/volumetric_rendering/__pycache__/renderer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1fe32ba07209ce4062b24f0b1e8d3b76dae18fb
Binary files /dev/null and b/nsr/volumetric_rendering/__pycache__/renderer.cpython-39.pyc differ
diff --git a/nsr/volumetric_rendering/math_utils.py b/nsr/volumetric_rendering/math_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..81378a371d959e0c0b5886f00925651c0caa2124
--- /dev/null
+++ b/nsr/volumetric_rendering/math_utils.py
@@ -0,0 +1,137 @@
+# MIT License
+
+# Copyright (c) 2022 Petr Kellnhofer
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+from pdb import set_trace as st
+
+
+def transform_vectors(matrix: torch.Tensor,
+ vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor,
+ box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+ bb_min = [
+ -1 * (box_side_length / 2), -1 * (box_side_length / 2),
+ -1 * (box_side_length / 2)
+ ]
+ bb_max = [
+ 1 * (box_side_length / 2), 1 * (box_side_length / 2),
+ 1 * (box_side_length / 2)
+ ]
+ bounds = torch.tensor([bb_min, bb_max],
+ dtype=rays_o.dtype,
+ device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] -
+ rays_o[..., 0]) * invdir[..., 0]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] -
+ rays_o[..., 0]) * invdir[..., 0]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] -
+ rays_o[..., 1]) * invdir[..., 1]
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] -
+ rays_o[..., 1]) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] -
+ rays_o[..., 2]) * invdir[..., 2]
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] -
+ rays_o[..., 2]) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32,
+ device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
diff --git a/nsr/volumetric_rendering/ray_marcher.py b/nsr/volumetric_rendering/ray_marcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..62e4f0cf06a6e7d2256c810a9b83a58f6a2846ce
--- /dev/null
+++ b/nsr/volumetric_rendering/ray_marcher.py
@@ -0,0 +1,74 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""
+The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
+Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pdb import set_trace as st
+
+
+class MipRayMarcher2(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def run_forward(self, colors, densities, depths, rendering_options):
+ deltas = depths[:, :, 1:] - depths[:, :, :-1]
+ colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
+ densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
+ depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+
+ if rendering_options['clamp_mode'] == 'softplus':
+ densities_mid = F.softplus(
+ densities_mid -
+ 1) # activation bias of -1 makes things initialize better
+ else:
+ assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
+
+ density_delta = densities_mid * deltas
+
+ alpha = 1 - torch.exp(-density_delta)
+
+ alpha_shifted = torch.cat(
+ [torch.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2)
+ T = torch.cumprod(alpha_shifted, -2) # transmittance
+ weights = alpha * T[:, :, :-1]
+ visibility = T[:, :,
+ -1] # bg lambda, https://github.com/Kai-46/nerfplusplus/blob/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/ddp_model.py#L101
+ # st()
+
+ composite_rgb = torch.sum(weights * colors_mid, -2)
+ weight_total = weights.sum(2)
+ # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
+ composite_depth = torch.sum(
+ weights * depths_mid,
+ -2) # shapenet white background, no need this.
+
+ # clip the composite to min/max range of depths
+ composite_depth = torch.nan_to_num(composite_depth, float('inf'))
+ composite_depth = torch.clamp(composite_depth, torch.min(depths),
+ torch.max(depths))
+
+ if rendering_options.get('white_back', True):
+ composite_rgb = composite_rgb + 1 - weight_total
+
+ composite_rgb = composite_rgb * 2 - 1 # Scale (0,1) to (-1, 1)
+
+ return composite_rgb, composite_depth, visibility, weights
+
+ def forward(self, colors, densities, depths, rendering_options):
+ composite_rgb, composite_depth, visibility, weights = self.run_forward(
+ colors, densities, depths, rendering_options)
+
+ return composite_rgb, composite_depth, visibility, weights
diff --git a/nsr/volumetric_rendering/ray_sampler.py b/nsr/volumetric_rendering/ray_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b28b17c3bc024f23c519d5c1945542766ddfa0eb
--- /dev/null
+++ b/nsr/volumetric_rendering/ray_sampler.py
@@ -0,0 +1,331 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""
+The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
+Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
+"""
+
+import torch
+from pdb import set_trace as st
+import random
+
+HUGE_NUMBER = 1e10
+TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
+
+
+######################################################################################
+# wrapper to simplify the use of nerfnet
+######################################################################################
+# https://github.com/Kai-46/nerfplusplus/blob/ebf2f3e75fd6c5dfc8c9d0b533800daaf17bd95f/ddp_model.py#L16
+def depth2pts_outside(ray_o, ray_d, depth):
+ '''
+ ray_o, ray_d: [..., 3]
+ depth: [...]; inverse of distance to sphere origin
+ '''
+ # note: d1 becomes negative if this mid point is behind camera
+ d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
+ p_mid = ray_o + d1.unsqueeze(-1) * ray_d
+ p_mid_norm = torch.norm(p_mid, dim=-1)
+ ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
+ d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos
+ p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d
+
+ rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
+ rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
+ phi = torch.asin(p_mid_norm)
+ theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
+ rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
+
+ # now rotate p_sphere
+ # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
+ p_sphere_new = p_sphere * torch.cos(rot_angle) + \
+ torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
+ rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))
+ p_sphere_new = p_sphere_new / torch.norm(
+ p_sphere_new, dim=-1, keepdim=True)
+ pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
+
+ # now calculate conventional depth
+ depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1
+ return pts, depth_real
+
+
+class RaySampler(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
+
+ def create_patch_uv(self,
+ patch_resolution,
+ resolution,
+ cam2world_matrix,
+ fg_bbox=None):
+
+ def sample_patch_uv(fg_bbox=None):
+ assert patch_resolution <= resolution
+
+ def sample_patch_range():
+ patch_reolution_start = random.randint(
+ 0, resolution -
+ patch_resolution) # alias for randrange(start, stop+1)
+ # patch_reolution_end = patch_reolution_start + patch_resolution
+ return patch_reolution_start # , patch_reolution_end
+
+ def sample_patch_range_oversample_boundary(range_start=None,
+ range_end=None):
+ # left down corner undersampled
+ if range_start is None:
+ # range_start = patch_resolution // 2
+ range_start = patch_resolution
+ if range_end is None:
+ # range_end = resolution + patch_resolution // 2
+ range_end = resolution + patch_resolution
+
+ # oversample the boundary
+ patch_reolution_end = random.randint(
+ range_start,
+ range_end,
+ )
+
+ # clip range
+ if patch_reolution_end <= patch_resolution:
+ patch_reolution_end = patch_resolution
+ elif patch_reolution_end > resolution:
+ patch_reolution_end = resolution
+
+ # patch_reolution_end = patch_reolution_start + patch_resolution
+ return patch_reolution_end # , patch_reolution_end
+
+ # h_start = sample_patch_range()
+ # assert fg_bbox is not None
+ if fg_bbox is not None and random.random(
+ ) > 0.125: # only train foreground. Has 0.1 prob to sample/train background.
+ # if fg_bbox is not None: # only train foreground. Has 0.1 prob to sample/train background.
+ # only return one UV here
+ top_min, left_min = fg_bbox[:, :2].min(dim=0,
+ keepdim=True)[0][0]
+ height_max, width_max = fg_bbox[:, 2:].max(dim=0,
+ keepdim=True)[0][0]
+
+ if top_min + patch_resolution < height_max:
+ h_end = sample_patch_range_oversample_boundary(
+ top_min + patch_resolution, height_max)
+ else:
+ h_end = max(
+ height_max.to(torch.uint8).item(), patch_resolution)
+ if left_min + patch_resolution < width_max:
+ w_end = sample_patch_range_oversample_boundary(
+ left_min + patch_resolution, width_max)
+ else:
+ w_end = max(
+ width_max.to(torch.uint8).item(), patch_resolution)
+
+ h_start = h_end - patch_resolution
+ w_start = w_end - patch_resolution
+
+ try:
+ assert h_start >= 0 and w_start >= 0
+ except:
+ st()
+
+ else:
+ h_end = sample_patch_range_oversample_boundary()
+ h_start = h_end - patch_resolution
+ w_end = sample_patch_range_oversample_boundary()
+ w_start = w_end - patch_resolution
+
+ assert h_start >= 0 and w_start >= 0
+
+ uv = torch.stack(
+ torch.meshgrid(
+ torch.arange(
+ start=h_start,
+ # end=h_start+patch_resolution,
+ end=h_end,
+ dtype=torch.float32,
+ device=cam2world_matrix.device),
+ torch.arange(
+ start=w_start,
+ # end=w_start + patch_resolution,
+ end=w_end,
+ dtype=torch.float32,
+ device=cam2world_matrix.device),
+ indexing='ij')) * (1. / resolution) + (0.5 / resolution)
+
+ uv = uv.flip(0).reshape(2, -1).transpose(1, 0) # ij -> xy
+
+ return uv, (h_start, w_start, patch_resolution, patch_resolution
+ ) # top: int, left: int, height: int, width: int
+
+ all_uv = []
+ ray_bboxes = []
+ for _ in range(cam2world_matrix.shape[0]):
+ uv, bbox = sample_patch_uv(fg_bbox)
+ all_uv.append(uv)
+ ray_bboxes.append(bbox)
+
+ all_uv = torch.stack(all_uv, 0) # B patch_res**2 2
+ # ray_bboxes = torch.stack(ray_bboxes, 0) # B patch_res**2 2
+
+ return all_uv, ray_bboxes
+
+ def create_uv(self, resolution, cam2world_matrix):
+
+ uv = torch.stack(
+ torch.meshgrid(torch.arange(resolution,
+ dtype=torch.float32,
+ device=cam2world_matrix.device),
+ torch.arange(resolution,
+ dtype=torch.float32,
+ device=cam2world_matrix.device),
+ indexing='ij')) * (1. / resolution) + (0.5 /
+ resolution)
+
+ uv = uv.flip(0).reshape(2, -1).transpose(1, 0) # why
+ uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+
+ return uv
+
+ def forward(self, cam2world_matrix, intrinsics, resolution, fg_mask=None):
+ """
+ Create batches of rays and return origins and directions.
+
+ cam2world_matrix: (N, 4, 4)
+ intrinsics: (N, 3, 3)
+ resolution: int
+
+ ray_origins: (N, M, 3)
+ ray_dirs: (N, M, 2)
+ """
+ N, M = cam2world_matrix.shape[0], resolution**2
+ cam_locs_world = cam2world_matrix[:, :3, 3]
+ fx = intrinsics[:, 0, 0]
+ fy = intrinsics[:, 1, 1]
+ cx = intrinsics[:, 0, 2]
+ cy = intrinsics[:, 1, 2]
+ sk = intrinsics[:, 0, 1]
+
+ # uv = torch.stack(
+ # torch.meshgrid(torch.arange(resolution,
+ # dtype=torch.float32,
+ # device=cam2world_matrix.device),
+ # torch.arange(resolution,
+ # dtype=torch.float32,
+ # device=cam2world_matrix.device),
+ # indexing='ij')) * (1. / resolution) + (0.5 /
+ # resolution)
+ # uv = uv.flip(0).reshape(2, -1).transpose(1, 0) # why
+ # uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+ uv = self.create_uv(
+ resolution,
+ cam2world_matrix,
+ )
+
+ x_cam = uv[:, :, 0].view(N, -1)
+ y_cam = uv[:, :, 1].view(N, -1) # [0,1] range
+ z_cam = torch.ones((N, M), device=cam2world_matrix.device)
+
+ # basically torch.inverse(intrinsics)
+ x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) *
+ sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) *
+ y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
+ y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
+
+ cam_rel_points = torch.stack(
+ (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
+
+ # st()
+
+ world_rel_points = torch.bmm(cam2world_matrix,
+ cam_rel_points.permute(0, 2, 1)).permute(
+ 0, 2, 1)[:, :, :3]
+
+ ray_dirs = world_rel_points - cam_locs_world[:, None, :]
+ ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
+
+ ray_origins = cam_locs_world.unsqueeze(1).repeat(
+ 1, ray_dirs.shape[1], 1)
+
+ return ray_origins, ray_dirs, None
+
+
+class PatchRaySampler(RaySampler):
+
+ def forward(self,
+ cam2world_matrix,
+ intrinsics,
+ patch_resolution,
+ resolution,
+ fg_bbox=None):
+ """
+ Create batches of rays and return origins and directions.
+
+ cam2world_matrix: (N, 4, 4)
+ intrinsics: (N, 3, 3)
+ resolution: int
+
+ ray_origins: (N, M, 3)
+ ray_dirs: (N, M, 2)
+ """
+ N, M = cam2world_matrix.shape[0], patch_resolution**2
+ cam_locs_world = cam2world_matrix[:, :3, 3]
+ fx = intrinsics[:, 0, 0]
+ fy = intrinsics[:, 1, 1]
+ cx = intrinsics[:, 0, 2]
+ cy = intrinsics[:, 1, 2]
+ sk = intrinsics[:, 0, 1]
+
+ # uv = self.create_uv(resolution, cam2world_matrix)
+
+ # all_uv, ray_bboxes = self.create_patch_uv(
+ all_uv_list = []
+ ray_bboxes = []
+ for idx in range(N):
+ uv, bboxes = self.create_patch_uv(
+ patch_resolution, resolution, cam2world_matrix[idx:idx + 1],
+ fg_bbox[idx:idx + 1]
+ if fg_bbox is not None else None) # for debugging, hard coded
+ all_uv_list.append(
+ uv
+ # cam2world_matrix[idx:idx+1], )[0] # for debugging, hard coded
+ )
+ ray_bboxes.extend(bboxes)
+ all_uv = torch.cat(all_uv_list, 0)
+ # ray_bboxes = torch.cat(ray_bboxes_list, 0)
+ # all_uv, _ = self.create_patch_uv(
+ # patch_resolution, resolution,
+ # cam2world_matrix, fg_bbox) # for debugging, hard coded
+ # st()
+
+ x_cam = all_uv[:, :, 0].view(N, -1)
+ y_cam = all_uv[:, :, 1].view(N, -1) # [0,1] range
+ z_cam = torch.ones((N, M), device=cam2world_matrix.device)
+
+ # basically torch.inverse(intrinsics)
+ x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1) *
+ sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1) *
+ y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
+ y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
+
+ cam_rel_points = torch.stack(
+ (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
+
+ world_rel_points = torch.bmm(cam2world_matrix,
+ cam_rel_points.permute(0, 2, 1)).permute(
+ 0, 2, 1)[:, :, :3]
+
+ ray_dirs = world_rel_points - cam_locs_world[:, None, :]
+ ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
+
+ ray_origins = cam_locs_world.unsqueeze(1).repeat(
+ 1, ray_dirs.shape[1], 1)
+
+ return ray_origins, ray_dirs, ray_bboxes
diff --git a/nsr/volumetric_rendering/renderer.py b/nsr/volumetric_rendering/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cee75f5c034af01dcc916c2681f4f7a06cd43b6f
--- /dev/null
+++ b/nsr/volumetric_rendering/renderer.py
@@ -0,0 +1,637 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""
+The renderer is a module that takes in rays, decides where to sample along each
+ray, and computes pixel colors using the volume rendering equation.
+"""
+
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+from .ray_marcher import MipRayMarcher2
+from . import math_utils
+from pdb import set_trace as st
+from .ray_sampler import depth2pts_outside, HUGE_NUMBER, TINY_NUMBER
+
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+ """
+ return torch.tensor(
+ [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]],
+ [[0, 0, 1], [1, 0, 0], [0, 1, 0]]],
+ dtype=torch.float32)
+
+
+# def project_onto_planes(planes, coordinates):
+# """
+# Does a projection of a 3D point onto a batch of 2D planes,
+# returning 2D plane coordinates.
+
+# Takes plane axes of shape n_planes, 3, 3
+# # Takes coordinates of shape N, M, 3
+# # returns projections of shape N*n_planes, M, 2
+# """
+# N, M, C = coordinates.shape
+# n_planes, _, _ = planes.shape
+# coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+# inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+# projections = torch.bmm(coordinates, inv_planes)
+# return projections[..., :2]
+
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+
+ # # ORIGINAL
+ # N, M, C = coordinates.shape
+ # xy_coords = coordinates[..., [0, 1]]
+ # xz_coords = coordinates[..., [0, 2]]
+ # zx_coords = coordinates[..., [2, 0]]
+ # return torch.stack([xy_coords, xz_coords, zx_coords], dim=1).reshape(N*3, M, 2)
+
+ # FIXED
+ N, M, _ = coordinates.shape
+ xy_coords = coordinates[..., [0, 1]]
+ yz_coords = coordinates[..., [1, 2]]
+ zx_coords = coordinates[..., [2, 0]]
+ return torch.stack([xy_coords, yz_coords, zx_coords],
+ dim=1).reshape(N * 3, M, 2)
+
+
+def sample_from_planes(plane_axes,
+ plane_features,
+ coordinates,
+ mode='bilinear',
+ padding_mode='zeros',
+ box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ # st()
+ plane_features = plane_features.view(N * n_planes, C, H, W)
+ # plane_features = plane_features.reshape(N * n_planes, C, H, W)
+
+ coordinates = (2 / box_warp) * coordinates # TODO: add specific box bounds
+
+ projected_coordinates = project_onto_planes(plane_axes,
+ coordinates).unsqueeze(1)
+ output_features = torch.nn.functional.grid_sample(
+ plane_features,
+ projected_coordinates.float(),
+ mode=mode,
+ padding_mode=padding_mode,
+ align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+
+def sample_from_3dgrid(grid, coordinates):
+ """
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+ Expects grid in shape (1, channels, H, W, D)
+ (Also works if grid has batch size)
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
+ """
+ batch_size, n_coords, n_dims = coordinates.shape
+ sampled_features = torch.nn.functional.grid_sample(
+ grid.expand(batch_size, -1, -1, -1, -1),
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ N, C, H, W, D = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 4, 3, 2,
+ 1).reshape(N, H * W * D, C)
+ return sampled_features
+
+
+class ImportanceRenderer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.ray_marcher = MipRayMarcher2()
+ self.plane_axes = generate_planes()
+
+ def forward(self,
+ planes,
+ decoder,
+ ray_origins,
+ ray_directions,
+ rendering_options,
+ return_meta=False):
+ # return_sampling_details_flag=False):
+ self.plane_axes = self.plane_axes.to(ray_origins.device)
+ # if rendering_options.get('return_sampling_details_flag', None) is not None:
+ shape_synthesized = {}
+
+ if rendering_options['ray_start'] == rendering_options[
+ 'ray_end'] == 'auto':
+ ray_start, ray_end = math_utils.get_ray_limits_box(
+ ray_origins,
+ ray_directions,
+ box_side_length=rendering_options['box_warp'])
+ is_ray_valid = ray_end > ray_start
+ # st()
+ if torch.any(is_ray_valid).item():
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+ depths_coarse = self.sample_stratified(
+ ray_origins, ray_start, ray_end,
+ rendering_options['depth_resolution'],
+ rendering_options['disparity_space_sampling'])
+ else:
+ # Create stratified depth samples
+ depths_coarse = self.sample_stratified(
+ ray_origins, rendering_options['ray_start'],
+ rendering_options['ray_end'],
+ rendering_options['depth_resolution'],
+ rendering_options['disparity_space_sampling'])
+
+ batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
+
+ # Coarse Pass
+ sample_coordinates = (
+ ray_origins.unsqueeze(-2) +
+ depths_coarse * ray_directions.unsqueeze(-2)).reshape(
+ batch_size, -1, 3)
+ # st() # np.save('sample_coordinates.npy', sample_coordinates.detach().cpu().numpy())
+ sample_directions = ray_directions.unsqueeze(-2).expand(
+ -1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+
+ colors_coarse, densities_coarse = self.run_model(
+ planes, decoder, sample_coordinates, sample_directions,
+ rendering_options, batch_size, num_rays, samples_per_ray)
+
+ colors_coarse = colors_coarse.reshape(batch_size, num_rays,
+ samples_per_ray,
+ colors_coarse.shape[-1])
+ densities_coarse = densities_coarse.reshape(batch_size, num_rays,
+ samples_per_ray, 1)
+
+ if rendering_options.get('return_sampling_details_flag', False):
+ shape_synthesized.update({
+ # 'coarse_coords': sample_coordinates.detach().clone(),
+ # 'coarse_densities': densities_coarse.detach()
+ 'coarse_coords':
+ sample_coordinates.reshape(batch_size, num_rays,
+ samples_per_ray, 3),
+ 'coarse_densities':
+ densities_coarse
+ })
+
+ # Fine Pass
+ N_importance = rendering_options['depth_resolution_importance']
+ if N_importance > 0:
+ _, _, _, weights = self.ray_marcher(colors_coarse,
+ densities_coarse,
+ depths_coarse,
+ rendering_options)
+
+ depths_fine = self.sample_importance(depths_coarse, weights,
+ N_importance)
+
+ sample_directions = ray_directions.unsqueeze(-2).expand(
+ -1, -1, N_importance, -1).reshape(batch_size, -1, 3)
+ sample_coordinates = (
+ ray_origins.unsqueeze(-2) +
+ depths_fine * ray_directions.unsqueeze(-2)).reshape(
+ batch_size, -1, 3)
+
+ colors_fine, densities_fine = self.run_model(
+ planes, decoder, sample_coordinates, sample_directions,
+ rendering_options, batch_size, num_rays, N_importance)
+ # colors_fine = out['rgb']
+ # densities_fine = out['sigma']
+ colors_fine = colors_fine.reshape(batch_size, num_rays,
+ N_importance,
+ colors_fine.shape[-1])
+ densities_fine = densities_fine.reshape(batch_size, num_rays,
+ N_importance, 1)
+ if rendering_options.get('return_sampling_details_flag', False):
+ shape_synthesized.update({
+ # 'fine_coords': sample_coordinates.detach(),
+ # 'fine_densities': densities_fine.detach()
+ 'fine_coords': sample_coordinates,
+ # 'fine_coords': sample_coordinates.reshape(batch_size, num_rays, N_importance, 3),
+ 'fine_densities': densities_fine,
+ })
+
+ all_depths, all_colors, all_densities, indices = self.unify_samples(
+ depths_coarse, colors_coarse, densities_coarse, depths_fine,
+ colors_fine, densities_fine)
+
+ # Aggregate
+ rgb_final, depth_final, visibility, weights = self.ray_marcher(
+ all_colors, all_densities, all_depths, rendering_options)
+
+ else:
+ rgb_final, depth_final, visibility, weights = self.ray_marcher(
+ colors_coarse, densities_coarse, depths_coarse,
+ rendering_options)
+
+ if rendering_options.get('return_surface', False):
+ weight_total = weights.sum(2)
+
+ all_coords = torch.cat([
+ shape_synthesized['coarse_coords'],
+ shape_synthesized['fine_coords']
+ ],
+ dim=-2) # B 4096 48+48 3
+ all_coords = torch.gather(all_coords, -2,
+ indices.expand(-1, -1, -1, 3))
+
+ composite_surface = torch.sum(weights * all_coords,
+ -2) / weight_total
+
+ # clip the composite to min/max range of depths
+ composite_surface = torch.nan_to_num(composite_surface,
+ float('inf'))
+ composite_surface = torch.clamp(composite_surface,
+ torch.min(all_coords),
+ torch.max(all_coords))
+ shape_synthesized['surface_coords'] = composite_surface
+
+ shape_synthesized.update({
+ # 'depth': depth_final.detach()
+ 'depth': depth_final
+ })
+
+ ret_dict = {
+ 'feature_samples': rgb_final,
+ 'depth_samples': depth_final,
+ 'weights_samples': weights.sum(2),
+ 'shape_synthesized': shape_synthesized,
+ 'visibility': visibility # T[..., -1]
+ }
+
+ if return_meta: # for pifu
+ all_coords = torch.cat([
+ shape_synthesized['coarse_coords'],
+ shape_synthesized['fine_coords'].reshape(
+ batch_size, num_rays, N_importance, 3)
+ ],
+ dim=-2)
+ # 'fine_coords': sample_coordinates,
+ all_coords = torch.gather(all_coords, -2,
+ indices.expand(-1, -1, -1, 3))
+
+ ret_dict.update({
+ 'all_coords': all_coords,
+ 'feature_volume': all_colors,
+ 'weights': weights
+ })
+
+ if rendering_options.get('return_sampling_details_flag', False):
+ ret_dict.update({'shape_synthesized': shape_synthesized})
+ # return rgb_final, depth_final, weights.sum(2), shape_synthesized # rgb_final, B, 4096, 32
+
+ # return rgb_final, depth_final, weights.sum(2)
+ return ret_dict
+
+ # old run_model
+ def _run_model(self, planes, decoder, sample_coordinates,
+ sample_directions, options):
+ sampled_features = sample_from_planes(self.plane_axes,
+ planes,
+ sample_coordinates,
+ padding_mode='zeros',
+ box_warp=options['box_warp'])
+
+ out = decoder(sampled_features, sample_directions)
+ if options.get('density_noise', 0) > 0:
+ out['sigma'] += torch.randn_like(
+ out['sigma']) * options['density_noise']
+ return out
+
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions,
+ rendering_options, batch_size, num_rays, samples_per_ray):
+ """ a compat wrapper for Objaverse (bbox-sampling) and FFHQ/Shapenet-based rendering (ray-start/end sampling).
+
+ returns color and density
+ """
+
+ if rendering_options.get('filter_out_of_bbox', False):
+ # Coarse Pass
+ colors, densities = self._forward_pass(
+ # depths=depths_coarse,
+ # ray_directions=ray_directions,
+ # ray_origins=ray_origins,
+ sample_coordinates,
+ sample_directions,
+ planes=planes,
+ decoder=decoder,
+ rendering_options=rendering_options,
+ batch_size=batch_size,
+ num_rays=num_rays,
+ samples_per_ray=samples_per_ray,
+ )
+ else:
+ out = self._run_model(planes, decoder, sample_coordinates,
+ sample_directions, rendering_options)
+ colors = out['rgb']
+ densities = out['sigma']
+
+ return colors, densities
+
+ def _forward_pass(
+ self,
+ sample_coordinates,
+ sample_directions,
+ # depths: torch.Tensor,
+ # ray_directions: torch.Tensor,
+ # ray_origins: torch.Tensor,
+ planes: torch.Tensor,
+ decoder: nn.Module,
+ rendering_options: dict,
+ batch_size,
+ num_rays,
+ samples_per_ray):
+ """
+ Additional filtering is applied to filter out-of-box samples.
+ Modifications made by Zexin He.
+ """
+
+ # context related variables
+ # batch_size, num_rays, samples_per_ray, _ = depths.shape
+ device = sample_coordinates.device
+
+ # define sample points with depths
+ # sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+ # sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+
+ # filter out-of-box samples
+ mask_inbox = \
+ (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
+ (sample_coordinates <= rendering_options['sampler_bbox_max'])
+ mask_inbox = mask_inbox.all(-1) # np.save('box.npy', mask_inbox.detach().cpu().numpy())
+
+ # forward model according to all samples
+ _out = self._run_model(planes, decoder, sample_coordinates,
+ sample_directions, rendering_options)
+
+ # set out-of-box samples to zeros(rgb) & -inf(sigma)
+ SAFE_GUARD = 3
+ DATA_TYPE = _out['sigma'].dtype
+ colors_pass = torch.zeros(batch_size,
+ num_rays * samples_per_ray,
+ # 3,
+ decoder.decoder_output_dim,
+ device=device,
+ dtype=DATA_TYPE)
+ densities_pass = torch.nan_to_num(
+ torch.full((batch_size, num_rays * samples_per_ray, 1),
+ -float('inf'),
+ device=device,
+ dtype=DATA_TYPE)) / SAFE_GUARD
+ colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][
+ mask_inbox], _out['sigma'][mask_inbox]
+
+ # reshape back
+ # colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
+ # densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
+
+ return colors_pass, densities_pass
+
+ def sort_samples(self, all_depths, all_colors, all_densities):
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(
+ all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2,
+ indices.expand(-1, -1, -1, 1))
+ return all_depths, all_colors, all_densities
+
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2,
+ densities2):
+ all_depths = torch.cat([depths1, depths2], dim=-2)
+ all_colors = torch.cat([colors1, colors2], dim=-2)
+ all_densities = torch.cat([densities1, densities2], dim=-2)
+
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(
+ all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2,
+ indices.expand(-1, -1, -1, 1))
+
+ return all_depths, all_colors, all_densities, indices
+
+ def sample_stratified(self,
+ ray_origins,
+ ray_start,
+ ray_end,
+ depth_resolution,
+ disparity_space_sampling=False):
+ """
+ Return depths of approximately uniformly spaced samples along rays.
+ """
+ N, M, _ = ray_origins.shape
+ if disparity_space_sampling:
+ depths_coarse = torch.linspace(0,
+ 1,
+ depth_resolution,
+ device=ray_origins.device).reshape(
+ 1, 1, depth_resolution,
+ 1).repeat(N, M, 1, 1)
+ depth_delta = 1 / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ depths_coarse = 1. / (1. / ray_start * (1. - depths_coarse) +
+ 1. / ray_end * depths_coarse)
+ else:
+ if type(ray_start) == torch.Tensor:
+ depths_coarse = math_utils.linspace(ray_start, ray_end,
+ depth_resolution).permute(
+ 1, 2, 0, 3)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[
+ ..., None]
+ else:
+ depths_coarse = torch.linspace(
+ ray_start,
+ ray_end,
+ depth_resolution,
+ device=ray_origins.device).reshape(1, 1, depth_resolution,
+ 1).repeat(N, M, 1, 1)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ # print("ignore normal noise!!! for debugging")
+
+ return depths_coarse
+
+ def sample_importance(self, z_vals, weights, N_importance):
+ """
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
+ """
+ with torch.no_grad():
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+ weights = weights.reshape(
+ batch_size * num_rays,
+ -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+ # smooth weights
+ weights = torch.nn.functional.max_pool1d(
+ weights.unsqueeze(1).float(), 2, 1, padding=1)
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+ weights = weights + 0.01
+
+ z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+ N_importance).detach().reshape(
+ batch_size, num_rays,
+ N_importance, 1)
+ return importance_z_vals
+
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+ """
+ Sample @N_importance samples from @bins with distribution defined by @weights.
+ Inputs:
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+ weights: (N_rays, N_samples_)
+ N_importance: the number of samples to draw from the distribution
+ det: deterministic or not
+ eps: a small number to prevent division by zero
+ Outputs:
+ samples: the sampled samples
+ """
+ N_rays, N_samples_ = weights.shape
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
+ pdf = weights / torch.sum(weights, -1,
+ keepdim=True) # (N_rays, N_samples_)
+ cdf = torch.cumsum(
+ pdf, -1) # (N_rays, N_samples), cumulative distribution function
+ cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf],
+ -1) # (N_rays, N_samples_+1)
+ # padded to 0~1 inclusive
+
+ if det:
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
+ u = u.expand(N_rays, N_importance)
+ else:
+ u = torch.rand(N_rays, N_importance, device=bins.device)
+ u = u.contiguous()
+
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.clamp_min(inds - 1, 0)
+ above = torch.clamp_max(inds, N_samples_)
+
+ inds_sampled = torch.stack([below, above],
+ -1).view(N_rays, 2 * N_importance)
+ cdf_g = torch.gather(cdf, 1,
+ inds_sampled).view(N_rays, N_importance, 2)
+ bins_g = torch.gather(bins, 1,
+ inds_sampled).view(N_rays, N_importance, 2)
+
+ denom = cdf_g[..., 1] - cdf_g[..., 0]
+ denom[
+ denom <
+ eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
+ # anyway, therefore any value for it is fine (set to 1 here)
+
+ samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (
+ bins_g[..., 1] - bins_g[..., 0])
+ return samples
+
+
+class ImportanceRendererfg_bg(ImportanceRenderer):
+ """
+ render foreground-background together, using nerfpp strategy.
+ """
+ def __init__(self):
+ super().__init__()
+
+ def forward_background(self, bg_planes, decoder, ray_origins,
+ ray_directions, rendering_options):
+ # ! no importance sampling here.
+
+ # # background depth
+ depths_coarse = self.sample_stratified(
+ ray_origins, 0, 1, rendering_options['bg_depth_resolution'],
+ rendering_options['disparity_space_sampling']).squeeze(
+ -1) # remove the last 1 dim, B N S here
+
+ batch_size, num_rays, samples_per_ray = depths_coarse.shape
+
+ sample_directions = ray_directions.unsqueeze(-2).expand(
+ -1, -1, samples_per_ray, -1)
+ sample_origins = ray_origins.unsqueeze(-2).expand(
+ -1, -1, samples_per_ray, -1)
+
+ bg_sample_coordinates, _ = depth2pts_outside(
+ sample_origins, sample_directions,
+ depths_coarse) # [..., N_samples, 4]
+
+ out = self.run_model(bg_planes, decoder, bg_sample_coordinates,
+ sample_directions.reshape(batch_size, -1, 3),
+ rendering_options)
+
+ colors_coarse = out['rgb']
+ densities_coarse = out['sigma']
+ colors_coarse = colors_coarse.reshape(batch_size, num_rays,
+ samples_per_ray,
+ colors_coarse.shape[-1])
+ densities_coarse = densities_coarse.reshape(batch_size, num_rays,
+ samples_per_ray, 1)
+
+ rgb_final, depth_final, _, weights = self.ray_marcher(
+ colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+ ret_dict = {
+ 'feature_samples': rgb_final,
+ 'depth_samples': depth_final,
+ 'weights_samples': weights.sum(2),
+ # 'visibility': visibility # T[..., -1]
+ }
+
+ return ret_dict
+
+ def forward(self,
+ planes,
+ decoder,
+ ray_origins,
+ ray_directions,
+ rendering_options,
+ return_meta=False):
+
+ fg_planes, bg_planes = torch.split(
+ planes, planes.shape[2] // 2,
+ dim=2) # concatenated on the Channel side
+
+ # ! composite fg/bg
+ fg_ret_dict = super().forward(fg_planes,
+ decoder,
+ ray_origins,
+ ray_directions,
+ rendering_options,
+ return_meta=False)
+
+ bg_ret_dict = self.forward_background(
+ bg_planes,
+ decoder,
+ ray_origins,
+ ray_directions,
+ rendering_options,
+ )
+
+ ret_dict = {**fg_ret_dict, 'bg_ret_dict': bg_ret_dict} # for compat
+
+ return ret_dict # will composite in the external triplane.py
diff --git a/scripts/__pycache__/vit_triplane_train_FFHQ.cpython-39.pyc b/scripts/__pycache__/vit_triplane_train_FFHQ.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dcb8ce773de90a81a790bf6b0a7d479d544428b
Binary files /dev/null and b/scripts/__pycache__/vit_triplane_train_FFHQ.cpython-39.pyc differ
diff --git a/scripts/lmdb_create.py b/scripts/lmdb_create.py
new file mode 100644
index 0000000000000000000000000000000000000000..62a7de85fb94c2e51b9b595424e81b698e53f62c
--- /dev/null
+++ b/scripts/lmdb_create.py
@@ -0,0 +1,414 @@
+"""
+Train a diffusion model on images.
+"""
+# import imageio
+import gzip
+import random
+import json
+import sys
+import os
+import lmdb
+from tqdm import tqdm
+sys.path.append('.')
+import torch.distributed as dist
+import pickle
+import traceback
+from PIL import Image
+import torch as th
+import torch.multiprocessing as mp
+import lzma
+import numpy as np
+
+from torch.utils.data import DataLoader, Dataset
+import imageio.v3 as iio
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+)
+# from nsr.train_util import TrainLoop3DRec as TrainLoop
+from nsr.train_nv_util import TrainLoop3DRecNV, TrainLoop3DRec, TrainLoop3DRecNVPatch
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+# from datasets.shapenet import load_data, load_data_for_lmdb, load_eval_data, load_memory_data
+from nsr.losses.builder import E3DGELossClass
+from datasets.eg3d_dataset import init_dataset_kwargs
+
+from pdb import set_trace as st
+import bz2
+
+# th.backends.cuda.matmul.allow_tf32 = True # https://huggingface.co/docs/diffusers/optimization/fp16
+
+
+
+def training_loop(args):
+ # def training_loop(args):
+ dist_util.setup_dist(args)
+ # th.autograd.set_detect_anomaly(True) # type: ignore
+ th.autograd.set_detect_anomaly(False) # type: ignore
+ # https://blog.csdn.net/qq_41682740/article/details/126304613
+
+ SEED = args.seed
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count())
+ logger.log(f"{args.local_rank=} init complete, seed={SEED}")
+ th.cuda.set_device(args.local_rank)
+ th.cuda.empty_cache()
+
+ # * deterministic algorithms flags
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+ random.seed(SEED)
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating encoder and NSR decoder...")
+ # device = dist_util.dev()
+ device = th.device("cuda", args.local_rank)
+
+ # shared eg3d opts
+ opts = eg3d_options_default()
+
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ # auto_encoder = create_3DAE_model(
+ # **args_to_dict(args,
+ # encoder_and_nsr_defaults().keys()))
+ # auto_encoder.to(device)
+ # auto_encoder.train()
+
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_data_for_lmdb
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data, load_data_for_lmdb
+
+ logger.log("creating data loader...")
+ # data = load_data(
+ # st()
+ # if args.overfitting:
+ # data = load_memory_data(
+ # file_path=args.data_dir,
+ # batch_size=args.batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=args.depth_lambda > 0
+ # load_depth=True # for evaluation
+ # )
+ # else:
+ if args.cfg in ('afhq', 'ffhq'):
+ # ! load data
+ logger.log("creating eg3d data loader...")
+ training_set_kwargs, dataset_name = init_dataset_kwargs(data=args.data_dir,
+ class_name='datasets.eg3d_dataset.ImageFolderDatasetLMDB',
+ reso_gt=args.image_size) # only load pose here
+ # if args.cond and not training_set_kwargs.use_labels:
+ # raise Exception('check here')
+
+ # training_set_kwargs.use_labels = args.cond
+ training_set_kwargs.use_labels = True
+ training_set_kwargs.xflip = False
+ training_set_kwargs.random_seed = SEED
+ # training_set_kwargs.max_size = args.dataset_size
+ # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # * construct ffhq/afhq dataset
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+ dataset_size = len(training_set)
+
+ # training_set_sampler = InfiniteSampler(
+ # dataset=training_set,
+ # rank=dist_util.get_rank(),
+ # num_replicas=dist_util.get_world_size(),
+ # seed=SEED)
+
+ data = DataLoader(
+ training_set,
+ shuffle=False,
+ batch_size=1,
+ num_workers=16,
+ drop_last=False,
+ # prefetch_factor=2,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+
+ else:
+ # data, dataset_name, dataset_size = load_data_for_lmdb(
+ data, dataset_name, dataset_size, _ = load_data_for_lmdb(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True,
+ preprocess=None,
+ dataset_size=args.dataset_size,
+ trainer_name=args.trainer_name
+ # load_depth=True # for evaluation
+ )
+ # if args.pose_warm_up_iter > 0:
+ # overfitting_dataset = load_memory_data(
+ # file_path=args.data_dir,
+ # batch_size=args.batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=args.depth_lambda > 0
+ # load_depth=True # for evaluation
+ # )
+ # data = [data, overfitting_dataset, args.pose_warm_up_iter]
+ # eval_data = load_eval_data(
+ # file_path=args.eval_data_dir,
+ # batch_size=args.eval_batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # load_depth=True, # for evaluation
+ # preprocess=auto_encoder.preprocess)
+ args.img_size = [args.image_size_encoder]
+ # try dry run
+ # batch = next(data)
+ # batch = None
+
+ # logger.log("creating model and diffusion...")
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ # opt.max_depth, opt.min_depth = args.rendering_kwargs.ray_end, args.rendering_kwargs.ray_start
+ # loss_class = E3DGELossClass(device, opt).to(device)
+
+ # writer = SummaryWriter() # TODO, add log dir
+
+ logger.log("training...")
+
+ # TrainLoop = {
+ # 'input_rec': TrainLoop3DRec,
+ # 'nv_rec': TrainLoop3DRecNV,
+ # 'nv_rec_patch': TrainLoop3DRecNVPatch,
+ # }[args.trainer_name]
+
+ # TrainLoop(rec_model=auto_encoder,
+ # loss_class=loss_class,
+ # data=data,
+ # eval_data=eval_data,
+ # **vars(args)).run_loop() # ! overfitting
+
+
+ def convert_to_lmdb(dataset_loader, lmdb_path):
+ """
+ Convert a PyTorch dataset to LMDB format.
+
+ Parameters:
+ - dataset: PyTorch dataset
+ - lmdb_path: Path to store the LMDB database
+ """
+ env = lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) # Adjust map_size based on your dataset size
+
+ with env.begin(write=True) as txn:
+ for idx, sample in enumerate(tqdm(dataset_loader)):
+ # remove the batch index of returned dict sample
+ sample = {
+ k:v.squeeze(0).cpu().numpy() if isinstance(v, th.Tensor) else v[0]
+ for k, v in sample.items()
+ }
+
+ # sample = dataset_loader[idx]
+ key = str(idx).encode('utf-8')
+ value = pickle.dumps(sample)
+ txn.put(key, value)
+
+ # txn.put("length".encode("utf-8"), f'{imgset_size}'.encode("utf-8")) # ! will incur bug in dataloading.
+ # txn.put("start_idx".encode("utf-8"), f'{start_idx}'.encode("utf-8"))
+ # txn.put("end_idx".encode("utf-8"), f'{end_idx}'.encode("utf-8"))
+
+ # env.close()
+
+ import zlib
+
+ # Function to encode and compress an image
+ # def encode_and_compress_image(image_path):
+ # def encode_and_compress_image(image):
+ # # Open and encode the image
+ # # with open(image_path, 'rb') as f:
+ # # image = Image.open(f)
+ # encoded_data = image.tobytes()
+
+ # # Compress the encoded data
+ # # Compress the image data using bz2
+ # compressed_data = gzip.compress(encoded_data)
+ # # compressed_data = bz2.compress(encoded_data)
+ # # compressed_data = lzma.compress(encoded_data)
+ # # compressed_data = zlib.compress(encoded_data)
+
+ # return compressed_data
+
+ # Function to compress an image using gzip
+ # def compress_image_gzip(image_path):
+ def encode_and_compress_image(inp_array, is_image=False, compress=True):
+ # Read the image using imageio
+ # image = imageio.v3.imread(image_path)
+
+ # Convert the image to bytes
+ # with io.BytesIO() as byte_buffer:
+ # imageio.imsave(byte_buffer, image, format="png")
+ # image_bytes = byte_buffer.getvalue()
+ if is_image:
+ inp_bytes = iio.imwrite("", inp_array, extension=".png")
+ else:
+ inp_bytes = inp_array.tobytes()
+
+ # Compress the image data using gzip
+ if compress:
+ compressed_data = gzip.compress(inp_bytes)
+ return compressed_data
+ else:
+ return inp_bytes
+
+
+
+ def convert_to_lmdb_compressed(dataset_loader, lmdb_path, dataset_size):
+ """
+ Convert a PyTorch dataset to LMDB format.
+
+ Parameters:
+ - dataset: PyTorch dataset
+ - lmdb_path: Path to store the LMDB database
+ """
+ env = lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) # Adjust map_size based on your dataset size
+
+ # with env.begin(write=True) as txn:
+
+ with env.begin(write=True) as txn:
+ txn.put("length".encode("utf-8"), str(dataset_size).encode("utf-8"))
+ for idx, sample in enumerate(tqdm(dataset_loader)):
+ # remove the batch index of returned dict sample
+ sample = {
+ k:v.squeeze(0).cpu().numpy() if isinstance(v, th.Tensor) else v[0]
+ for k, v in sample.items()
+ }
+
+ # sample = dataset_loader[idx]
+ for k, v in sample.items():
+
+ # if idx == 0: # record data shape and type for decoding
+ # txn.put(f"{k}.shape".encode("utf-8"), str(v.shape).encode("utf-8"))
+ # txn.put(f"{k}.dtype".encode("utf-8"), str(v.dtype).encode("utf-8"))
+
+ key = f'{idx}-{k}'.encode('utf-8')
+ # value = pickle.dumps(sample)
+ # if 'depth' in k or 'img' in k:
+ if 'img' in k: # only bytes required? laod the 512 depth bytes only.
+ v = encode_and_compress_image(v, is_image=True, compress=False)
+ # elif 'depth' in k:
+ else: # regular bytes encoding
+ if type(v) != str:
+ v = v.astype(np.float32)
+ v = encode_and_compress_image(v, is_image=False, compress=False)
+ else:
+ v = v.encode("utf-8")
+ # else: # regular bytes encoding
+ # v = v.tobytes()
+
+ txn.put(key, v)
+
+
+ # txn.put("length".encode("utf-8"), f'{imgset_size}'.encode("utf-8")) # ! will incur bug in dataloading.
+ # txn.put("start_idx".encode("utf-8"), f'{start_idx}'.encode("utf-8"))
+ # txn.put("end_idx".encode("utf-8"), f'{end_idx}'.encode("utf-8"))
+
+ # env.close()
+
+
+ # convert_to_lmdb(data, os.path.join(logger.get_dir(), dataset_name)) convert_to_lmdb_compressed(data, os.path.join(logger.get_dir(), dataset_name))
+ convert_to_lmdb_compressed(data, os.path.join(logger.get_dir()), dataset_size)
+
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ seed=0,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_amp=False,
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ # test warm up pose sampling training
+ objv_dataset=False,
+ pose_warm_up_iter=-1,
+ )
+
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ # os.environ[
+ # "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+ # os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
+ # os.environ["NCCL_DEBUG"]="INFO"
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ opts = args
+
+ args.rendering_kwargs = rendering_options_defaults(opts)
+
+ # print(args)
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/profile_dataloading.py b/scripts/profile_dataloading.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b27cfae91a43492aa20f2ce3217ca6a29b5c517
--- /dev/null
+++ b/scripts/profile_dataloading.py
@@ -0,0 +1,289 @@
+"""
+Train a diffusion model on images.
+"""
+import cv2
+from pathlib import Path
+import imageio
+import random
+import json
+import sys
+import os
+
+from tqdm import tqdm
+sys.path.append('.')
+import torch.distributed as dist
+
+import traceback
+
+import torch as th
+import torch.multiprocessing as mp
+import numpy as np
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+)
+# from nsr.train_util import TrainLoop3DRec as TrainLoop
+from nsr.train_nv_util import TrainLoop3DRecNV, TrainLoop3DRec, TrainLoop3DRecNVPatch
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+# from datasets.shapenet import load_data, load_eval_data, load_memory_data, load_dataset
+from nsr.losses.builder import E3DGELossClass
+from datasets.eg3d_dataset import LMDBDataset_MV_Compressed_eg3d
+from dnnlib.util import EasyDict, InfiniteSampler
+
+from pdb import set_trace as st
+
+# th.backends.cuda.matmul.allow_tf32 = True # https://huggingface.co/docs/diffusers/optimization/fp16
+
+
+
+def training_loop(args):
+ # def training_loop(args):
+ dist_util.setup_dist(args)
+ # th.autograd.set_detect_anomaly(True) # type: ignore
+ th.autograd.set_detect_anomaly(False) # type: ignore
+ # https://blog.csdn.net/qq_41682740/article/details/126304613
+
+ SEED = args.seed
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count())
+ # logger.log(f"{args.local_rank=} init complete, seed={SEED}")
+ th.cuda.set_device(args.local_rank)
+ th.cuda.empty_cache()
+
+ # * deterministic algorithms flags
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+ random.seed(SEED)
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating encoder and NSR decoder...")
+ # device = dist_util.dev()
+ # device = th.device("cuda", args.local_rank)
+
+ # shared eg3d opts
+ opts = eg3d_options_default()
+
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ # auto_encoder = create_3DAE_model(
+ # **args_to_dict(args,
+ # encoder_and_nsr_defaults().keys()))
+ # auto_encoder.to(device)
+ # auto_encoder.train()
+
+ logger.log("creating data loader...")
+ # data = load_data(
+ # st()
+
+ # st()
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_dataset, load_eval_data, load_memory_data
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data, load_dataset
+
+ # st()
+ if args.overfitting:
+ data = load_memory_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ # load_depth=args.depth_lambda > 0
+ load_depth=True # for evaluation
+ )
+ else:
+ if args.cfg in ['ffhq' ]:
+ training_set = LMDBDataset_MV_Compressed_eg3d(
+ args.data_dir,
+ args.image_size,
+ args.image_size_encoder,
+ )
+ training_set_sampler = InfiniteSampler(
+ dataset=training_set,
+ rank=dist_util.get_rank(),
+ num_replicas=dist_util.get_world_size(),
+ seed=SEED)
+
+ data = iter(
+ th.utils.data.DataLoader(
+ dataset=training_set,
+ sampler=training_set_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ num_workers=args.num_workers,
+ persistent_workers=args.num_workers>0,
+ prefetch_factor=max(8//args.batch_size, 2),
+ ))
+
+ else:
+ # st()
+ # loader = load_data(
+ loader = load_dataset(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True,
+ preprocess=None,
+ dataset_size=args.dataset_size,
+ trainer_name=args.trainer_name,
+ use_lmdb=args.use_lmdb,
+ infi_sampler=False,
+ # infi_sampler=True,
+ # load_depth=True # for evaluation
+ )
+ if args.pose_warm_up_iter > 0:
+ overfitting_dataset = load_memory_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ # load_depth=args.depth_lambda > 0
+ load_depth=True # for evaluation
+ )
+ data = [data, overfitting_dataset, args.pose_warm_up_iter]
+ # eval_data = load_eval_data(
+ # file_path=args.eval_data_dir,
+ # batch_size=args.eval_batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # load_depth=True, # for evaluation
+ # preprocess=None,
+ args.img_size = [args.image_size_encoder]
+ # try dry run
+ # batch = next(data)
+ # batch = None
+
+ # logger.log("creating model and diffusion...")
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ # opt.max_depth, opt.min_depth = args.rendering_kwargs.ray_end, args.rendering_kwargs.ray_start
+ # loss_class = E3DGELossClass(device, opt).to(device)
+
+ # writer = SummaryWriter() # TODO, add log dir
+
+ logger.log("training...")
+
+ # TrainLoop = {
+ # 'input_rec': TrainLoop3DRec,
+ # 'nv_rec': TrainLoop3DRecNV,
+ # 'nv_rec_patch': TrainLoop3DRecNVPatch,
+ # }[args.trainer_name]
+
+ # TrainLoop(rec_model=auto_encoder,
+ # loss_class=loss_class,
+ # data=data,
+ # eval_data=eval_data,
+ # **vars(args)).run_loop() # ! overfitting
+ number = 0
+ # tgt_dir = Path(f'/mnt/lustre/yslan/3D_Dataset/resized_for_fid/chair/{args.image_size}')
+ # tgt_dir = Path(f'/mnt/lustre/yslan/3D_Dataset/resized_for_fid/chair-new/{args.image_size}')
+ # tgt_dir.mkdir(parents=True, exist_ok=True)
+ for idx, batch in enumerate(tqdm(loader)):
+ # for idx in tqdm(len(loader)): # ! dataset here, direct reference
+ # batch = loader[idx]
+ # worker=3: 2.5it/s; worker=8: 1.47it/s; worker=4, 2.3it/s; worker=1, 1.45it/s
+ # ! save to target folder for FID/KID
+ # for idx in range(batch['img'].shape[0]):
+ # # imageio.v3.imwrite(tgt_dir / f'{number}.png' ,(127.5+127.5*batch['img'][idx].cpu().numpy()).astype(np.uint8))
+ # cv2.imwrite(str(tgt_dir / f'{number}.png') ,(127.5+127.5*batch['img'][idx].cpu().permute(1,2,0).numpy()).astype(np.uint8))
+ # number += 1
+
+ pass
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ seed=0,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_amp=False,
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ # test warm up pose sampling training
+ pose_warm_up_iter=-1,
+ use_lmdb=False,
+ objv_dataset=False,
+ )
+
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ # os.environ[
+ # "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+ # os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
+ # os.environ["NCCL_DEBUG"]="INFO"
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ opts = args
+
+ args.rendering_kwargs = rendering_options_defaults(opts)
+
+ # print(args)
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/vit_triplane_cldm_train.py b/scripts/vit_triplane_cldm_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bae18e5c3d92903d6fea657d9779d434ade0e00
--- /dev/null
+++ b/scripts/vit_triplane_cldm_train.py
@@ -0,0 +1,367 @@
+"""
+Train a diffusion model on images.
+"""
+import json
+import sys
+import os
+
+sys.path.append('.')
+
+# from dnnlib import EasyDict
+import traceback
+
+import torch as th
+import torch.multiprocessing as mp
+import torch.distributed as dist
+import numpy as np
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.resample import create_named_schedule_sampler
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+ continuous_diffusion_defaults,
+ model_and_diffusion_defaults,
+ create_model_and_diffusion,
+)
+from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
+import nsr
+import nsr.lsgm
+# from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
+
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+from datasets.shapenet import load_data, load_eval_data, load_memory_data
+from nsr.losses.builder import E3DGELossClass
+
+from utils.torch_utils import legacy, misc
+from torch.utils.data import Subset
+from pdb import set_trace as st
+
+from dnnlib.util import EasyDict, InfiniteSampler
+# from .vit_triplane_train_FFHQ import init_dataset_kwargs
+from datasets.eg3d_dataset import init_dataset_kwargs
+
+# from torch.utils.tensorboard import SummaryWriter
+
+SEED = 0
+
+
+def training_loop(args):
+ # def training_loop(args):
+ logger.log("dist setup...")
+
+ th.cuda.set_device(
+ args.local_rank) # set this line to avoid extra memory on rank 0
+ th.cuda.empty_cache()
+
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+
+ dist_util.setup_dist(args)
+
+ # st() # mark
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating ViT encoder and NSR decoder...")
+ # st() # mark
+ device = dist_util.dev()
+
+ args.img_size = [args.image_size_encoder]
+
+ logger.log("creating model and diffusion...")
+ # * set denoise model args
+
+ if args.denoise_in_channels == -1:
+ args.diffusion_input_size = args.image_size_encoder
+ args.denoise_in_channels = args.out_chans
+ args.denoise_out_channels = args.out_chans
+ else:
+ assert args.denoise_out_channels != -1
+
+ # args.image_size = args.image_size_encoder # 224, follow the triplane size
+
+ # if args.diffusion_input_size == -1:
+ # else:
+ # args.image_size = args.diffusion_input_size
+
+ denoise_model, diffusion = create_model_and_diffusion(
+ **args_to_dict(args,
+ model_and_diffusion_defaults().keys()))
+ denoise_model.to(dist_util.dev())
+ denoise_model.train()
+
+ opts = eg3d_options_default()
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ logger.log("creating encoder and NSR decoder...")
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+
+ auto_encoder.to(device)
+ auto_encoder.eval()
+
+ # * load G_ema modules into autoencoder
+ # * clone G_ema.decoder to auto_encoder triplane
+ # logger.log("AE triplane decoder reuses G_ema decoder...")
+ # auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
+
+ # auto_encoder.decoder.triplane_decoder.decoder.load_state_dict( # type: ignore
+ # G_ema.decoder.state_dict()) # type: ignore
+
+ # set grad=False in this manner suppresses the DDP forward no grad error.
+ logger.log("freeze triplane decoder...")
+ for param in auto_encoder.decoder.triplane_decoder.parameters(
+ ): # type: ignore
+ # for param in auto_encoder.decoder.triplane_decoder.decoder.parameters(): # type: ignore
+ param.requires_grad_(False)
+
+ # if args.sr_training:
+
+ # logger.log("AE triplane decoder reuses G_ema SR module...")
+ # # auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
+ # # G_ema.superresolution.state_dict()) # type: ignore
+
+ # # set grad=False in this manner suppresses the DDP forward no grad error.
+ # logger.log("freeze SR module...")
+ # for param in auto_encoder.decoder.superresolution.parameters(): # type: ignore
+ # param.requires_grad_(False)
+
+ # # del G_ema
+ # th.cuda.empty_cache()
+
+ if args.cfg in ('afhq', 'ffhq'):
+
+ if args.sr_training:
+
+ logger.log("AE triplane decoder reuses G_ema SR module...")
+ auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
+ G_ema.superresolution.state_dict()) # type: ignore
+
+ # set grad=False in this manner suppresses the DDP forward no grad error.
+ for param in auto_encoder.decoder.triplane_decoder.superresolution.parameters(
+ ): # type: ignore
+ param.requires_grad_(False)
+
+ # ! load data
+ logger.log("creating eg3d data loader...")
+ training_set_kwargs, dataset_name = init_dataset_kwargs(
+ data=args.data_dir,
+ class_name='datasets.eg3d_dataset.ImageFolderDataset'
+ ) # only load pose here
+ # if args.cond and not training_set_kwargs.use_labels:
+ # raise Exception('check here')
+
+ # training_set_kwargs.use_labels = args.cond
+ training_set_kwargs.use_labels = True
+ training_set_kwargs.xflip = True
+ training_set_kwargs.random_seed = SEED
+ # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # * construct ffhq/afhq dataset
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ training_set_sampler = InfiniteSampler(
+ dataset=training_set,
+ rank=dist_util.get_rank(),
+ num_replicas=dist_util.get_world_size(),
+ seed=SEED)
+
+ data = iter(
+ th.utils.data.DataLoader(
+ dataset=training_set,
+ sampler=training_set_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ num_workers=args.num_workers,
+ ))
+ # prefetch_factor=2))
+
+ eval_data = th.utils.data.DataLoader(dataset=Subset(
+ training_set, np.arange(10)),
+ batch_size=args.eval_batch_size,
+ num_workers=1)
+
+ else:
+
+ logger.log("creating data loader...")
+ # TODO, load shapenet data
+ # data = load_data(
+ # st() mark
+ if args.overfitting:
+ logger.log("create overfitting memory dataset")
+ data = load_memory_data(
+ file_path=args.eval_data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True # for evaluation
+ )
+ else:
+ logger.log("create all instances dataset")
+ # st() mark
+ data = load_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True,
+ preprocess=auto_encoder.preprocess, # clip
+ dataset_size=args.dataset_size,
+ # load_depth=True # for evaluation
+ )
+ # st() mark
+ eval_data = load_eval_data(
+ file_path=args.eval_data_dir,
+ batch_size=args.eval_batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True # for evaluation
+ )
+
+ # let all processes sync up before starting with a new epoch of training
+
+ if dist_util.get_rank() == 0:
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ args.schedule_sampler = create_named_schedule_sampler(
+ args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ loss_class = E3DGELossClass(device, opt).to(device)
+
+ logger.log("training...")
+
+ TrainLoop = {
+ 'adm': nsr.TrainLoop3DDiffusion,
+ 'dit': nsr.TrainLoop3DDiffusionDiT,
+ 'ssd': nsr.TrainLoop3DDiffusionSingleStage,
+ # 'ssd_cvD': nsr.TrainLoop3DDiffusionSingleStagecvD,
+ 'ssd_cvD_sds': nsr.TrainLoop3DDiffusionSingleStagecvDSDS,
+ 'ssd_cvd_sds_no_separate_sds_step':
+ nsr.TrainLoop3DDiffusionSingleStagecvDSDS_sdswithrec,
+ 'vpsde_lsgm_noD': nsr.lsgm.TrainLoop3DDiffusionLSGM_noD, # use vpsde
+ # 'vpsde_lsgm': nsr.TrainLoop3DDiffusionLSGM, # use vpsde
+ # 'vpsde': nsr.TrainLoop3DDiffusion_vpsde,
+ }[args.trainer_name]
+
+ if 'vpsde' in args.trainer_name:
+ sde_diffusion = make_sde_diffusion(
+ dnnlib.EasyDict(
+ args_to_dict(args,
+ continuous_diffusion_defaults().keys())))
+ assert args.mixed_prediction, 'enable mixed_prediction by default'
+ logger.log('create VPSDE diffusion.')
+ else:
+ sde_diffusion = None
+
+ dist_util.synchronize()
+
+ TrainLoop(rec_model=auto_encoder,
+ denoise_model=denoise_model,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ **vars(args)).run_loop()
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ dataset_size=-1,
+ diffusion_input_size=-1,
+ trainer_name='adm',
+ use_amp=False,
+ triplane_scaling_divider=1.0, # divide by this value
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ schedule_sampler="uniform",
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ resume_checkpoint_EG3D="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ load_submodule_name='', # for loading pretrained auto_encoder model
+ ignore_resume_opt=False,
+ # freeze_ae=False,
+ denoised_ae=True,
+ )
+
+ defaults.update(model_and_diffusion_defaults())
+ defaults.update(continuous_diffusion_defaults())
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ # os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
+ # os.environ["NCCL_DEBUG"] = "INFO"
+
+ os.environ[
+ "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ # opts = dnnlib.EasyDict(vars(args)) # compatiable with triplane original settings
+ # opts = args
+ args.rendering_kwargs = rendering_options_defaults(args)
+
+ # Launch processes.
+ logger.log('Launching processes...')
+
+ logger.log('Available devices ', th.cuda.device_count())
+ logger.log('Current cuda device ', th.cuda.current_device())
+ # logger.log('GPU Device name:', th.cuda.get_device_name(th.cuda.current_device()))
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/vit_triplane_cvD_train.py b/scripts/vit_triplane_cvD_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..49e415d3fddb19e706fec95d97cc43e544b01972
--- /dev/null
+++ b/scripts/vit_triplane_cvD_train.py
@@ -0,0 +1,224 @@
+"""
+Train a diffusion model on images.
+"""
+import json
+import sys
+import os
+
+sys.path.append('.')
+import torch.distributed as dist
+
+import traceback
+
+import torch as th
+import torch.multiprocessing as mp
+import numpy as np
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+)
+
+# from nsr.train_util import TrainLoop3DRec as TrainLoop
+
+import nsr
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+from datasets.shapenet import load_data, load_eval_data, load_memory_data
+from nsr.losses.builder import E3DGELossClass
+
+from pdb import set_trace as st
+
+import warnings
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+# th.backends.cuda.matmul.allow_tf32 = True # https://huggingface.co/docs/diffusers/optimization/fp16
+
+SEED = 0
+
+
+def training_loop(args):
+ # def training_loop(args):
+ dist_util.setup_dist(args)
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count())
+ print(f"{args.local_rank=} init complete")
+ th.cuda.set_device(args.local_rank)
+ th.cuda.empty_cache()
+
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating encoder and NSR decoder...")
+ # device = dist_util.dev()
+ device = th.device("cuda", args.local_rank)
+
+ # shared eg3d opts
+ opts = eg3d_options_default()
+
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+ auto_encoder.to(device)
+ auto_encoder.train()
+
+ logger.log("creating data loader...")
+ # data = load_data(
+ if args.overfitting:
+ data = load_memory_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ # trainer_name=args.trainer_name,
+ # load_depth=args.depth_lambda > 0
+ load_depth=True # for evaluation
+ )
+ else:
+ data = load_data(
+ dataset_size=args.dataset_size,
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True,
+ preprocess=auto_encoder.preprocess, # clip
+ trainer_name=args.trainer_name,
+ use_lmdb=args.use_lmdb
+ # load_depth=True # for evaluation
+ )
+ eval_data = load_eval_data(
+ file_path=args.eval_data_dir,
+ batch_size=args.eval_batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=2,
+ load_depth=True, # for evaluation
+ preprocess=auto_encoder.preprocess)
+ args.img_size = [args.image_size_encoder]
+ # try dry run
+ # batch = next(data)
+ # batch = None
+
+ # logger.log("creating model and diffusion...")
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ loss_class = E3DGELossClass(device, opt).to(device)
+
+ # writer = SummaryWriter() # TODO, add log dir
+
+ logger.log("training...")
+
+ TrainLoop = {
+ 'cvD': nsr.TrainLoop3DcvD,
+ 'nvsD': nsr.TrainLoop3DcvD_nvsD,
+ 'nvsD_nosr': nsr.TrainLoop3DcvD_nvsD_noSR,
+ 'cano_nvsD_nosr': nsr.TrainLoop3DcvD_nvsD_noSR,
+ 'cano_nvs_cvD': nsr.TrainLoop3DcvD_nvsD_canoD,
+ 'cano_nvs_cvD_nv': nsr.TrainLoop3DcvD_nvsD_canoD_multiview,
+ 'cvD_nvsD_canoD_canomask': nsr.TrainLoop3DcvD_nvsD_canoD_canomask,
+ 'canoD': nsr.TrainLoop3DcvD_canoD
+ }[args.trainer_name]
+
+ TrainLoop(rec_model=auto_encoder,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ **vars(args)).run_loop() # ! overfitting
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ dataset_size=-1,
+ trainer_name='cvD',
+ use_amp=False,
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ pose_warm_up_iter=-1,
+ use_lmdb=False,
+ )
+
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ os.environ[
+ "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+ os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
+
+ # master_addr = '127.0.0.1'
+ # master_port = dist_util._find_free_port()
+ # master_port = 31323
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ opts = args
+
+ args.rendering_kwargs = rendering_options_defaults(opts)
+
+ # print(args)
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/vit_triplane_cvD_train_ffhq.py b/scripts/vit_triplane_cvD_train_ffhq.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc7a4684686be72c2301317c5e1290cb34c69e31
--- /dev/null
+++ b/scripts/vit_triplane_cvD_train_ffhq.py
@@ -0,0 +1,330 @@
+"""
+Train a diffusion model on images.
+"""
+import json
+import sys
+import os
+
+sys.path.append('.')
+import torch.distributed as dist
+
+import traceback
+
+import torch as th
+import torch.multiprocessing as mp
+import numpy as np
+
+import argparse
+import dnnlib
+from dnnlib.util import EasyDict, InfiniteSampler
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+)
+
+# from nsr.train_util import TrainLoop3DRec as TrainLoop
+
+import nsr
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+from datasets.shapenet import load_data, load_eval_data, load_memory_data
+from nsr.losses.builder import E3DGELossClass
+from torch.utils.data import Subset
+from datasets.eg3d_dataset import init_dataset_kwargs
+from utils.torch_utils import legacy, misc
+
+from pdb import set_trace as st
+
+import warnings
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+# th.backends.cuda.matmul.allow_tf32 = True # https://huggingface.co/docs/diffusers/optimization/fp16
+
+SEED = 0
+
+
+def training_loop(args):
+ # def training_loop(args):
+ dist_util.setup_dist(args)
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count())
+ print(f"{args.local_rank=} init complete")
+ th.cuda.set_device(args.local_rank)
+ th.cuda.empty_cache()
+
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating encoder and NSR decoder...")
+ # device = dist_util.dev()
+ device = th.device("cuda", args.local_rank)
+
+ # shared eg3d opts
+ opts = eg3d_options_default()
+
+ # if args.sr_training:
+ # args.sr_kwargs = dnnlib.EasyDict(
+ # channel_base=opts.cbase,
+ # channel_max=opts.cmax,
+ # fused_modconv_default='inference_only',
+ # use_noise=True
+ # ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ logger.log("creating data loader...")
+ # data = load_data(
+ # if args.overfitting:
+ # data = load_memory_data(
+ # file_path=args.data_dir,
+ # batch_size=args.batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=args.depth_lambda > 0
+ # load_depth=True # for evaluation
+ # )
+ # else:
+ # data = load_data(
+ # dataset_size=args.dataset_size,
+ # file_path=args.data_dir,
+ # batch_size=args.batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # load_depth=True,
+ # preprocess=auto_encoder.preprocess # clip
+ # # load_depth=True # for evaluation
+ # )
+ # eval_data = load_eval_data(
+ # file_path=args.eval_data_dir,
+ # batch_size=args.eval_batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=2,
+ # load_depth=True, # for evaluation
+ # preprocess=auto_encoder.preprocess)
+ # ! load pre-trained SR in G
+ common_kwargs = dict(c_dim=25, img_resolution=512, img_channels=3)
+
+ G_kwargs = EasyDict(class_name=None,
+ z_dim=512,
+ w_dim=512,
+ mapping_kwargs=EasyDict())
+ G_kwargs.channel_base = opts.cbase
+ G_kwargs.channel_max = opts.cmax
+ G_kwargs.mapping_kwargs.num_layers = opts.map_depth
+ G_kwargs.class_name = opts.g_class_name
+ G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
+ G_kwargs.rendering_kwargs = args.rendering_kwargs
+ G_kwargs.num_fp16_res = 0
+ G_kwargs.sr_num_fp16_res = 4
+
+ G_kwargs.sr_kwargs = EasyDict(channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True) # ! close noise injection? since noise_mode='none' in eg3d
+
+ G_kwargs.num_fp16_res = opts.g_num_fp16_res
+ G_kwargs.conv_clamp = 256 if opts.g_num_fp16_res > 0 else None
+
+ # creating G
+ resume_data = th.load(args.resume_checkpoint_EG3D, map_location='cuda:{}'.format(args.local_rank))
+ G_ema = dnnlib.util.construct_class_by_name(
+ **G_kwargs, **common_kwargs).train().requires_grad_(False).to(
+ dist_util.dev()) # subclass of th.nn.Module
+ for name, module in [
+ ('G_ema', G_ema),
+ # ('D', D),
+ ]:
+ misc.copy_params_and_buffers(
+ resume_data[name], # type: ignore
+ module,
+ require_all=True,
+ # load_except=d_load_except if name == 'D' else [],
+ )
+
+
+ G_ema.requires_grad_(False)
+ G_ema.eval()
+
+ if args.sr_training:
+ args.sr_kwargs = G_kwargs.sr_kwargs # uncomment if needs to train with SR module
+
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+ auto_encoder.to(device)
+ auto_encoder.train()
+
+ # * clone G_ema.decoder to auto_encoder triplane
+ logger.log("AE triplane decoder reuses G_ema decoder...")
+ auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
+
+ auto_encoder.decoder.triplane_decoder.decoder.load_state_dict( # type: ignore
+ G_ema.decoder.state_dict()) # type: ignore
+
+ # set grad=False in this manner suppresses the DDP forward no grad error.
+ for param in auto_encoder.decoder.triplane_decoder.decoder.parameters(): # type: ignore
+ param.requires_grad_(False)
+
+ if args.sr_training:
+ logger.log("AE triplane decoder reuses G_ema SR module...")
+ auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
+ G_ema.superresolution.state_dict()) # type: ignore
+ # set grad=False in this manner suppresses the DDP forward no grad error.
+ for param in auto_encoder.decoder.triplane_decoder.superresolution.parameters(): # type: ignore
+ param.requires_grad_(False)
+
+ del resume_data, G_ema
+ th.cuda.empty_cache()
+
+ auto_encoder.to(dist_util.dev())
+ auto_encoder.train()
+
+ # ! load FFHQ/AFHQ
+ # Training set.
+ # training_set_kwargs, dataset_name = init_dataset_kwargs(data=args.data_dir, class_name='datasets.eg3d_dataset.ImageFolderDatasetPose') # only load pose here
+ training_set_kwargs, dataset_name = init_dataset_kwargs(data=args.data_dir, class_name='datasets.eg3d_dataset.ImageFolderDataset') # only load pose here
+ # if args.cond and not training_set_kwargs.use_labels:
+ # raise Exception('check here')
+
+ # training_set_kwargs.use_labels = args.cond
+ training_set_kwargs.use_labels = True
+ training_set_kwargs.xflip = False
+ training_set_kwargs.random_seed = SEED
+ # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # * construct ffhq/afhq dataset
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ training_set_sampler = InfiniteSampler(
+ dataset=training_set,
+ rank=dist_util.get_rank(),
+ num_replicas=dist_util.get_world_size(),
+ seed=SEED)
+
+ data = iter(
+ th.utils.data.DataLoader(dataset=training_set,
+ sampler=training_set_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ num_workers=args.num_workers,))
+ # prefetch_factor=2))
+
+ eval_data = th.utils.data.DataLoader(dataset=Subset(training_set, np.arange(10)),
+ batch_size=args.eval_batch_size,
+ num_workers=1)
+
+ args.img_size = [args.image_size_encoder]
+ # try dry run
+ # batch = next(data)
+ # batch = None
+
+ # logger.log("creating model and diffusion...")
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ loss_class = E3DGELossClass(device, opt).to(device)
+
+ # writer = SummaryWriter() # TODO, add log dir
+
+ logger.log("training...")
+
+ TrainLoop = {
+ 'cvD': nsr.TrainLoop3DcvD,
+ 'nvsD': nsr.TrainLoop3DcvD_nvsD,
+ 'cano_nvs_cvD': nsr.TrainLoop3DcvD_nvsD_canoD,
+ 'canoD': nsr.TrainLoop3DcvD_canoD
+ }[args.trainer_name]
+
+ TrainLoop(rec_model=auto_encoder,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ **vars(args)).run_loop() # ! overfitting
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ dataset_size=-1,
+ trainer_name='cvD',
+ use_amp=False,
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ resume_checkpoint_EG3D="",
+ )
+
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ os.environ[
+ "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+ os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
+
+ # master_addr = '127.0.0.1'
+ # master_port = dist_util._find_free_port()
+ # master_port = 31323
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ opts = args
+
+ args.rendering_kwargs = rendering_options_defaults(opts)
+
+ # print(args)
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/vit_triplane_diffusion_sample.py b/scripts/vit_triplane_diffusion_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..31782a18f22718ca35d7f75edfa94f59fe8a4023
--- /dev/null
+++ b/scripts/vit_triplane_diffusion_sample.py
@@ -0,0 +1,427 @@
+"""
+Generate a large batch of image samples from a model and save them as a large
+numpy array. This can be used to produce samples for FID evaluation.
+"""
+
+import argparse
+import json
+import sys
+import os
+
+sys.path.append('.')
+
+from pdb import set_trace as st
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ NUM_CLASSES,
+ model_and_diffusion_defaults,
+ create_model_and_diffusion,
+ add_dict_to_argparser,
+ args_to_dict,
+ continuous_diffusion_defaults,
+ control_net_defaults,
+)
+
+from pathlib import Path
+
+from tqdm import tqdm, trange
+import dnnlib
+from dnnlib.util import EasyDict, InfiniteSampler
+from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
+from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
+import nsr
+import nsr.lsgm
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, AE_with_Diffusion, rendering_options_defaults, eg3d_options_default, dataset_defaults
+
+from datasets.shapenet import load_eval_data
+from torch.utils.data import Subset
+from datasets.eg3d_dataset import init_dataset_kwargs
+from datasets.eg3d_dataset import LMDBDataset_MV_Compressed_eg3d
+
+SEED = 0
+
+
+def main(args):
+
+ # args.rendering_kwargs = rendering_options_defaults(args)
+
+ dist_util.setup_dist(args)
+ logger.configure(dir=args.logdir)
+
+ th.cuda.empty_cache()
+
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+
+ # * set denoise model args
+ logger.log("creating model and diffusion...")
+ args.img_size = [args.image_size_encoder]
+ # ! no longer required for LDM
+ # args.denoise_in_channels = args.out_chans
+ # args.denoise_out_channels = args.out_chans
+ args.image_size = args.image_size_encoder # 224, follow the triplane size
+
+ denoise_model, diffusion = create_model_and_diffusion(
+ **args_to_dict(args,
+ model_and_diffusion_defaults().keys()))
+
+ if 'cldm' in args.trainer_name:
+ assert isinstance(denoise_model, tuple)
+ denoise_model, controlNet = denoise_model
+
+ controlNet.to(dist_util.dev())
+ controlNet.train()
+ else:
+ controlNet = None
+
+ opts = eg3d_options_default()
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ # denoise_model.load_state_dict(
+ # dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
+ denoise_model.to(dist_util.dev())
+ if args.use_fp16:
+ denoise_model.convert_to_fp16()
+ denoise_model.eval()
+
+ # * auto-encoder reconstruction model
+ logger.log("creating 3DAE...")
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+
+ # logger.log("AE triplane decoder reuses G_ema decoder...")
+ # auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
+
+ # print(auto_encoder.decoder.w_avg.shape) # [512]
+
+ # auto_encoder.load_state_dict(
+ # dist_util.load_state_dict(args.rec_model_path, map_location="cpu"))
+
+ auto_encoder.to(dist_util.dev())
+ auto_encoder.eval()
+
+ # TODO, how to set the scale?
+ logger.log("create dataset")
+
+ # data = None
+
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_wds_data
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data
+
+ eval_data = None
+
+ # if args.cfg in ('afhq', 'ffhq'):
+ # # ! load data
+ # if args.use_lmdb:
+ # logger.log("creating LMDB eg3d data loader...")
+ # training_set = LMDBDataset_MV_Compressed_eg3d(
+ # args.data_dir,
+ # args.image_size,
+ # args.image_size_encoder,
+ # )
+ # else:
+
+ # logger.log("creating eg3d data loader...")
+ # training_set_kwargs, dataset_name = init_dataset_kwargs(
+ # data=args.data_dir,
+ # class_name='datasets.eg3d_dataset.ImageFolderDataset'
+ # ) # only load pose here
+ # # if args.cond and not training_set_kwargs.use_labels:
+ # # raise Exception('check here')
+
+ # # training_set_kwargs.use_labels = args.cond
+ # training_set_kwargs.use_labels = True
+ # training_set_kwargs.xflip = True
+ # training_set_kwargs.random_seed = SEED
+ # # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # # * construct ffhq/afhq dataset
+ # training_set = dnnlib.util.construct_class_by_name(
+ # **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ # training_set = dnnlib.util.construct_class_by_name(
+ # **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ # # training_set_sampler = InfiniteSampler(
+ # # dataset=training_set,
+ # # rank=dist_util.get_rank(),
+ # # num_replicas=dist_util.get_world_size(),
+ # # seed=SEED)
+
+ # # data = iter(
+ # # th.utils.data.DataLoader(dataset=training_set,
+ # # sampler=training_set_sampler,
+ # # batch_size=args.batch_size,
+ # # pin_memory=True,
+ # # num_workers=args.num_workers,))
+ # # # prefetch_factor=2))
+
+ # # training_set_sampler = InfiniteSampler(
+ # # dataset=training_set,
+ # # rank=dist_util.get_rank(),
+ # # num_replicas=dist_util.get_world_size(),
+ # # seed=SEED)
+
+ # # data = iter(
+ # # th.utils.data.DataLoader(
+ # # dataset=training_set,
+ # # sampler=training_set_sampler,
+ # # batch_size=args.batch_size,
+ # # pin_memory=True,
+ # # num_workers=args.num_workers,
+ # # persistent_workers=args.num_workers > 0,
+ # # # prefetch_factor=max(8//args.batch_size, 2),
+ # # ))
+
+ # eval_data = th.utils.data.DataLoader(dataset=Subset(
+ # training_set, np.arange(25)),
+ # batch_size=args.eval_batch_size,
+ # num_workers=1)
+
+ # else:
+
+ # logger.log("creating data loader...")
+
+ # # if args.objv_dataset:
+ # # from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data
+ # # else: # shapenet
+ # # from datasets.shapenet import load_data, load_eval_data, load_memory_data
+
+ # # eval_data = load_eval_data(
+ # # file_path=args.eval_data_dir,
+ # # batch_size=args.eval_batch_size,
+ # # reso=args.image_size,
+ # # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # # num_workers=args.num_workers,
+ # # load_depth=True, # for evaluation
+ # # interval=args.interval,
+ # # use_lmdb=args.use_lmdb,
+ # # )
+
+ # if args.use_wds:
+ # if args.eval_data_dir == 'NONE':
+ # with open(args.eval_shards_lst) as f:
+ # eval_shards_lst = [url.strip() for url in f.readlines()]
+ # else:
+ # eval_shards_lst = args.eval_data_dir # auto expanded
+
+ # eval_data = load_wds_data(
+ # eval_shards_lst, args.image_size, args.image_size_encoder,
+ # args.eval_batch_size, args.num_workers,
+ # **args_to_dict(args,
+ # dataset_defaults().keys()))
+
+ # else:
+ # eval_data = load_eval_data(
+ # file_path=args.eval_data_dir,
+ # batch_size=args.eval_batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=True, # for evaluation
+ # **args_to_dict(args,
+ # dataset_defaults().keys()))
+
+ TrainLoop = {
+ 'adm': nsr.TrainLoop3DDiffusion,
+ 'vpsde_crossattn': nsr.lsgm.TrainLoop3DDiffusionLSGM_crossattn,
+ }[args.trainer_name]
+
+ # continuous
+ if 'vpsde' in args.trainer_name:
+ sde_diffusion = make_sde_diffusion(
+ dnnlib.EasyDict(
+ args_to_dict(args,
+ continuous_diffusion_defaults().keys())))
+ assert args.mixed_prediction, 'enable mixed_prediction by default'
+ logger.log('create VPSDE diffusion.')
+ else:
+ sde_diffusion = None
+
+ # if 'cldm' in args.trainer_name:
+ # assert isinstance(denoise_model, tuple)
+ # denoise_model, controlNet = denoise_model
+
+ # controlNet.to(dist_util.dev())
+ # controlNet.train()
+ # else:
+ # controlNet = None
+
+ training_loop_class = TrainLoop(rec_model=auto_encoder,
+ denoise_model=denoise_model,
+ control_model=controlNet,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ loss_class=None,
+ data=None,
+ eval_data=eval_data,
+ **vars(args))
+
+ logger.log("sampling...")
+ dist_util.synchronize()
+
+ # all_images = []
+ # all_labels = []
+ # while len(all_images) * args.batch_size < args.num_samples:
+
+ if dist_util.get_rank() == 0:
+
+ (Path(logger.get_dir()) / 'FID_Cals').mkdir(exist_ok=True,
+ parents=True)
+
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # load eval pose
+ if args.cfg == 'ffhq':
+ camera = th.load('assets/ffhq_eval_pose.pt',
+ map_location=dist_util.dev())[:]
+ elif args.cfg == 'shapenet':
+ camera = th.load('assets/shapenet_eval_pose.pt',
+ map_location=dist_util.dev())[:]
+
+ for sample_idx in trange(args.num_samples):
+ model_kwargs = {}
+
+ # if args.class_cond:
+ # classes = th.randint(low=0,
+ # high=NUM_CLASSES,
+ # size=(args.batch_size, ),
+ # device=dist_util.dev())
+ # model_kwargs["y"] = classes
+ training_loop_class.step = sample_idx # save to different position
+ if args.create_controlnet or 'crossattn' in args.trainer_name:
+ training_loop_class.eval_cldm(
+ prompt=args.prompt,
+ unconditional_guidance_scale=args.
+ unconditional_guidance_scale,
+ use_ddim=args.use_ddim,
+ save_img=args.save_img,
+ use_train_trajectory=args.use_train_trajectory,
+ export_mesh=args.export_mesh,
+ camera=camera,
+ overwrite_diff_inp_size=args.overwrite_diff_inp_size,
+ # training_loop_class.rec_model,
+ # training_loop_class.ddpm_model
+ )
+ else:
+ # evaluate ldm
+ training_loop_class.eval_ddpm_sample(
+ training_loop_class.rec_model,
+ save_img=args.save_img,
+ use_train_trajectory=args.use_train_trajectory,
+ export_mesh=args.export_mesh,
+ camera=camera,
+ # training_loop_class.ddpm_model
+ )
+
+ dist.barrier()
+ logger.log("sampling complete")
+
+
+def create_argparser():
+ defaults = dict(
+ image_size_encoder=224,
+ triplane_scaling_divider=1.0, # divide by this value
+ diffusion_input_size=-1,
+ trainer_name='adm',
+ use_amp=False,
+ # triplane_scaling_divider=1.0, # divide by this value
+
+ # * sampling flags
+ clip_denoised=False,
+ num_samples=10,
+ use_ddim=False,
+ ddpm_model_path="",
+ cldm_model_path="",
+ rec_model_path="",
+
+ # * eval logging flags
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ data_dir="",
+ eval_data_dir="",
+ eval_batch_size=1,
+ num_workers=1,
+
+ # * training flags for loading TrainingLoop class
+ overfitting=False,
+ image_size=128,
+ iterations=150000,
+ schedule_sampler="uniform",
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ resume_cldm_checkpoint="",
+ resume_checkpoint_EG3D="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ load_submodule_name='', # for loading pretrained auto_encoder model
+ ignore_resume_opt=False,
+ freeze_ae=False,
+ denoised_ae=True,
+ # inference prompt
+ prompt="a red chair",
+ interval=1,
+ objv_dataset=False,
+ use_lmdb=False,
+ save_img=False,
+ use_train_trajectory=
+ False, # use train trajectory to sample images for fid calculation
+ unconditional_guidance_scale=1.0,
+ cond_key='img_sr',
+ use_eos_feature=False,
+ export_mesh=False,
+ overwrite_diff_inp_size=None,
+ )
+
+ defaults.update(model_and_diffusion_defaults())
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+ defaults.update(continuous_diffusion_defaults())
+ defaults.update(control_net_defaults())
+ defaults.update(dataset_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ # os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
+ # os.environ["NCCL_DEBUG"] = "INFO"
+
+ os.environ[
+ "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+
+ args = create_argparser().parse_args()
+
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ args.rendering_kwargs = rendering_options_defaults(args)
+
+ main(args)
diff --git a/scripts/vit_triplane_diffusion_sample_objaverse.py b/scripts/vit_triplane_diffusion_sample_objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b18be3b246cb9b1f3d878a47fac3e22affb81b
--- /dev/null
+++ b/scripts/vit_triplane_diffusion_sample_objaverse.py
@@ -0,0 +1,364 @@
+"""
+Generate a large batch of image samples from a model and save them as a large
+numpy array. This can be used to produce samples for FID evaluation.
+"""
+
+import argparse
+import json
+import sys
+import os
+
+sys.path.append('.')
+
+from pdb import set_trace as st
+import imageio
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ NUM_CLASSES,
+ model_and_diffusion_defaults,
+ create_model_and_diffusion,
+ add_dict_to_argparser,
+ args_to_dict,
+ continuous_diffusion_defaults,
+ control_net_defaults,
+)
+
+th.backends.cuda.matmul.allow_tf32 = True
+th.backends.cudnn.allow_tf32 = True
+th.backends.cudnn.enabled = True
+
+from pathlib import Path
+
+from tqdm import tqdm, trange
+import dnnlib
+from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
+from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
+import nsr
+import nsr.lsgm
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, AE_with_Diffusion, rendering_options_defaults, eg3d_options_default, dataset_defaults
+
+from datasets.shapenet import load_eval_data
+from torch.utils.data import Subset
+from datasets.eg3d_dataset import init_dataset_kwargs
+
+SEED = 0
+
+
+def main(args):
+
+ # args.rendering_kwargs = rendering_options_defaults(args)
+
+ dist_util.setup_dist(args)
+ logger.configure(dir=args.logdir)
+
+ th.cuda.empty_cache()
+
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+
+ # * set denoise model args
+ logger.log("creating model and diffusion...")
+ args.img_size = [args.image_size_encoder]
+ # ! no longer required for LDM
+ # args.denoise_in_channels = args.out_chans
+ # args.denoise_out_channels = args.out_chans
+ args.image_size = args.image_size_encoder # 224, follow the triplane size
+
+ denoise_model, diffusion = create_model_and_diffusion(
+ **args_to_dict(args,
+ model_and_diffusion_defaults().keys()))
+
+ if 'cldm' in args.trainer_name:
+ assert isinstance(denoise_model, tuple)
+ denoise_model, controlNet = denoise_model
+
+ controlNet.to(dist_util.dev())
+ controlNet.train()
+ else:
+ controlNet = None
+
+ opts = eg3d_options_default()
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ # denoise_model.load_state_dict(
+ # dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
+ denoise_model.to(dist_util.dev())
+ if args.use_fp16:
+ denoise_model.convert_to_fp16()
+ denoise_model.eval()
+
+ # * auto-encoder reconstruction model
+ logger.log("creating 3DAE...")
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+
+ # logger.log("AE triplane decoder reuses G_ema decoder...")
+ # auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
+
+ # print(auto_encoder.decoder.w_avg.shape) # [512]
+
+ # auto_encoder.load_state_dict(
+ # dist_util.load_state_dict(args.rec_model_path, map_location="cpu"))
+
+ auto_encoder.to(dist_util.dev())
+ auto_encoder.eval()
+
+ # TODO, how to set the scale?
+ logger.log("create dataset")
+
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_wds_data
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data
+
+ # if args.cfg in ('afhq', 'ffhq'):
+ # # ! load data
+ # logger.log("creating eg3d data loader...")
+ # training_set_kwargs, dataset_name = init_dataset_kwargs(
+ # data=args.data_dir,
+ # class_name='datasets.eg3d_dataset.ImageFolderDataset'
+ # ) # only load pose here
+ # # if args.cond and not training_set_kwargs.use_labels:
+ # # raise Exception('check here')
+
+ # # training_set_kwargs.use_labels = args.cond
+ # training_set_kwargs.use_labels = True
+ # training_set_kwargs.xflip = True
+ # training_set_kwargs.random_seed = SEED
+ # # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # # * construct ffhq/afhq dataset
+ # training_set = dnnlib.util.construct_class_by_name(
+ # **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ # training_set = dnnlib.util.construct_class_by_name(
+ # **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ # # training_set_sampler = InfiniteSampler(
+ # # dataset=training_set,
+ # # rank=dist_util.get_rank(),
+ # # num_replicas=dist_util.get_world_size(),
+ # # seed=SEED)
+
+ # # data = iter(
+ # # th.utils.data.DataLoader(dataset=training_set,
+ # # sampler=training_set_sampler,
+ # # batch_size=args.batch_size,
+ # # pin_memory=True,
+ # # num_workers=args.num_workers,))
+ # # # prefetch_factor=2))
+
+ # eval_data = th.utils.data.DataLoader(dataset=Subset(
+ # training_set, np.arange(25)),
+ # batch_size=args.eval_batch_size,
+ # num_workers=1)
+
+ # else:
+
+ # logger.log("creating data loader...")
+
+ # if args.use_wds:
+ # if args.eval_data_dir == 'NONE':
+ # with open(args.eval_shards_lst) as f:
+ # eval_shards_lst = [url.strip() for url in f.readlines()]
+ # else:
+ # eval_shards_lst = args.eval_data_dir # auto expanded
+
+ # eval_data = load_wds_data(
+ # eval_shards_lst, args.image_size, args.image_size_encoder,
+ # args.eval_batch_size, args.num_workers,
+ # **args_to_dict(args,
+ # dataset_defaults().keys()))
+
+ # else:
+ # eval_data = load_eval_data(
+ # file_path=args.eval_data_dir,
+ # batch_size=args.eval_batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=True, # for evaluation
+ # **args_to_dict(args,
+ # dataset_defaults().keys()))
+
+ TrainLoop = {
+ # 'adm': nsr.TrainLoop3DDiffusion,
+ # 'vpsde_ldm': nsr.lsgm.TrainLoop3D_LDM,
+ # 'dit': nsr.TrainLoop3DDiffusionDiT,
+ # lsgm
+ 'vpsde_crossattn': nsr.lsgm.TrainLoop3DDiffusionLSGM_crossattn,
+ 'vpsde_crossattn_objv': nsr.crossattn_cldm_objv.TrainLoop3DDiffusionLSGM_crossattn, # for api compat
+ }[args.trainer_name]
+
+ # continuous
+ if 'vpsde' in args.trainer_name:
+ sde_diffusion = make_sde_diffusion(
+ dnnlib.EasyDict(
+ args_to_dict(args,
+ continuous_diffusion_defaults().keys())))
+ # assert args.mixed_prediction, 'enable mixed_prediction by default'
+ logger.log('create VPSDE diffusion.')
+ else:
+ sde_diffusion = None
+
+ auto_encoder.decoder.rendering_kwargs = args.rendering_kwargs
+
+ training_loop_class = TrainLoop(rec_model=auto_encoder,
+ denoise_model=denoise_model,
+ control_model=controlNet,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ loss_class=None,
+ data=None,
+ # eval_data=eval_data,
+ eval_data=None,
+ **vars(args))
+
+ logger.log("sampling...")
+ dist_util.synchronize()
+
+ # all_images = []
+ # all_labels = []
+ # while len(all_images) * args.batch_size < args.num_samples:
+
+ if dist_util.get_rank() == 0:
+
+ (Path(logger.get_dir()) / 'FID_Cals').mkdir(exist_ok=True,
+ parents=True)
+
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # ! use pre-saved camera pose form g-buffer objaverse
+ camera = th.load('assets/objv_eval_pose.pt', map_location=dist_util.dev())[:]
+
+ if args.create_controlnet or 'crossattn' in args.trainer_name:
+ training_loop_class.eval_cldm(
+ prompt=args.prompt,
+ unconditional_guidance_scale=args.
+ unconditional_guidance_scale,
+ use_ddim=args.use_ddim,
+ save_img=args.save_img,
+ use_train_trajectory=args.use_train_trajectory,
+ camera=camera,
+ num_instances=args.num_instances,
+ num_samples=args.num_samples,
+ export_mesh=args.export_mesh,
+ # training_loop_class.rec_model,
+ # training_loop_class.ddpm_model
+ )
+ else:
+ # evaluate ldm
+ training_loop_class.eval_ddpm_sample(
+ training_loop_class.rec_model,
+ save_img=args.save_img,
+ use_train_trajectory=args.use_train_trajectory,
+ export_mesh=args.export_mesh,
+ # training_loop_class.ddpm_model
+ )
+
+ dist.barrier()
+ logger.log("sampling complete")
+
+
+def create_argparser():
+ defaults = dict(
+ image_size_encoder=224,
+ triplane_scaling_divider=1.0, # divide by this value
+ diffusion_input_size=-1,
+ trainer_name='adm',
+ use_amp=False,
+ # triplane_scaling_divider=1.0, # divide by this value
+
+ # * sampling flags
+ clip_denoised=False,
+ num_samples=10,
+ num_instances=10, # for i23d, loop different condition
+ use_ddim=False,
+ ddpm_model_path="",
+ cldm_model_path="",
+ rec_model_path="",
+
+ # * eval logging flags
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ data_dir="",
+ eval_data_dir="",
+ eval_batch_size=1,
+ num_workers=1,
+
+ # * training flags for loading TrainingLoop class
+ overfitting=False,
+ image_size=128,
+ iterations=150000,
+ schedule_sampler="uniform",
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ resume_cldm_checkpoint="",
+ resume_checkpoint_EG3D="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ load_submodule_name='', # for loading pretrained auto_encoder model
+ ignore_resume_opt=False,
+ freeze_ae=False,
+ denoised_ae=True,
+ # inference prompt
+ prompt="a red chair",
+ interval=1,
+ save_img=False,
+ use_train_trajectory=
+ False, # use train trajectory to sample images for fid calculation
+ unconditional_guidance_scale=1.0,
+ use_eos_feature=False,
+ export_mesh=False,
+ cond_key='caption',
+ )
+
+ defaults.update(model_and_diffusion_defaults())
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+ defaults.update(continuous_diffusion_defaults())
+ defaults.update(control_net_defaults())
+ defaults.update(dataset_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ # os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
+ # os.environ["NCCL_DEBUG"] = "INFO"
+
+ os.environ[
+ "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+
+ args = create_argparser().parse_args()
+
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ args.rendering_kwargs = rendering_options_defaults(args)
+
+ main(args)
diff --git a/scripts/vit_triplane_diffusion_train.py b/scripts/vit_triplane_diffusion_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..7947d4272ed868f1d1a9f51094c6734e70f30e2b
--- /dev/null
+++ b/scripts/vit_triplane_diffusion_train.py
@@ -0,0 +1,426 @@
+"""
+Train a diffusion model on images.
+"""
+import json
+import sys
+import os
+
+sys.path.append('.')
+
+# from dnnlib import EasyDict
+import traceback
+
+import torch as th
+import torch.multiprocessing as mp
+import torch.distributed as dist
+import numpy as np
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.resample import create_named_schedule_sampler
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+ continuous_diffusion_defaults,
+ control_net_defaults,
+ model_and_diffusion_defaults,
+ create_model_and_diffusion,
+)
+from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
+import nsr
+import nsr.lsgm
+# from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
+
+from datasets.eg3d_dataset import LMDBDataset_MV_Compressed_eg3d
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+from datasets.shapenet import load_data, load_eval_data, load_memory_data
+from nsr.losses.builder import E3DGELossClass
+
+from utils.torch_utils import legacy, misc
+from torch.utils.data import Subset
+from pdb import set_trace as st
+
+from dnnlib.util import EasyDict, InfiniteSampler
+# from .vit_triplane_train_FFHQ import init_dataset_kwargs
+from datasets.eg3d_dataset import init_dataset_kwargs
+
+# from torch.utils.tensorboard import SummaryWriter
+
+SEED = 0
+
+
+def training_loop(args):
+ # def training_loop(args):
+ logger.log("dist setup...")
+ # th.autograd.set_detect_anomaly(False) # type: ignore
+ th.autograd.set_detect_anomaly(True) # type: ignore
+
+ th.cuda.set_device(
+ args.local_rank) # set this line to avoid extra memory on rank 0
+ th.cuda.empty_cache()
+
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+
+ dist_util.setup_dist(args)
+
+ # st() # mark
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating ViT encoder and NSR decoder...")
+ # st() # mark
+ device = dist_util.dev()
+
+ args.img_size = [args.image_size_encoder]
+
+ logger.log("creating model and diffusion...")
+ # * set denoise model args
+
+ if args.denoise_in_channels == -1:
+ args.diffusion_input_size = args.image_size_encoder
+ args.denoise_in_channels = args.out_chans
+ args.denoise_out_channels = args.out_chans
+ else:
+ assert args.denoise_out_channels != -1
+
+ # args.image_size = args.image_size_encoder # 224, follow the triplane size
+
+ # if args.diffusion_input_size == -1:
+ # else:
+ # args.image_size = args.diffusion_input_size
+
+ if args.pred_type == 'v': # for lsgm training
+ assert args.predict_v == True # for DDIM sampling
+
+ denoise_model, diffusion = create_model_and_diffusion(
+ **args_to_dict(args,
+ model_and_diffusion_defaults().keys()))
+
+ opts = eg3d_options_default()
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ logger.log("creating encoder and NSR decoder...")
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+
+ auto_encoder.to(device)
+ auto_encoder.eval()
+
+ # * load G_ema modules into autoencoder
+ # * clone G_ema.decoder to auto_encoder triplane
+ # logger.log("AE triplane decoder reuses G_ema decoder...")
+ # auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
+
+ # auto_encoder.decoder.triplane_decoder.decoder.load_state_dict( # type: ignore
+ # G_ema.decoder.state_dict()) # type: ignore
+
+ # set grad=False in this manner suppresses the DDP forward no grad error.
+
+ # if args.sr_training:
+
+ # logger.log("AE triplane decoder reuses G_ema SR module...")
+ # # auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
+ # # G_ema.superresolution.state_dict()) # type: ignore
+
+ # # set grad=False in this manner suppresses the DDP forward no grad error.
+ # logger.log("freeze SR module...")
+ # for param in auto_encoder.decoder.superresolution.parameters(): # type: ignore
+ # param.requires_grad_(False)
+
+ # # del G_ema
+ # th.cuda.empty_cache()
+
+ if args.freeze_triplane_decoder:
+ logger.log("freeze triplane decoder...")
+ for param in auto_encoder.decoder.triplane_decoder.parameters(
+ ): # type: ignore
+ # for param in auto_encoder.decoder.triplane_decoder.decoder.parameters(): # type: ignore
+ param.requires_grad_(False)
+
+
+ if args.cfg in ('afhq', 'ffhq'):
+
+ if args.sr_training:
+
+ logger.log("AE triplane decoder reuses G_ema SR module...")
+ auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
+ G_ema.superresolution.state_dict()) # type: ignore
+
+ # set grad=False in this manner suppresses the DDP forward no grad error.
+ for param in auto_encoder.decoder.triplane_decoder.superresolution.parameters(
+ ): # type: ignore
+ param.requires_grad_(False)
+
+ # ! load data
+ if args.use_lmdb:
+ logger.log("creating LMDB eg3d data loader...")
+ training_set = LMDBDataset_MV_Compressed_eg3d(
+ args.data_dir,
+ args.image_size,
+ args.image_size_encoder,
+ )
+ else:
+ logger.log("creating eg3d data loader...")
+
+ training_set_kwargs, dataset_name = init_dataset_kwargs(data=args.data_dir,
+ class_name='datasets.eg3d_dataset.ImageFolderDataset',
+ reso_gt=args.image_size) # only load pose here
+ # if args.cond and not training_set_kwargs.use_labels:
+ # raise Exception('check here')
+
+ # training_set_kwargs.use_labels = args.cond
+ training_set_kwargs.use_labels = True
+ training_set_kwargs.xflip = False
+ training_set_kwargs.random_seed = SEED
+ training_set_kwargs.max_size = args.dataset_size
+ # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # * construct ffhq/afhq dataset
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+
+ training_set_sampler = InfiniteSampler(
+ dataset=training_set,
+ rank=dist_util.get_rank(),
+ num_replicas=dist_util.get_world_size(),
+ seed=SEED)
+
+ data = iter(
+ th.utils.data.DataLoader(
+ dataset=training_set,
+ sampler=training_set_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ num_workers=args.num_workers,
+ persistent_workers=args.num_workers>0,
+ # prefetch_factor=max(8//args.batch_size, 2),
+ ))
+ # prefetch_factor=2))
+
+ eval_data = th.utils.data.DataLoader(dataset=Subset(
+ training_set, np.arange(8)),
+ batch_size=args.eval_batch_size,
+ num_workers=1)
+
+ else:
+
+ logger.log("creating data loader...")
+
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data
+
+
+ # TODO, load shapenet data
+ # data = load_data(
+ # st() mark
+ if args.overfitting:
+ logger.log("create overfitting memory dataset")
+ data = load_memory_data(
+ file_path=args.eval_data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True # for evaluation
+ )
+ else:
+ logger.log("create all instances dataset")
+ # st() mark
+ data = load_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=args.load_depth,
+ preprocess=auto_encoder.preprocess, # clip
+ dataset_size=args.dataset_size,
+ use_lmdb=args.use_lmdb,
+ trainer_name=args.trainer_name,
+ # load_depth=True # for evaluation
+ )
+
+ eval_data = load_eval_data(
+ file_path=args.eval_data_dir,
+ batch_size=args.eval_batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True, # for evaluation
+ interval=args.interval,
+ # use_lmdb=args.use_lmdb,
+ )
+
+ # let all processes sync up before starting with a new epoch of training
+
+ if dist_util.get_rank() == 0:
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ args.schedule_sampler = create_named_schedule_sampler(
+ args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ loss_class = E3DGELossClass(device, opt).to(device)
+
+ logger.log("training...")
+
+ TrainLoop = {
+ 'adm': nsr.TrainLoop3DDiffusion,
+ 'dit': nsr.TrainLoop3DDiffusionDiT,
+ 'ssd': nsr.TrainLoop3DDiffusionSingleStage,
+ # 'ssd_cvD': nsr.TrainLoop3DDiffusionSingleStagecvD,
+ 'ssd_cvD_sds': nsr.TrainLoop3DDiffusionSingleStagecvDSDS,
+ 'ssd_cvd_sds_no_separate_sds_step':
+ nsr.TrainLoop3DDiffusionSingleStagecvDSDS_sdswithrec,
+ 'vpsde_lsgm_noD': nsr.lsgm.TrainLoop3DDiffusionLSGM_noD, # use vpsde
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD,
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD_scaling,
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm,
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD,
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_weightingv0': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_weightingv0,
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED,
+ 'vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED_nv': nsr.lsgm.TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED_nv,
+ 'vpsde_lsgm_joint_noD': nsr.lsgm.TrainLoop3DDiffusionLSGMJointnoD, # use vpsde
+ 'vpsde_lsgm_joint_noD_ponly': nsr.lsgm.TrainLoop3DDiffusionLSGMJointnoD_ponly, # use vpsde
+ # control
+ 'vpsde_cldm':nsr.lsgm.TrainLoop3DDiffusionLSGM_Control,
+ 'vpsde_crossattn': nsr.lsgm.TrainLoop3DDiffusionLSGM_crossattn,
+ 'vpsde_ldm': nsr.lsgm.TrainLoop3D_LDM,
+ }[args.trainer_name]
+
+ if 'vpsde' in args.trainer_name:
+ sde_diffusion = make_sde_diffusion(
+ dnnlib.EasyDict(
+ args_to_dict(args,
+ continuous_diffusion_defaults().keys())))
+ assert args.mixed_prediction, 'enable mixed_prediction by default'
+ logger.log('create VPSDE diffusion.')
+ else:
+ sde_diffusion = None
+
+
+ if 'cldm' in args.trainer_name:
+ assert isinstance(denoise_model, tuple)
+ denoise_model, controlNet = denoise_model
+
+ controlNet.to(dist_util.dev())
+ controlNet.train()
+ else:
+ controlNet = None
+
+ # st()
+ denoise_model.to(dist_util.dev())
+ denoise_model.train()
+
+ TrainLoop(rec_model=auto_encoder,
+ denoise_model=denoise_model,
+ control_model=controlNet,
+ diffusion=diffusion,
+ sde_diffusion=sde_diffusion,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ **vars(args)).run_loop()
+
+ dist_util.synchronize()
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ dataset_size=-1,
+ diffusion_input_size=-1,
+ trainer_name='adm',
+ use_amp=False,
+ triplane_scaling_divider=1.0, # divide by this value
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ schedule_sampler="uniform",
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ resume_checkpoint_EG3D="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ load_depth=True, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ load_submodule_name='', # for loading pretrained auto_encoder model
+ ignore_resume_opt=False,
+ # freeze_ae=False,
+ denoised_ae=True,
+ diffusion_ce_anneal=False,
+ use_lmdb=False,
+ interval=1,
+ freeze_triplane_decoder=False,
+ objv_dataset=False,
+ cond_key='img_sr',
+ )
+
+ defaults.update(model_and_diffusion_defaults())
+ defaults.update(continuous_diffusion_defaults())
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+ defaults.update(control_net_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ # os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
+ # os.environ["NCCL_DEBUG"] = "INFO"
+
+ os.environ[
+ "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ # opts = dnnlib.EasyDict(vars(args)) # compatiable with triplane original settings
+ # opts = args
+ args.rendering_kwargs = rendering_options_defaults(args)
+
+ # Launch processes.
+ logger.log('Launching processes...')
+
+ logger.log('Available devices ', th.cuda.device_count())
+ logger.log('Current cuda device ', th.cuda.current_device())
+ # logger.log('GPU Device name:', th.cuda.get_device_name(th.cuda.current_device()))
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/vit_triplane_train.py b/scripts/vit_triplane_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..704c9db9d4ab8d4ae7de1d10c80f0a009b78f825
--- /dev/null
+++ b/scripts/vit_triplane_train.py
@@ -0,0 +1,351 @@
+"""
+Train a diffusion model on images.
+"""
+import random
+import json
+import sys
+import os
+
+sys.path.append('.')
+import torch.distributed as dist
+
+import traceback
+
+import torch as th
+
+import torch.multiprocessing as mp
+import numpy as np
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+)
+# from nsr.train_util import TrainLoop3DRec as TrainLoop
+from nsr.train_nv_util import TrainLoop3DRecNV, TrainLoop3DRec, TrainLoop3DRecNVPatch, TrainLoop3DRecNVPatchSingleForward, TrainLoop3DRecNVPatchSingleForwardMV, TrainLoop3DRecNVPatchSingleForwardMVAdvLoss
+
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default, dataset_defaults
+from nsr.losses.builder import E3DGELossClass, E3DGE_with_AdvLoss
+
+from pdb import set_trace as st
+
+# th.backends.cuda.matmul.allow_tf32 = True # https://huggingface.co/docs/diffusers/optimization/fp16
+# th.backends.cuda.matmul.allow_tf32 = True
+# th.backends.cudnn.allow_tf32 = True
+# th.backends.cudnn.enabled = True
+
+enable_tf32 = th.backends.cuda.matmul.allow_tf32 # requires A100
+
+th.backends.cuda.matmul.allow_tf32 = enable_tf32
+th.backends.cudnn.allow_tf32 = enable_tf32
+th.backends.cudnn.enabled = True
+
+
+def training_loop(args):
+ # def training_loop(args):
+ dist_util.setup_dist(args)
+ # th.autograd.set_detect_anomaly(True) # type: ignore
+ th.autograd.set_detect_anomaly(False) # type: ignore
+ # https://blog.csdn.net/qq_41682740/article/details/126304613
+
+ SEED = args.seed
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count())
+ logger.log(f"{args.local_rank=} init complete, seed={SEED}")
+ th.cuda.set_device(args.local_rank)
+ th.cuda.empty_cache()
+
+ # * deterministic algorithms flags
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+ random.seed(SEED)
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating encoder and NSR decoder...")
+ # device = dist_util.dev()
+ device = th.device("cuda", args.local_rank)
+
+ # shared eg3d opts
+ opts = eg3d_options_default()
+
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+ auto_encoder = create_3DAE_model(
+ **args_to_dict(args,
+ encoder_and_nsr_defaults().keys()))
+ auto_encoder.to(device)
+ auto_encoder.train()
+
+ logger.log("creating data loader...")
+ # data = load_data(
+ # st()
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_wds_data
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data
+
+ if args.overfitting:
+ data = load_memory_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ # load_depth=args.depth_lambda > 0
+ # load_depth=True, # for evaluation
+ **args_to_dict(args,
+ dataset_defaults().keys()))
+ eval_data = None
+ else:
+ if args.use_wds:
+ # st()
+ if args.data_dir == 'NONE':
+ with open(args.shards_lst) as f:
+ shards_lst = [url.strip() for url in f.readlines()]
+ data = load_wds_data(
+ shards_lst, # type: ignore
+ args.image_size,
+ args.image_size_encoder,
+ args.batch_size,
+ args.num_workers,
+ # plucker_embedding=args.plucker_embedding,
+ # mv_input=args.mv_input,
+ # split_chunk_input=args.split_chunk_input,
+ **args_to_dict(args,
+ dataset_defaults().keys()))
+
+ elif not args.inference:
+ data = load_wds_data(args.data_dir,
+ args.image_size,
+ args.image_size_encoder,
+ args.batch_size,
+ args.num_workers,
+ plucker_embedding=args.plucker_embedding,
+ mv_input=args.mv_input,
+ split_chunk_input=args.split_chunk_input)
+ else:
+ data = None
+ # ! load eval
+
+ if args.eval_data_dir == 'NONE':
+ with open(args.eval_shards_lst) as f:
+ eval_shards_lst = [url.strip() for url in f.readlines()]
+ else:
+ eval_shards_lst = args.eval_data_dir # auto expanded
+
+ eval_data = load_wds_data(
+ eval_shards_lst, # type: ignore
+ args.image_size,
+ args.image_size_encoder,
+ args.eval_batch_size,
+ args.num_workers,
+ # decode_encode_img_only=args.decode_encode_img_only,
+ # plucker_embedding=args.plucker_embedding,
+ # load_wds_diff=False,
+ # mv_input=args.mv_input,
+ # split_chunk_input=args.split_chunk_input,
+ **args_to_dict(args,
+ dataset_defaults().keys()))
+ # load_instance=True) # TODO
+
+ else:
+
+ if args.inference:
+ data = None
+ else:
+ data = load_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True,
+ preprocess=auto_encoder.preprocess, # clip
+ dataset_size=args.dataset_size,
+ trainer_name=args.trainer_name,
+ use_lmdb=args.use_lmdb,
+ use_wds=args.use_wds,
+ use_lmdb_compressed=args.use_lmdb_compressed,
+ plucker_embedding=args.plucker_embedding
+ # load_depth=True # for evaluation
+ )
+
+ if args.pose_warm_up_iter > 0:
+ overfitting_dataset = load_memory_data(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ # load_depth=args.depth_lambda > 0
+ # load_depth=True # for evaluation
+ **args_to_dict(args,
+ dataset_defaults().keys()))
+ data = [data, overfitting_dataset, args.pose_warm_up_iter]
+
+ eval_data = load_eval_data(
+ file_path=args.eval_data_dir,
+ batch_size=args.eval_batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True, # for evaluation
+ preprocess=auto_encoder.preprocess,
+ # interval=args.interval,
+ # use_lmdb=args.use_lmdb,
+ # plucker_embedding=args.plucker_embedding,
+ # load_real=args.load_real,
+ # four_view_for_latent=args.four_view_for_latent,
+ # load_extra_36_view=args.load_extra_36_view,
+ # shuffle_across_cls=args.shuffle_across_cls,
+ **args_to_dict(args,
+ dataset_defaults().keys()))
+
+ logger.log("creating data loader done...")
+
+ args.img_size = [args.image_size_encoder]
+ # try dry run
+ # batch = next(data)
+ # batch = None
+
+ # logger.log("creating model and diffusion...")
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ # opt.max_depth, opt.min_depth = args.rendering_kwargs.ray_end, args.rendering_kwargs.ray_start
+ if 'disc' in args.trainer_name:
+ loss_class = E3DGE_with_AdvLoss(
+ device,
+ opt,
+ # disc_weight=args.patchgan_disc, # rec_cvD_lambda
+ disc_factor=args.patchgan_disc_factor, # reduce D update speed
+ disc_weight=args.patchgan_disc_g_weight).to(device)
+ else:
+ loss_class = E3DGELossClass(device, opt).to(device)
+
+ # writer = SummaryWriter() # TODO, add log dir
+
+ logger.log("training...")
+
+ TrainLoop = {
+ 'input_rec': TrainLoop3DRec,
+ 'nv_rec': TrainLoop3DRecNV,
+ # 'nv_rec_patch': TrainLoop3DRecNVPatch,
+ 'nv_rec_patch': TrainLoop3DRecNVPatchSingleForward,
+ 'nv_rec_patch_mvE': TrainLoop3DRecNVPatchSingleForwardMV,
+ 'nv_rec_patch_mvE_disc': TrainLoop3DRecNVPatchSingleForwardMVAdvLoss, # default for objaverse
+ }[args.trainer_name]
+
+ logger.log("creating TrainLoop done...")
+
+ # th._dynamo.config.verbose=True # th212 required
+ # th._dynamo.config.suppress_errors = True
+ auto_encoder.decoder.rendering_kwargs = args.rendering_kwargs
+ train_loop = TrainLoop(
+ rec_model=auto_encoder,
+ loss_class=loss_class,
+ data=data,
+ eval_data=eval_data,
+ # compile=args.compile,
+ **vars(args))
+
+ if args.inference:
+ camera = th.load('assets/objv_eval_pose.pt', map_location=dist_util.dev())
+ train_loop.eval_novelview_loop(camera=camera,
+ save_latent=args.save_latent)
+ else:
+ train_loop.run_loop()
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ seed=0,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_amp=False,
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ # test warm up pose sampling training
+ pose_warm_up_iter=-1,
+ inference=False,
+ export_latent=False,
+ save_latent=False,
+ )
+
+ defaults.update(dataset_defaults()) # type: ignore
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ # os.environ[
+ # "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+ # os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
+ # os.environ["NCCL_DEBUG"]="INFO"
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ # if os.environ['WORLD_SIZE'] > 1:
+ # args.global_rank = int(os.environ["RANK"])
+ args.gpus = th.cuda.device_count()
+
+ opts = args
+
+ args.rendering_kwargs = rendering_options_defaults(opts)
+
+ # print(args)
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/scripts/wds_create.py b/scripts/wds_create.py
new file mode 100644
index 0000000000000000000000000000000000000000..8725cc2fb607639b9d487c59d7cb4cdf84528053
--- /dev/null
+++ b/scripts/wds_create.py
@@ -0,0 +1,392 @@
+"""
+Train a diffusion model on images.
+"""
+# import imageio
+import gzip
+import random
+import json
+import sys
+import os
+import lmdb
+from tqdm import tqdm
+sys.path.append('.')
+import torch.distributed as dist
+import pickle
+import traceback
+from PIL import Image
+import torch as th
+import torch.multiprocessing as mp
+import lzma
+import webdataset as wds
+import numpy as np
+
+from torch.utils.data import DataLoader, Dataset
+import imageio.v3 as iio
+
+import argparse
+import dnnlib
+from guided_diffusion import dist_util, logger
+from guided_diffusion.script_util import (
+ args_to_dict,
+ add_dict_to_argparser,
+)
+# from nsr.train_util import TrainLoop3DRec as TrainLoop
+from nsr.train_nv_util import TrainLoop3DRecNV, TrainLoop3DRec, TrainLoop3DRecNVPatch
+from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
+from datasets.shapenet import load_data, load_data_for_lmdb, load_eval_data, load_memory_data
+from nsr.losses.builder import E3DGELossClass
+from datasets.eg3d_dataset import init_dataset_kwargs
+
+# from .lmdb_create import encode_and_compress_image
+
+def encode_and_compress_image(inp_array, is_image=False, compress=True):
+ # Read the image using imageio
+ # image = imageio.v3.imread(image_path)
+
+ # Convert the image to bytes
+ # with io.BytesIO() as byte_buffer:
+ # imageio.imsave(byte_buffer, image, format="png")
+ # image_bytes = byte_buffer.getvalue()
+ if is_image:
+ inp_bytes = iio.imwrite("", inp_array, extension=".png")
+ else:
+ inp_bytes = inp_array.tobytes()
+
+ # Compress the image data using gzip
+ if compress:
+ compressed_data = gzip.compress(inp_bytes)
+ return compressed_data
+ else:
+ return inp_bytes
+
+
+
+
+from pdb import set_trace as st
+import bz2
+
+# th.backends.cuda.matmul.allow_tf32 = True # https://huggingface.co/docs/diffusers/optimization/fp16
+
+
+
+def training_loop(args):
+ # def training_loop(args):
+ dist_util.setup_dist(args)
+ # th.autograd.set_detect_anomaly(True) # type: ignore
+ th.autograd.set_detect_anomaly(False) # type: ignore
+ # https://blog.csdn.net/qq_41682740/article/details/126304613
+
+ SEED = args.seed
+
+ # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count())
+ logger.log(f"{args.local_rank=} init complete, seed={SEED}")
+ th.cuda.set_device(args.local_rank)
+ th.cuda.empty_cache()
+
+ # * deterministic algorithms flags
+ th.cuda.manual_seed_all(SEED)
+ np.random.seed(SEED)
+ random.seed(SEED)
+
+ # logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
+ logger.configure(dir=args.logdir)
+
+ logger.log("creating encoder and NSR decoder...")
+ # device = dist_util.dev()
+ device = th.device("cuda", args.local_rank)
+
+ # shared eg3d opts
+ opts = eg3d_options_default()
+
+ if args.sr_training:
+ args.sr_kwargs = dnnlib.EasyDict(
+ channel_base=opts.cbase,
+ channel_max=opts.cmax,
+ fused_modconv_default='inference_only',
+ use_noise=True
+ ) # ! close noise injection? since noise_mode='none' in eg3d
+
+
+ if args.objv_dataset:
+ from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_data_for_lmdb
+ else: # shapenet
+ from datasets.shapenet import load_data, load_eval_data, load_memory_data, load_data_for_lmdb
+
+ # auto_encoder = create_3DAE_model(
+ # **args_to_dict(args,
+ # encoder_and_nsr_defaults().keys()))
+ # auto_encoder.to(device)
+ # auto_encoder.train()
+
+ logger.log("creating data loader...")
+ # data = load_data(
+ # st()
+ # if args.overfitting:
+ # data = load_memory_data(
+ # file_path=args.data_dir,
+ # batch_size=args.batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=args.depth_lambda > 0
+ # load_depth=True # for evaluation
+ # )
+ # else:
+ if args.cfg in ('afhq', 'ffhq'):
+ # ! load data
+ logger.log("creating eg3d data loader...")
+ training_set_kwargs, dataset_name = init_dataset_kwargs(data=args.data_dir,
+ class_name='datasets.eg3d_dataset.ImageFolderDatasetLMDB',
+ reso_gt=args.image_size) # only load pose here
+ # if args.cond and not training_set_kwargs.use_labels:
+ # raise Exception('check here')
+
+ # training_set_kwargs.use_labels = args.cond
+ training_set_kwargs.use_labels = True
+ training_set_kwargs.xflip = False
+ training_set_kwargs.random_seed = SEED
+ # training_set_kwargs.max_size = args.dataset_size
+ # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
+
+ # * construct ffhq/afhq dataset
+ training_set = dnnlib.util.construct_class_by_name(
+ **training_set_kwargs) # subclass of training.dataset.Dataset
+ dataset_size = len(training_set)
+
+ # training_set_sampler = InfiniteSampler(
+ # dataset=training_set,
+ # rank=dist_util.get_rank(),
+ # num_replicas=dist_util.get_world_size(),
+ # seed=SEED)
+
+ data = DataLoader(
+ training_set,
+ shuffle=False,
+ batch_size=1,
+ num_workers=16,
+ drop_last=False,
+ # prefetch_factor=2,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+
+ else:
+ data, dataset_name, dataset_size, dataset = load_data_for_lmdb(
+ file_path=args.data_dir,
+ batch_size=args.batch_size,
+ reso=args.image_size,
+ reso_encoder=args.image_size_encoder, # 224 -> 128
+ num_workers=args.num_workers,
+ load_depth=True,
+ preprocess=None,
+ dataset_size=args.dataset_size,
+ trainer_name=args.trainer_name,
+ wds_output_path=os.path.join(logger.get_dir(), f'wds-%06d.tar')
+ # load_depth=True # for evaluation
+ )
+ # if args.pose_warm_up_iter > 0:
+ # overfitting_dataset = load_memory_data(
+ # file_path=args.data_dir,
+ # batch_size=args.batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # # load_depth=args.depth_lambda > 0
+ # load_depth=True # for evaluation
+ # )
+ # data = [data, overfitting_dataset, args.pose_warm_up_iter]
+ # eval_data = load_eval_data(
+ # file_path=args.eval_data_dir,
+ # batch_size=args.eval_batch_size,
+ # reso=args.image_size,
+ # reso_encoder=args.image_size_encoder, # 224 -> 128
+ # num_workers=args.num_workers,
+ # load_depth=True, # for evaluation
+ # preprocess=auto_encoder.preprocess)
+ args.img_size = [args.image_size_encoder]
+ # try dry run
+ # batch = next(data)
+ # batch = None
+
+ # logger.log("creating model and diffusion...")
+
+ # let all processes sync up before starting with a new epoch of training
+ dist_util.synchronize()
+
+ # schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
+
+ opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
+ # opt.max_depth, opt.min_depth = args.rendering_kwargs.ray_end, args.rendering_kwargs.ray_start
+ # loss_class = E3DGELossClass(device, opt).to(device)
+
+ # writer = SummaryWriter() # TODO, add log dir
+
+ logger.log("training...")
+
+ # TrainLoop = {
+ # 'input_rec': TrainLoop3DRec,
+ # 'nv_rec': TrainLoop3DRecNV,
+ # 'nv_rec_patch': TrainLoop3DRecNVPatch,
+ # }[args.trainer_name]
+
+ # TrainLoop(rec_model=auto_encoder,
+ # loss_class=loss_class,
+ # data=data,
+ # eval_data=eval_data,
+ # **vars(args)).run_loop() # ! overfitting
+
+
+ # Function to compress an image using gzip
+ # def compress_image_gzip(image_path):
+ # def encode_and_compress_image(inp_array, is_image=False):
+ # # Read the image using imageio
+ # # image = imageio.v3.imread(image_path)
+
+ # # Convert the image to bytes
+ # # with io.BytesIO() as byte_buffer:
+ # # imageio.imsave(byte_buffer, image, format="png")
+ # # image_bytes = byte_buffer.getvalue()
+ # if is_image:
+ # inp_bytes = iio.imwrite("", inp_array, extension=".png")
+ # else:
+ # inp_bytes = inp_array.tobytes()
+
+ # # Compress the image data using gzip
+ # compressed_data = gzip.compress(inp_bytes)
+
+ # return compressed_data
+
+
+ def convert_to_wds_compressed(dataset,dataset_loader, dataset_size, lmdb_path):
+ """
+ Convert a PyTorch dataset to LMDB format.
+
+ Parameters:
+ - dataset: PyTorch dataset
+ - lmdb_path: Path to store the LMDB database
+ """
+ # env = lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) # Adjust map_size based on your dataset size
+ sink = wds.ShardWriter(lmdb_path)
+
+ # with env.begin(write=True) as txn:
+
+ # with env.begin(write=True) as txn:
+ # txn.put("length".encode("utf-8"), str(dataset_size).encode("utf-8"))
+
+ for idx, sample in enumerate(tqdm(dataset_loader)):
+ pass
+ # remove the batch index of returned dict sample
+
+ sample = {
+ k:v.squeeze(0).cpu().numpy() if isinstance(v, th.Tensor) else v[0] for k, v in sample.items()
+ # k:v.cpu().numpy() if isinstance(v, torch.Tensor) else v for k, v in sample.items()
+ }
+
+ # sample = dataset_loader[idx]
+ compressed_sample = {}
+ for k, v in sample.items():
+
+ # key = f'{idx}-{k}'.encode('utf-8')
+
+ if 'img' in k: # only bytes required? laod the 512 depth bytes only.
+ v = encode_and_compress_image(v, is_image=True, compress=False)
+ # elif 'depth' in k:
+ elif isinstance(v, str):
+ v = v.encode('utf-8') # caption
+ else: # regular bytes encoding
+ v = encode_and_compress_image(v.astype(np.float32), is_image=False, compress=False)
+
+ compressed_sample[k] = v
+
+ sink.write({
+ "__key__": "sample%08d" % idx,
+ # **{f'{k}.pyd': v for k, v in compressed_sample.items()}, # store as pickle, already compressed
+ 'sample.pyd': compressed_sample
+ })
+
+ # break
+ if idx > 100:
+ break
+
+ sink.close()
+
+
+ # convert_to_lmdb(data, os.path.join(logger.get_dir(), dataset_name)) convert_to_lmdb_compressed(data, os.path.join(logger.get_dir(), dataset_name))
+ # convert_to_lmdb_compressed(data, os.path.join(logger.get_dir()), dataset_size)
+ convert_to_wds_compressed(dataset, data, dataset_size, os.path.join(logger.get_dir(), f'wds-%06d.tar'))
+
+
+
+def create_argparser(**kwargs):
+ # defaults.update(model_and_diffusion_defaults())
+
+ defaults = dict(
+ seed=0,
+ dataset_size=-1,
+ trainer_name='input_rec',
+ use_amp=False,
+ overfitting=False,
+ num_workers=4,
+ image_size=128,
+ image_size_encoder=224,
+ iterations=150000,
+ anneal_lr=False,
+ lr=5e-5,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ batch_size=1,
+ eval_batch_size=12,
+ microbatch=-1, # -1 disables microbatches
+ ema_rate="0.9999", # comma-separated list of EMA values
+ log_interval=50,
+ eval_interval=2500,
+ save_interval=10000,
+ resume_checkpoint="",
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ data_dir="",
+ eval_data_dir="",
+ # load_depth=False, # TODO
+ logdir="/mnt/lustre/yslan/logs/nips23/",
+ # test warm up pose sampling training
+ objv_dataset=False,
+ pose_warm_up_iter=-1,
+ )
+
+ defaults.update(encoder_and_nsr_defaults()) # type: ignore
+ defaults.update(loss_defaults())
+
+ parser = argparse.ArgumentParser()
+ add_dict_to_argparser(parser, defaults)
+
+ return parser
+
+
+if __name__ == "__main__":
+ # os.environ[
+ # "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
+ # os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
+ # os.environ["NCCL_DEBUG"]="INFO"
+
+ args = create_argparser().parse_args()
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.gpus = th.cuda.device_count()
+
+ opts = args
+
+ args.rendering_kwargs = rendering_options_defaults(opts)
+
+ # print(args)
+ with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Launch processes.
+ print('Launching processes...')
+
+ try:
+ training_loop(args)
+ # except KeyboardInterrupt as e:
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ dist_util.cleanup() # clean port and socket when ctrl+c
diff --git a/shell_scripts/final_release/evaluation/.gitkeep b/shell_scripts/final_release/evaluation/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/shell_scripts/final_release/inference/sample_ffhq_t23d.sh b/shell_scripts/final_release/inference/sample_ffhq_t23d.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8ab576ed0d4265e98fa4dd3feebc6e314631b3b8
--- /dev/null
+++ b/shell_scripts/final_release/inference/sample_ffhq_t23d.sh
@@ -0,0 +1,167 @@
+set -x
+
+lpips_lambda=0.8
+
+image_size=128
+image_size_encoder=224
+
+patch_size=14
+
+batch_size=1
+num_samples=1
+
+dataset_name=ffhq
+
+
+DATASET_FLAGS="
+ --data_dir /mnt/yslan/datasets/cache/lmdb_debug/${dataset_name} \
+"
+
+lr=2e-5 # for improved-diffusion unet
+kl_lambda=0
+vit_lr=1e-5 # for improved-diffusion unet
+
+encoder_lr=$vit_lr
+vit_decoder_lr=$vit_lr
+conv_lr=0.0005
+triplane_decoder_lr=$conv_lr
+super_resolution_lr=$conv_lr
+
+scale_clip_encoding=18.4
+triplane_scaling_divider=1
+
+CKPT_FLAGS="
+--resume_checkpoint checkpoints/ffhq/model_joint_denoise_rec_model1580000.pt \
+"
+
+LR_FLAGS="--encoder_lr $encoder_lr \
+--vit_decoder_lr $vit_decoder_lr \
+--triplane_decoder_lr $triplane_decoder_lr \
+--super_resolution_lr $super_resolution_lr \
+--lr $lr"
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --image_size_encoder $image_size_encoder \
+ --image_size $image_size \
+ --dino_version v2 \
+ --sr_training False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --image_size $image_size \
+ --kl_lambda ${kl_lambda} \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder True \
+ --fg_mse True \
+ --bg_lamdba 0.01 \
+ "
+# --vae_p 1 \
+
+
+DDPM_MODEL_FLAGS="
+--learn_sigma False \
+--num_heads 8 \
+--num_res_blocks 2 \
+--num_channels 320 \
+--attention_resolutions "4,2,1" \
+--use_spatial_transformer True \
+--transformer_depth 1 \
+--context_dim 768 \
+"
+
+
+DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear \
+--use_kl False \
+--use_amp False \
+--triplane_scaling_divider ${triplane_scaling_divider} \
+--trainer_name vpsde_crossattn \
+--mixed_prediction True \
+--denoise_in_channels 12 \
+--denoise_out_channels 12 \
+--diffusion_input_size 32 \
+--p_rendering_loss False \
+--pred_type v \
+--predict_v True \
+"
+
+DDIM_FLAGS="
+--timestep_respacing ddim250 \
+--use_ddim True \
+--unconditional_guidance_scale 6.5 \
+"
+
+# not used here
+CONTROL_FLAGS="
+--train_vae False \
+--create_controlnet False \
+--control_key img_sr \
+"
+
+prompt="a middle aged woman with brown hair, wearing glasses."
+
+logdir="./logs/LSGM/inference/t23d/${dataset_name}/crossattn-v1-ddim250/T23D_test/woman_glass-newcls"
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--alpha_lambda 1 \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--triplane_in_chans 32 \
+--decoder_output_dim 32 \
+--ae_classname vit.vit_triplane.VAE_LDM_V4_vit3D_v3_conv3D_depth2_xformer_mha_PEinit_2d_sincos_uvit_RodinRollOutConv_4x4_lite_mlp_unshuffle_4XC_final \
+"
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+NUM_GPUS=1
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export CUDA_VISIBLE_DEVICES=6
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --master_port=0 \
+ --rdzv_backend=c10d \
+ --rdzv-endpoint=localhost:33385 \
+ --nnodes 1 \
+ scripts/vit_triplane_diffusion_sample.py \
+ --num_workers 4 \
+ --depth_lambda 0 \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DIFFUSION_FLAGS} \
+ ${CONTROL_FLAGS} \
+ ${DDPM_MODEL_FLAGS} \
+ ${DATASET_FLAGS} \
+ ${CKPT_FLAGS} \
+ ${LR_FLAGS} \
+ --lpips_lambda $lpips_lambda \
+ --overfitting False \
+ --load_pretrain_encoder True \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 2500 \
+ --decomposed True \
+ --logdir $logdir \
+ --cfg ffhq \
+ --patch_size ${patch_size} \
+ --eval_batch_size ${batch_size} \
+ --prompt "$prompt" \
+ --interval 5 \
+ --save_img True \
+ --num_samples ${num_samples} \
+ --use_train_trajectory False \
+ --normalize_clip_encoding True \
+ --scale_clip_encoding ${scale_clip_encoding} \
+ --overwrite_diff_inp_size 16 \
+ --use_lmdb True \
+ ${DDIM_FLAGS} \
\ No newline at end of file
diff --git a/shell_scripts/final_release/inference/sample_obajverse.sh b/shell_scripts/final_release/inference/sample_obajverse.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1a338cad41c2ab9d4f78ce177310a4f03861b95f
--- /dev/null
+++ b/shell_scripts/final_release/inference/sample_obajverse.sh
@@ -0,0 +1,217 @@
+set -x
+
+lpips_lambda=0.8
+
+image_size=128 # final rendered resolution
+image_size_encoder=256
+
+patch_size=14
+
+
+# ! 29GB -> 37GB
+
+batch_size=4 # BS=256 is enough
+microbatch=${batch_size}
+
+num_samples=$((50/${batch_size})) # follow ssdnerf and functa
+
+cfg_dropout_prob=0.1 # SD config
+
+unconditional_guidance_scale=6.5
+
+num_workers=0
+
+eval_data_dir="NONE"
+shards_lst=/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/diff_shards_lst_ani.txt
+eval_shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_animals_lst.txt"
+
+data_dir="NONE"
+DATASET_FLAGS="
+ --data_dir ${data_dir} \
+ --eval_shards_lst ${eval_shards_lst} \
+ --shards_lst ${shards_lst} \
+"
+
+lr=2e-5 # for official DiT, lr=1e-4 for BS=256
+kl_lambda=0
+vit_lr=1e-5 # for improved-diffusion unet
+ce_lambda=0.5 # ?
+conv_lr=5e-5
+alpha_lambda=1
+scale_clip_encoding=1
+
+triplane_scaling_divider=0.88
+
+# prompt="A blue plastic chair."
+prompt="A sailboat with mast."
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $vit_lr \
+ --vit_decoder_lr $vit_lr \
+ --lpips_lambda $lpips_lambda \
+ --triplane_decoder_lr $conv_lr \
+ --super_resolution_lr $conv_lr \
+ --lr $lr \
+ --kl_lambda ${kl_lambda} \
+ --bg_lamdba 0.01 \
+ --alpha_lambda ${alpha_lambda} \
+"
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --microbatch ${microbatch} \
+ --image_size_encoder $image_size_encoder \
+ --image_size $image_size \
+ --dino_version mv-sd-dit \
+ --sr_training False \
+ --encoder_cls_token False \
+ --decoder_cls_token False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder True \
+ --decoder_load_pretrained False \
+ --fg_mse False \
+ --vae_p 2 \
+ --plucker_embedding True \
+ --encoder_in_channels 9 \
+ --arch_dit_decoder DiT2-B/2 \
+ --sd_E_ch 64 \
+ --sd_E_num_res_blocks 1 \
+ --lrm_decoder False \
+ --resume_checkpoint /home/yslan/Repo/open-source/data/model_joint_denoise_rec_model2310000.pt \
+ "
+
+
+
+DDPM_MODEL_FLAGS="
+--learn_sigma False \
+--num_heads 8 \
+--num_res_blocks 2 \
+--num_channels 320 \
+--attention_resolutions "4,2,1" \
+--use_spatial_transformer True \
+--transformer_depth 1 \
+--context_dim 768 \
+"
+# --pred_type x0 \
+# --iw_sample_p drop_all_uniform \
+# --loss_type x0 \
+
+# ! diffusion steps and noise schedule not used, since the continuous diffusion is adopted.
+DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear \
+--use_kl False \
+--use_amp False \
+--triplane_scaling_divider ${triplane_scaling_divider} \
+--trainer_name vpsde_crossattn_objv \
+--mixed_prediction False \
+--train_vae False \
+--denoise_in_channels 4 \
+--denoise_out_channels 4 \
+--diffusion_input_size 32 \
+--diffusion_ce_anneal True \
+--create_controlnet False \
+--p_rendering_loss False \
+--pred_type v \
+--predict_v True \
+--create_dit False \
+--train_vae False \
+--use_eos_feature False \
+--roll_out True \
+"
+
+DDIM_FLAGS="
+--timestep_respacing ddim250 \
+--use_ddim True \
+--unconditional_guidance_scale ${unconditional_guidance_scale} \
+"
+
+
+logdir=./logs/LSGM/inference/t23d/Objaverse/cfg=${unconditional_guidance_scale}/fixing-DDIM/231w/mast3
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--triplane_in_chans 32 \
+--decoder_output_dim 3 \
+"
+# --resume_checkpoint /mnt/lustre/yslan/logs/nips23/LSGM/ssd/chair/scaling/entropy/kl0_ema0.9999_vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_weightingv0_lsgm_unfreezeD_0.01_gradclip_nocesquare_clipH@0_noallAMP_dataset500/model_joint_denoise_rec_model0910000.pt \
+
+
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+NUM_GPUS=1
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export OPENCV_IO_ENABLE_OPENEXR=1
+export NCCL_IB_GID_INDEX=3 # https://github.com/huggingface/accelerate/issues/314#issuecomment-1821973930
+# export CUDA_VISIBLE_DEVICES=0,1,2
+
+# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
+# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
+# export CUDA_VISIBLE_DEVICES=7
+# export CUDA_VISIBLE_DEVICES=3,7
+# export CUDA_VISIBLE_DEVICES=3,4,5
+# export CUDA_VISIBLE_DEVICES=0,1,2,3
+export CUDA_VISIBLE_DEVICES=0
+# export CUDA_VISIBLE_DEVICES=4,5,6
+# export CUDA_VISIBLE_DEVICES=6,7
+# export CUDA_VISIBLE_DEVICES=7
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --nnodes 1 \
+ --rdzv-endpoint=localhost:24369 \
+ scripts/vit_triplane_diffusion_sample_objaverse.py \
+ --num_workers ${num_workers} \
+ --eval_data_dir $eval_data_dir \
+ --depth_lambda 0 \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DATASET_FLAGS} \
+ ${DIFFUSION_FLAGS} \
+ ${DDPM_MODEL_FLAGS} \
+ ${DDIM_FLAGS} \
+ --overfitting False \
+ --load_pretrain_encoder False \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 5000 \
+ --decomposed True \
+ --logdir $logdir \
+ --cfg objverse_tuneray_aug_resolution_64_64_auto \
+ --patch_size ${patch_size} \
+ --eval_batch_size 1 \
+ ${LR_FLAGS} \
+ --ce_lambda ${ce_lambda} \
+ --negative_entropy_lambda ${ce_lambda} \
+ --triplane_fg_bg False \
+ --grad_clip True \
+ --interval 5 \
+ --normalize_clip_encoding True \
+ --scale_clip_encoding ${scale_clip_encoding} \
+ --objv_dataset True \
+ --cfg_dropout_prob ${cfg_dropout_prob} \
+ --cond_key caption \
+ --enable_mixing_normal False \
+ --use_lmdb_compressed False \
+ --use_lmdb False \
+ --load_wds_diff True \
+ --mv_input True \
+ --compile False \
+ --prompt "$prompt" \
+ --num_samples ${num_samples} \
+ --use_wds False \
\ No newline at end of file
diff --git a/shell_scripts/final_release/inference/sample_shapenet_car_t23d.sh b/shell_scripts/final_release/inference/sample_shapenet_car_t23d.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3185d40fde11fdde6671b061eef70b800f2013cf
--- /dev/null
+++ b/shell_scripts/final_release/inference/sample_shapenet_car_t23d.sh
@@ -0,0 +1,175 @@
+set -x
+
+lpips_lambda=0.8
+
+image_size=128
+image_size_encoder=224
+
+patch_size=14
+
+batch_size=4
+num_samples=40
+
+dataset_name=car
+
+eval_data_dir=/mnt/lustre/yslan/3D_Dataset/get3d/${dataset_name}_test
+DATASET_FLAGS="
+ --data_dir /mnt/cache/yslan/get3d/lmdb_debug/${dataset_name} \
+ --eval_data_dir $eval_data_dir \
+"
+
+lr=2e-5 # for improved-diffusion unet
+kl_lambda=0
+vit_lr=1e-5 # for improved-diffusion unet
+
+encoder_lr=$vit_lr
+vit_decoder_lr=$vit_lr
+conv_lr=0.0005
+triplane_decoder_lr=$conv_lr
+super_resolution_lr=$conv_lr
+
+scale_clip_encoding=18.4
+unconditional_guidance_scale=1.0
+triplane_scaling_divider=1
+
+CKPT_FLAGS="
+--resume_checkpoint checkpoints/shapenet/car/model_joint_denoise_rec_model1700000.pt
+"
+
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $encoder_lr \
+--vit_decoder_lr $vit_decoder_lr \
+--triplane_decoder_lr $triplane_decoder_lr \
+--super_resolution_lr $super_resolution_lr \
+--lr $lr"
+
+
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --image_size_encoder $image_size_encoder \
+ --image_size $image_size \
+ --dino_version v2 \
+ --sr_training False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --image_size $image_size \
+ --kl_lambda ${kl_lambda} \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder True \
+ --fg_mse True \
+ --vae_p 2 \
+ --bg_lamdba 0.01 \
+ "
+
+
+DDPM_MODEL_FLAGS="
+--learn_sigma False \
+--num_heads 8 \
+--num_res_blocks 2 \
+--num_channels 320 \
+--attention_resolutions "4,2,1" \
+--use_spatial_transformer True \
+--transformer_depth 1 \
+--context_dim 768 \
+"
+
+
+DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear \
+--use_kl False \
+--use_amp False \
+--triplane_scaling_divider ${triplane_scaling_divider} \
+--trainer_name vpsde_crossattn \
+--mixed_prediction True \
+--denoise_in_channels 12 \
+--denoise_out_channels 12 \
+--diffusion_input_size 32 \
+--p_rendering_loss False \
+--pred_type v \
+--predict_v True \
+"
+
+DDIM_FLAGS="
+--timestep_respacing ddim250 \
+--use_ddim True \
+--unconditional_guidance_scale ${unconditional_guidance_scale} \
+"
+
+# not used here
+CONTROL_FLAGS="
+--train_vae False \
+--create_controlnet False \
+--control_key img_sr \
+"
+
+
+prompt="a SUV car"
+# logdir="/mnt/lustre/yslan/logs/nips23/LSGM/cldm/inference/t23d/${dataset_name}/cfg=${unconditional_guidance_scale}/yellow_bus-4/"
+logdir="./logs/LSGM/inference/t23d/${dataset_name}/cfg=${unconditional_guidance_scale}/a-suv-car-mesh/"
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--alpha_lambda 1 \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--decoder_output_dim 32 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn \
+"
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+NUM_GPUS=1
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export OMP_NUM_THREADS=12
+export CUDA_VISIBLE_DEVICES=4
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --master_port=0 \
+ --rdzv_backend=c10d \
+ --nnodes 1 \
+ scripts/vit_triplane_diffusion_sample.py \
+ --num_workers 4 \
+ --depth_lambda 0 \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DIFFUSION_FLAGS} \
+ ${CONTROL_FLAGS} \
+ ${DDPM_MODEL_FLAGS} \
+ ${DATASET_FLAGS} \
+ ${CKPT_FLAGS} \
+ ${LR_FLAGS} \
+ ${DDIM_FLAGS} \
+ --lpips_lambda $lpips_lambda \
+ --overfitting False \
+ --load_pretrain_encoder True \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 2500 \
+ --decomposed True \
+ --logdir $logdir \
+ --cfg shapenet_tuneray_aug_resolution_64_64_nearestSR \
+ --ray_start 0.6 \
+ --ray_end 1.8 \
+ --patch_size ${patch_size} \
+ --eval_batch_size ${batch_size} \
+ --prompt "$prompt" \
+ --interval 5 \
+ --save_img True \
+ --num_samples ${num_samples} \
+ --use_train_trajectory False \
+ --normalize_clip_encoding True \
+ --export_mesh True \
+ --scale_clip_encoding ${scale_clip_encoding} \
diff --git a/shell_scripts/final_release/inference/sample_shapenet_chair_t23d.sh b/shell_scripts/final_release/inference/sample_shapenet_chair_t23d.sh
new file mode 100644
index 0000000000000000000000000000000000000000..125cf4ac751de2ba7467f0a333a2c08bc78931e2
--- /dev/null
+++ b/shell_scripts/final_release/inference/sample_shapenet_chair_t23d.sh
@@ -0,0 +1,180 @@
+set -x
+
+lpips_lambda=0.8
+
+image_size=128
+# image_size=64
+image_size_encoder=224
+
+patch_size=14
+
+batch_size=4
+# batch_size=2
+# batch_size=1
+# num_samples=$((1024/${batch_size})) # follow ssdnerf and functa
+num_samples=50
+# batch_size=80 # 13GB
+# batch_size=80 # 13GB
+
+dataset_name=chair
+
+eval_data_dir=/mnt/lustre/yslan/3D_Dataset/get3d/${dataset_name}_test
+DATASET_FLAGS="
+ --data_dir /mnt/cache/yslan/get3d/lmdb_debug/${dataset_name} \
+ --eval_data_dir $eval_data_dir \
+"
+
+lr=2e-5 # for improved-diffusion unet
+kl_lambda=0
+vit_lr=1e-5 # for improved-diffusion unet
+
+encoder_lr=$vit_lr
+vit_decoder_lr=$vit_lr
+conv_lr=0.0005
+triplane_decoder_lr=$conv_lr
+super_resolution_lr=$conv_lr
+
+scale_clip_encoding=18.4
+triplane_scaling_divider=1
+unconditional_guidance_scale=1.0
+
+CKPT_FLAGS="
+--resume_checkpoint checkpoints/shapenet/chair/model_joint_denoise_rec_model2030000.pt \
+"
+
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $encoder_lr \
+--vit_decoder_lr $vit_decoder_lr \
+--triplane_decoder_lr $triplane_decoder_lr \
+--super_resolution_lr $super_resolution_lr \
+--lr $lr"
+
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --image_size_encoder $image_size_encoder \
+ --image_size $image_size \
+ --dino_version v2 \
+ --sr_training False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --image_size $image_size \
+ --kl_lambda ${kl_lambda} \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder True \
+ --fg_mse True \
+ --vae_p 2 \
+ --bg_lamdba 0.01 \
+ "
+
+
+DDPM_MODEL_FLAGS="
+--learn_sigma False \
+--num_heads 8 \
+--num_res_blocks 2 \
+--num_channels 320 \
+--attention_resolutions "4,2,1" \
+--use_spatial_transformer True \
+--transformer_depth 1 \
+--context_dim 768 \
+"
+
+
+DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear \
+--use_kl False \
+--use_amp False \
+--triplane_scaling_divider ${triplane_scaling_divider} \
+--trainer_name vpsde_crossattn \
+--mixed_prediction True \
+--denoise_in_channels 12 \
+--denoise_out_channels 12 \
+--diffusion_input_size 32 \
+--p_rendering_loss False \
+--pred_type v \
+--predict_v True \
+"
+
+DDIM_FLAGS="
+--timestep_respacing ddim250 \
+--use_ddim True \
+--unconditional_guidance_scale ${unconditional_guidance_scale} \
+"
+
+# not used here
+CONTROL_FLAGS="
+--train_vae False \
+--create_controlnet False \
+--control_key img_sr \
+"
+
+
+prompt="a gaming chair"
+
+logdir="./logs/LSGM/inference/t23d/${dataset_name}/cfg=${unconditional_guidance_scale}/gaming_chair/"
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--alpha_lambda 1 \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--decoder_output_dim 32 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn \
+"
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+NUM_GPUS=1
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export OMP_NUM_THREADS=12
+export CUDA_VISIBLE_DEVICES=1
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --master_port=0 \
+ --rdzv_backend=c10d \
+ --rdzv-endpoint=localhost:23325 \
+ --nnodes 1 \
+ scripts/vit_triplane_diffusion_sample.py \
+ --num_workers 4 \
+ --depth_lambda 0 \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DIFFUSION_FLAGS} \
+ ${CONTROL_FLAGS} \
+ ${DDPM_MODEL_FLAGS} \
+ ${DATASET_FLAGS} \
+ ${CKPT_FLAGS} \
+ ${LR_FLAGS} \
+ ${DDIM_FLAGS} \
+ --lpips_lambda $lpips_lambda \
+ --overfitting False \
+ --load_pretrain_encoder True \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 2500 \
+ --decomposed True \
+ --logdir $logdir \
+ --cfg shapenet_tuneray_aug_resolution_64_64_nearestSR \
+ --ray_start 0.6 \
+ --ray_end 1.8 \
+ --patch_size ${patch_size} \
+ --eval_batch_size 50 \
+ --prompt "$prompt" \
+ --interval 5 \
+ --save_img True \
+ --num_samples ${num_samples} \
+ --export_mesh True \
+ --normalize_clip_encoding True \
+ --scale_clip_encoding ${scale_clip_encoding} \
diff --git a/shell_scripts/final_release/inference/sample_shapenet_plane_t23d.sh b/shell_scripts/final_release/inference/sample_shapenet_plane_t23d.sh
new file mode 100644
index 0000000000000000000000000000000000000000..16eb16af5364230ee8612c65269bb3156893da11
--- /dev/null
+++ b/shell_scripts/final_release/inference/sample_shapenet_plane_t23d.sh
@@ -0,0 +1,176 @@
+set -x
+
+lpips_lambda=0.8
+
+image_size=128
+image_size_encoder=224
+
+patch_size=14
+
+batch_size=4
+num_samples=40
+
+
+dataset_name=plane
+
+eval_data_dir=/mnt/lustre/yslan/3D_Dataset/get3d/chair_test
+DATASET_FLAGS="
+ --data_dir /mnt/cache/yslan/get3d/lmdb_debug/${dataset_name} \
+ --eval_data_dir $eval_data_dir \
+"
+
+lr=2e-5 # for improved-diffusion unet
+kl_lambda=0
+vit_lr=1e-5 # for improved-diffusion unet
+
+encoder_lr=$vit_lr
+vit_decoder_lr=$vit_lr
+conv_lr=0.0005
+triplane_decoder_lr=$conv_lr
+super_resolution_lr=$conv_lr
+
+scale_clip_encoding=18.4
+triplane_scaling_divider=1
+unconditional_guidance_scale=1.0
+
+
+CKPT_FLAGS="
+--resume_checkpoint checkpoints/shapenet/plane/model_joint_denoise_rec_model1770000.pt \
+"
+
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $encoder_lr \
+--vit_decoder_lr $vit_decoder_lr \
+--triplane_decoder_lr $triplane_decoder_lr \
+--super_resolution_lr $super_resolution_lr \
+--lr $lr"
+
+
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --image_size_encoder $image_size_encoder \
+ --image_size $image_size \
+ --dino_version v2 \
+ --sr_training False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --image_size $image_size \
+ --kl_lambda ${kl_lambda} \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder True \
+ --fg_mse True \
+ --vae_p 2 \
+ --bg_lamdba 0.01 \
+ "
+
+
+DDPM_MODEL_FLAGS="
+--learn_sigma False \
+--num_heads 8 \
+--num_res_blocks 2 \
+--num_channels 320 \
+--attention_resolutions "4,2,1" \
+--use_spatial_transformer True \
+--transformer_depth 1 \
+--context_dim 768 \
+"
+
+
+DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear \
+--use_kl False \
+--use_amp False \
+--triplane_scaling_divider ${triplane_scaling_divider} \
+--trainer_name vpsde_crossattn \
+--mixed_prediction True \
+--denoise_in_channels 12 \
+--denoise_out_channels 12 \
+--diffusion_input_size 32 \
+--p_rendering_loss False \
+--pred_type v \
+--predict_v True \
+"
+
+DDIM_FLAGS="
+--timestep_respacing ddim250 \
+--use_ddim True \
+--unconditional_guidance_scale ${unconditional_guidance_scale} \
+"
+
+# not used here
+CONTROL_FLAGS="
+--train_vae False \
+--create_controlnet False \
+--control_key img_sr \
+"
+
+prompt="a star war Tie Fighter"
+
+logdir="./logs/LSGM/inference/t23d/${dataset_name}/cfg=${unconditional_guidance_scale}/star_war_fighter-reproduce/"
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--alpha_lambda 1 \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--decoder_output_dim 32 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn \
+"
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+NUM_GPUS=1
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export OMP_NUM_THREADS=12
+export CUDA_VISIBLE_DEVICES=6
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --master_port=0 \
+ --rdzv_backend=c10d \
+ --rdzv-endpoint=localhost:23525 \
+ --nnodes 1 \
+ scripts/vit_triplane_diffusion_sample.py \
+ --num_workers 4 \
+ --depth_lambda 0 \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DIFFUSION_FLAGS} \
+ ${CONTROL_FLAGS} \
+ ${DDPM_MODEL_FLAGS} \
+ ${DATASET_FLAGS} \
+ ${CKPT_FLAGS} \
+ ${LR_FLAGS} \
+ ${DDIM_FLAGS} \
+ --lpips_lambda $lpips_lambda \
+ --overfitting False \
+ --load_pretrain_encoder True \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 2500 \
+ --decomposed True \
+ --logdir $logdir \
+ --cfg shapenet_tuneray_aug_resolution_64_64_nearestSR \
+ --ray_start 0.6 \
+ --ray_end 1.8 \
+ --patch_size ${patch_size} \
+ --eval_batch_size ${batch_size} \
+ --prompt "$prompt" \
+ --interval 5 \
+ --save_img True \
+ --num_samples ${num_samples} \
+ --export_mesh True \
+ --normalize_clip_encoding True \
+ --scale_clip_encoding ${scale_clip_encoding} \
diff --git a/shell_scripts/final_release/inference/vae_reconstruction.sh b/shell_scripts/final_release/inference/vae_reconstruction.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5dbc7c9b24dc42cc7addef0670f9dd4c9f46da7e
--- /dev/null
+++ b/shell_scripts/final_release/inference/vae_reconstruction.sh
@@ -0,0 +1,154 @@
+set -x
+# vit_decoder_lr=1.001
+
+# lpips_lambda=0.8
+# lpips_lambda=2.0 # ! lrm
+lpips_lambda=2.0
+# lpips_lambda=0.0
+ssim_lambda=0.
+l1_lambda=0. # following gaussian splatting
+l2_lambda=1 # ! use_conf_map
+
+NUM_GPUS=1
+
+
+image_size=128 # final rendered resolution
+
+num_workers=3 # for eval only
+image_size_encoder=256
+patch_size=14
+kl_lambda=1.0e-06
+patch_rendering_resolution=56 #
+batch_size=4 #
+microbatch=4 #
+
+
+# use g-buffer Objaverse data path here. check readme for more details.
+data_dir=./assets/Objaverse/
+
+
+DATASET_FLAGS="
+ --data_dir "NONE" \
+ --eval_data_dir ${data_dir} \
+"
+
+conv_lr=2e-4
+lr=1e-4 #
+
+vit_decoder_lr=$lr
+encoder_lr=${conv_lr} # scaling version , could be larger when multi-nodes
+triplane_decoder_lr=$conv_lr
+super_resolution_lr=$conv_lr
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $encoder_lr \
+--vit_decoder_lr $vit_decoder_lr \
+--triplane_decoder_lr $triplane_decoder_lr \
+--super_resolution_lr $super_resolution_lr \
+--lr $lr"
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --microbatch ${microbatch} \
+ --image_size_encoder $image_size_encoder \
+ --dino_version mv-sd-dit \
+ --sr_training False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --image_size $image_size \
+ --kl_lambda ${kl_lambda} \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder False \
+ --fg_mse True \
+ --bg_lamdba 1.0 \
+ --lpips_delay_iter 100 \
+ --sr_delay_iter 25000 \
+ --kl_anneal True \
+ --symmetry_loss False \
+ --vae_p 2 \
+ --plucker_embedding True \
+ --encoder_in_channels 10 \
+ --arch_dit_decoder DiT2-B/2 \
+ --sd_E_ch 64 \
+ --sd_E_num_res_blocks 1 \
+ --lrm_decoder False \
+ --resume_checkpoint checkpoints/objaverse/model_rec1680000.pt \
+ "
+
+# the path to save the extracted latents.
+logdir="./logs/vae-reconstruction/objav/vae/infer-latents"
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--alpha_lambda 1.0 \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--decoder_output_dim 3 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder_S \
+"
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+# localedef -c -f UTF-8 -i en_US en_US.UTF-8
+export LC_ALL=en_US.UTF-8
+
+export OPENCV_IO_ENABLE_OPENEXR=1
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export NCCL_IB_GID_INDEX=3 # https://github.com/huggingface/accelerate/issues/314#issuecomment-1821973930
+export CUDA_VISIBLE_DEVICES=0
+
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --nnodes=1 \
+ --rdzv-endpoint=${HOST_NODE_ADDR} \
+ --rdzv_backend=c10d \
+ scripts/vit_triplane_train.py \
+ --trainer_name nv_rec_patch_mvE \
+ --num_workers ${num_workers} \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DATASET_FLAGS} \
+ --lpips_lambda $lpips_lambda \
+ --overfitting False \
+ --load_pretrain_encoder False \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 250000000 \
+ --decomposed True \
+ --logdir $logdir \
+ --decoder_load_pretrained False \
+ --cfg objverse_tuneray_aug_resolution_64_64_auto \
+ --patch_size ${patch_size} \
+ --use_amp False \
+ --eval_batch_size 4 \
+ ${LR_FLAGS} \
+ --l1_lambda ${l1_lambda} \
+ --l2_lambda ${l2_lambda} \
+ --ssim_lambda ${ssim_lambda} \
+ --depth_smoothness_lambda 0 \
+ --use_conf_map False \
+ --objv_dataset True \
+ --depth_lambda 0.5 \
+ --patch_rendering_resolution ${patch_rendering_resolution} \
+ --use_lmdb_compressed False \
+ --use_lmdb False \
+ --mv_input True \
+ --inference True \
+ --split_chunk_input False \
+ --use_wds False \
+ --four_view_for_latent True \
+ --append_depth True \
+ --save_latent True \
+ --shuffle_across_cls True \
diff --git a/shell_scripts/final_release/train/stage-1-vae/Objaverse/mv-75k-addDepth_disc.sh b/shell_scripts/final_release/train/stage-1-vae/Objaverse/mv-75k-addDepth_disc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..58f6aa83e9350f2c2e8d065b2e5ed9d31f4d273a
--- /dev/null
+++ b/shell_scripts/final_release/train/stage-1-vae/Objaverse/mv-75k-addDepth_disc.sh
@@ -0,0 +1,234 @@
+set -x
+# vit_decoder_lr=1.001
+
+# lpips_lambda=0.8
+# lpips_lambda=2.0 # ! lrm
+lpips_lambda=2.0
+# lpips_lambda=0.0
+ssim_lambda=0.
+l1_lambda=0. # following gaussian splatting
+l2_lambda=1 # ! use_conf_map
+# patchgan_disc=0.002
+
+patchgan_disc_factor=0.1
+patchgan_disc_g_weight=0.02
+
+# NUM_GPUS=4
+# NUM_GPUS=8
+# NUM_GPUS=3
+NUM_GPUS=8
+# NUM_GPUS=1
+# NUM_GPUS=7
+# NUM_GPUS=5
+# NUM_GPUS=6
+
+image_size=128 # final rendered resolution
+# image_size=64 # final rendered resolution
+# image_size=80 # ! to alleviate memory issue
+# image_size=128 # final rendered resolution
+
+num_workers=0 # much faster, why?
+# num_workers=3 #
+image_size_encoder=256
+patch_size=14
+kl_lambda=1.0e-06
+
+# patch_rendering_resolution=64
+# patch_rendering_resolution=68
+# patch_rendering_resolution=58
+# patch_rendering_resolution=48
+patch_rendering_resolution=64 # ! render 8 views each, 64 crops given bs=4, 80gib
+# patch_rendering_resolution=56 # ! render 8 views each, 64 crops given bs=4, 80gib
+# patch_rendering_resolution=58 # ! OOM when BS=10
+
+batch_size=4 # ! actuall BS will double
+microbatch=32 # grad acc
+
+# batch_size=1 # ! actuall BS will double
+# microbatch=8 # grad acc
+
+# data_dir=/cpfs01/user/yangpeiqing.p/yslan/data/Objaverse
+# data_dir=/cpfs01/shared/V2V/V2V_hdd/yslan/Objaverse # hdd version for debug
+# data_dir="/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_compressed_lz4/wds-{000000..000004}.tar"
+# DATASET_FLAGS="
+# --data_dir ${data_dir} \
+# "
+# eval_data_dir=/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/download_unzipped/Furnitures # to update later
+
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_lst_4w.txt"
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_lst.txt"
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_animals_lst.txt"
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_ani_trans_lst.txt"
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_furnitures_lst.txt"
+# eval_data_dir=/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/download_unzipped/Furnitures # to update later
+# eval_data_dir="/cpfs01/user/lanyushi.p/data/Furnitures_compressed_lz4/wds-{000000..000004}.tar"
+
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_ani_trans_lst.txt"
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_lst_subset_shuffle.txt"
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_lst_subset_shuffle.txt"
+shards_lst="shell_scripts/shards_list/shards_lst_75K_shuffle.txt"
+
+
+DATASET_FLAGS="
+ --data_dir "NONE" \
+ --shards_lst ${shards_lst} \
+ --eval_data_dir "NONE" \
+ --eval_shards_lst ${shards_lst} \
+"
+
+conv_lr=2e-4
+lr=1e-4 # 4e-4 for BS=256, we have 80 here.
+
+vit_decoder_lr=$lr
+encoder_lr=${conv_lr} # scaling version , could be larger when multi-nodes
+triplane_decoder_lr=$conv_lr
+super_resolution_lr=$conv_lr
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $encoder_lr \
+--vit_decoder_lr $vit_decoder_lr \
+--triplane_decoder_lr $triplane_decoder_lr \
+--super_resolution_lr $super_resolution_lr \
+--lr $lr"
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --microbatch ${microbatch} \
+ --image_size_encoder $image_size_encoder \
+ --dino_version mv-sd-dit \
+ --sr_training False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --image_size $image_size \
+ --kl_lambda ${kl_lambda} \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder False \
+ --fg_mse True \
+ --bg_lamdba 1.0 \
+ --lpips_delay_iter 100 \
+ --sr_delay_iter 25000 \
+ --kl_anneal True \
+ --symmetry_loss False \
+ --vae_p 2 \
+ --plucker_embedding True \
+ --encoder_in_channels 10 \
+ --arch_dit_decoder DiT2-B/2 \
+ --sd_E_ch 64 \
+ --sd_E_num_res_blocks 1 \
+ --lrm_decoder False \
+ --resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/75K/add-depth/adv-0.002/bs4-gpu8/model_rec1620000.pt \
+ "
+
+
+# --resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/Ani-Trans-Furni/V=4/5-gpu8-lr2e-4vitlr-1e-4-plucker-patch46-ctd/model_rec0740000.pt \
+
+
+
+
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/Ani-Trans/V=4/MV/${batch_size}-gpu${NUM_GPUS}-lr${encoder_lr}vitlr-${vit_decoder_lr}-plucker-patch${patch_rendering_resolution}-V8-vtd-scaleinvDepth-noLRMDecoder-clipDepth-fullset_shuffle-fixdepth-shuffleInp/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/75K/bs${batch_size}-gpu${NUM_GPUS}-lr${encoder_lr}vitlr-${vit_decoder_lr}-patch${patch_rendering_resolution}
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/75K/bs${batch_size}-gpu${NUM_GPUS}-lr${encoder_lr}vitlr-${vit_decoder_lr}-patch${patch_rendering_resolution}-ctd
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/75K/add-depth/adv-${patchgan_disc}/bs${batch_size}-gpu${NUM_GPUS}
+logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/75K/add-depth/adv-${patchgan_disc_factor}-${patchgan_disc_g_weight}/bs${batch_size}-gpu${NUM_GPUS}
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--alpha_lambda 1.0 \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--decoder_output_dim 3 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder_S \
+"
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+# localedef -c -f UTF-8 -i en_US en_US.UTF-8
+export LC_ALL=en_US.UTF-8
+
+export OPENCV_IO_ENABLE_OPENEXR=1
+export OMP_NUM_THREADS=12
+export NCCL_ASYNC_ERROR_HANDLING=1
+export NCCL_IB_GID_INDEX=3 # https://github.com/huggingface/accelerate/issues/314#issuecomment-1821973930
+# export CUDA_VISIBLE_DEVICES=4,5,6,7
+# export CUDA_VISIBLE_DEVICES=4,5,6
+# export CUDA_VISIBLE_DEVICES=1,2,3,4
+# export CUDA_VISIBLE_DEVICES=0,1,2,3
+# export CUDA_VISIBLE_DEVICES=3,0,1
+# export CUDA_VISIBLE_DEVICES=0,1,2,3
+# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+# export CUDA_VISIBLE_DEVICES=1
+# export CUDA_VISIBLE_DEVICES=5,6,7
+# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
+
+
+ # --master_addr=${MASTER_ADDR} \
+ # --node_rank=${RANK} \
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --nnodes=1 \
+ --rdzv-endpoint=localhost:15368 \
+ --rdzv_backend=c10d \
+ scripts/vit_triplane_train.py \
+ --trainer_name nv_rec_patch_mvE_disc \
+ --num_workers ${num_workers} \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DATASET_FLAGS} \
+ --lpips_lambda $lpips_lambda \
+ --overfitting False \
+ --load_pretrain_encoder False \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 250000000 \
+ --decomposed True \
+ --logdir $logdir \
+ --decoder_load_pretrained False \
+ --cfg objverse_tuneray_aug_resolution_64_64_auto \
+ --patch_size ${patch_size} \
+ --use_amp False \
+ --eval_batch_size 1 \
+ ${LR_FLAGS} \
+ --l1_lambda ${l1_lambda} \
+ --l2_lambda ${l2_lambda} \
+ --ssim_lambda ${ssim_lambda} \
+ --depth_smoothness_lambda 0 \
+ --use_conf_map False \
+ --objv_dataset True \
+ --depth_lambda 0.5 \
+ --patch_rendering_resolution ${patch_rendering_resolution} \
+ --use_lmdb_compressed False \
+ --use_lmdb False \
+ --mv_input True \
+ --split_chunk_input True \
+ --append_depth True \
+ --patchgan_disc_factor ${patchgan_disc_factor} \
+ --patchgan_disc_g_weight ${patchgan_disc_g_weight} \
+ --use_wds True \
+
+# --inference True \
+
+# --dataset_size 1000 \
+
+# --trainer_name nv_rec_patch \
+# --cfg shapenet_tuneray_aug_resolution_64_64_nearestSR_patch \
+
+# --use_conf_map True
+
+# seed=0 fails to converge at the beginning
+
+# scripts/vit_triplane_train.py \
+
+# --rec_cvD_lambda 0.05 \
+# --nvs_cvD_lambda 0.2 \
\ No newline at end of file
diff --git a/shell_scripts/final_release/train/stage-1-vae/ShapeNet/.gitkeep b/shell_scripts/final_release/train/stage-1-vae/ShapeNet/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/shell_scripts/final_release/train/stage-2-diffusion/objaverse-ldm.sh b/shell_scripts/final_release/train/stage-2-diffusion/objaverse-ldm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8faef0e316dca9bc7b4c77645d189065faf4fac1
--- /dev/null
+++ b/shell_scripts/final_release/train/stage-2-diffusion/objaverse-ldm.sh
@@ -0,0 +1,312 @@
+# mv_latent_noMixing_75K_sgm_legacy
+
+set -x
+
+lpips_lambda=0.8
+
+image_size=128 # final rendered resolution
+image_size_encoder=256
+
+patch_size=14
+
+# batch_size=32 # 4*32
+# ! 29GB -> 37GB
+# batch_size=8 # 128 when 3?
+# batch_size=1 # for debug
+
+# batch_size=48
+# batch_size=96 # BS=480 on 5GPU
+
+# batch_size=36 # BS=256 is enough
+# batch_size=16 # BS=256 is enough
+
+# batch_size=80 # BS=480 on 5GPU
+# batch_size=18 # 126 in total
+# microbatch=72
+
+# batch_size=48 #
+# batch_size=40 #
+batch_size=36 # 8GPU here
+# batch_size=85 #
+# batch_size=96 #
+# batch_size=24 # 128 in total
+# batch_size=36 # 128 in total
+# batch_size=40 # 128 in total
+# batch_size=96 # 128 in total
+# batch_size=64 # 128 in total
+# batch_size=80 # 128 in total
+microbatch=${batch_size}
+
+cfg_dropout_prob=0.1 # SD config
+
+# dataset_size=10000
+# dataset_name=Ani-Trans-Furni
+dataset_name="75K"
+# num_workers=12
+# num_workers=7
+# num_workers=12
+num_workers=0
+
+
+# NUM_GPUS=4
+# NUM_GPUS=7
+# NUM_GPUS=3
+# NUM_GPUS=2
+NUM_GPUS=8
+# NUM_GPUS=7
+
+# shards_lst=/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/diff_shards_lst_3w.txt
+# shards_lst="/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/shards_list/diff_mv_latent_132w_3cls.txt"
+# eval_shards_lst=/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_lst.txt
+# DATASET_FLAGS="
+# --data_dir "NONE" \
+# --shards_lst ${shards_lst} \
+# --eval_data_dir "NONE" \
+# --eval_shards_lst ${eval_shards_lst} \
+# "
+
+# shards_lst=/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/diff_shards_lst_3w.txt
+# shards_lst="shell_scripts/baselines/reconstruction/sr/final_mv/diff_shards_lst_3w_shuffle.txt"
+shards_lst="shell_scripts/shards_list/diff_singleview_shards_75K.txt"
+eval_shards_lst=/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/baselines/reconstruction/sr/final_mv/shards_lst.txt
+DATASET_FLAGS="
+ --data_dir "NONE" \
+ --shards_lst ${shards_lst} \
+ --eval_data_dir "NONE" \
+ --eval_shards_lst ${eval_shards_lst} \
+"
+
+
+
+
+
+# eval_data_dir=/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/download_unzipped/Furnitures # to update later
+# eval_data_dir=${data_dir}
+
+# --dataset_size ${dataset_size} \
+
+# lr=2e-5 # for official DiT, lr=1e-4 for BS=256
+# lr=4e-5 # for official DiT, lr=1e-4 for BS=256
+lr=1e-4 # for LDM base learning rate
+# lr=3e-5 # for official DiT, lr=1e-4 for BS=256
+kl_lambda=0
+vit_lr=1e-5 # for improved-diffusion unet
+ce_lambda=0.5 # ?
+conv_lr=5e-5
+alpha_lambda=1
+scale_clip_encoding=1
+
+# triplane_scaling_divider=0.8918
+# triplane_scaling_divider=0.857916
+# triplane_scaling_divider=0.883637
+# triplane_scaling_divider=0.89247337
+# triplane_scaling_divider=0.82
+# triplane_scaling_divider=0.88
+# triplane_scaling_divider=0.89
+triplane_scaling_divider=0.90
+
+# * above the best lr config
+
+LR_FLAGS="--encoder_lr $vit_lr \
+ --vit_decoder_lr $vit_lr \
+ --lpips_lambda $lpips_lambda \
+ --triplane_decoder_lr $conv_lr \
+ --super_resolution_lr $conv_lr \
+ --lr $lr \
+ --kl_lambda ${kl_lambda} \
+ --bg_lamdba 0.01 \
+ --alpha_lambda ${alpha_lambda} \
+"
+
+TRAIN_FLAGS="--iterations 10001 --anneal_lr False \
+ --batch_size $batch_size --save_interval 10000 \
+ --microbatch ${microbatch} \
+ --image_size_encoder $image_size_encoder \
+ --image_size $image_size \
+ --dino_version mv-sd-dit \
+ --sr_training False \
+ --encoder_cls_token False \
+ --decoder_cls_token False \
+ --cls_token False \
+ --weight_decay 0.05 \
+ --no_dim_up_mlp True \
+ --uvit_skip_encoder True \
+ --decoder_load_pretrained True \
+ --fg_mse False \
+ --vae_p 2 \
+ --plucker_embedding True \
+ --encoder_in_channels 10 \
+ --arch_dit_decoder DiT2-B/2 \
+ --sd_E_ch 64 \
+ --sd_E_num_res_blocks 1 \
+ --lrm_decoder False \
+ "
+
+# --resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/Ani-Trans-Furni-rollout-12-lr3e-5-divide0.82/model_joint_denoise_rec_model0960000.pt \
+
+DDPM_MODEL_FLAGS="
+--learn_sigma False \
+--num_heads 8 \
+--num_res_blocks 2 \
+--num_channels 320 \
+--attention_resolutions "4,2,1" \
+--use_spatial_transformer True \
+--transformer_depth 1 \
+--context_dim 768 \
+"
+# --pred_type x0 \
+# --iw_sample_p drop_all_uniform \
+# --loss_type x0 \
+
+# ! diffusion steps and noise schedule not used, since the continuous diffusion is adopted.
+DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear \
+--use_kl False \
+--use_amp False \
+--triplane_scaling_divider ${triplane_scaling_divider} \
+--trainer_name sgm_legacy \
+--mixed_prediction False \
+--train_vae False \
+--denoise_in_channels 4 \
+--denoise_out_channels 4 \
+--diffusion_input_size 32 \
+--diffusion_ce_anneal True \
+--create_controlnet False \
+--p_rendering_loss False \
+--pred_type x_start \
+--predict_v False \
+--create_dit False \
+--train_vae False \
+--use_eos_feature False \
+--roll_out True \
+"
+
+# --dit_model_arch DiT-L/2 \
+
+# --trainer_name vpsde_TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED \
+
+# --predict_xstart True \
+
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-dit/${dataset_name}/cond_abla-rollout
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/${dataset_name}-rollout-${batch_size}-lr${lr}/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/${dataset_name}-rollout-${batch_size}-lr${lr}-ctd-smallBS/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/${dataset_name}-rollout-${batch_size}-lr${lr}-ctd-smallBS/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/${dataset_name}-rollout-${batch_size}-lr${lr}-ctd-smallBS-divide${triplane_scaling_divider}/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/${dataset_name}-rollout-${batch_size}-lr${lr}-ctd-smallBS-divide${triplane_scaling_divider}-mv/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/${dataset_name}-rollout-${batch_size}-lr${lr}-divide${triplane_scaling_divider}/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/${dataset_name}-rollout-${batch_size}-lr${lr}-divide${triplane_scaling_divider}-ctd/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/${dataset_name}-rollout-${batch_size}-lr${lr}-divide${triplane_scaling_divider}-ctd/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/3cls-138wLatent-gpu${NUM_GPUS}-batch${batch_size}/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-load3clsPT/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-fixingbug/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-removeImgCond-clipgrad0.5/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-sgm_legacy/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-sgm_legacy-newview-addTop-clipgrad0.4/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-sgm_legacy-newview-addTop-clipgrad0.4-debug/
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-sgm_legacy-newview-addTop-clipgrad0.4-ctd/
+logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu${NUM_GPUS}-batch${batch_size}-noPT-lr${lr}-newview-sgm_legacy-newview-addTop-clipgrad0.4-ctd-largeLR/
+# crossattn/TextEmbed/cfgDrop${cfg_dropout_prob}-gpu${NUM_GPUS}-batch${batch_size}
+
+# logdir=/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/Furnitures/SD/1000/SD-Encoder-F=8-D128/${batch_size}-gpu${NUM_GPUS}-patch45-32to128-heavy-final-noUpsample-wds-lr${encoder_lr}-lpips2-128-k=4ctd/
+
+SR_TRAIN_FLAGS_v1_2XC="
+--decoder_in_chans 32 \
+--out_chans 96 \
+--ae_classname vit.vit_triplane.RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder \
+--logdir $logdir \
+--arch_encoder vits \
+--arch_decoder vitb \
+--vit_decoder_wd 0.001 \
+--encoder_weight_decay 0.001 \
+--color_criterion mse \
+--triplane_in_chans 32 \
+--decoder_output_dim 3 \
+--resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu8-batch36-noPT-lr3e-5-newview-sgm_legacy-newview-addTop-clipgrad0.4-ctd/model_joint_denoise_rec_model2070000.pt \
+"
+# --resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu7-batch48-noPT-lr1e-4/model_joint_denoise_rec_model1770000.pt \
+
+# --resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/3cls-138wLatent-gpu-batch96/model_joint_denoise_rec_model1780000.pt \
+
+# --resume_checkpoint /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/LSGM/cldm-unet/mv/t23d/75K-168wLatent-gpu-batch64/model_joint_denoise_rec_model1690000.pt \
+
+
+
+SR_TRAIN_FLAGS=${SR_TRAIN_FLAGS_v1_2XC}
+
+
+rm -rf "$logdir"/runs
+mkdir -p "$logdir"/
+cp "$0" "$logdir"/
+
+export OMP_NUM_THREADS=12
+export LC_ALL=en_US.UTF-8 # save caption txt bug
+export NCCL_ASYNC_ERROR_HANDLING=1
+export OPENCV_IO_ENABLE_OPENEXR=1
+export NCCL_IB_GID_INDEX=3 # https://github.com/huggingface/accelerate/issues/314#issuecomment-1821973930
+# export CUDA_VISIBLE_DEVICES=0,1,2
+
+# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
+# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
+# export CUDA_VISIBLE_DEVICES=7
+# export CUDA_VISIBLE_DEVICES=3,7
+# export CUDA_VISIBLE_DEVICES=3,4,5
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+# export CUDA_VISIBLE_DEVICES=1,2
+# export CUDA_VISIBLE_DEVICES=0,1,2,3
+# export CUDA_VISIBLE_DEVICES=0,3
+# export CUDA_VISIBLE_DEVICES=0,3
+# export CUDA_VISIBLE_DEVICES=4,5,6,7
+# export CUDA_VISIBLE_DEVICES=4,5,6
+
+torchrun --nproc_per_node=$NUM_GPUS \
+ --nnodes 1 \
+ --rdzv-endpoint=localhost:23371 \
+ scripts/vit_triplane_diffusion_train.py \
+ --num_workers ${num_workers} \
+ --depth_lambda 0 \
+ ${TRAIN_FLAGS} \
+ ${SR_TRAIN_FLAGS} \
+ ${DATASET_FLAGS} \
+ ${DIFFUSION_FLAGS} \
+ ${DDPM_MODEL_FLAGS} \
+ --overfitting False \
+ --load_pretrain_encoder False \
+ --iterations 5000001 \
+ --save_interval 10000 \
+ --eval_interval 5000000 \
+ --decomposed True \
+ --logdir $logdir \
+ --cfg objverse_tuneray_aug_resolution_64_64_auto \
+ --patch_size ${patch_size} \
+ --eval_batch_size 1 \
+ ${LR_FLAGS} \
+ --ce_lambda ${ce_lambda} \
+ --negative_entropy_lambda ${ce_lambda} \
+ --triplane_fg_bg False \
+ --grad_clip True \
+ --interval 5 \
+ --normalize_clip_encoding True \
+ --scale_clip_encoding ${scale_clip_encoding} \
+ --mixing_logit_init 10000 \
+ --objv_dataset True \
+ --cfg_dropout_prob ${cfg_dropout_prob} \
+ --cond_key caption \
+ --use_lmdb_compressed False \
+ --use_lmdb False \
+ --load_wds_diff True \
+ --load_wds_latent False \
+ --compile False \
+ --split_chunk_input True \
+ --append_depth True \
+ --mv_input True \
+ --duplicate_sample False \
+ --enable_mixing_normal False \
+ --use_wds True \
+ --clip_grad_throld 0.4 \
+ --mv_latent_dir /cpfs01/user/lanyushi.p/data/latent_dir/168w-3class-withImg-newview-addTop/latent_dir \
+# --mv_latent_dir /cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/75K/infer-latents/168w-3class-withImg/latent_dir \
+
+
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e48ae6c4020b200c9d4b2046fb3a2758820bfbfc
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1 @@
+# helper functions copied from 3dgs
\ No newline at end of file
diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba459d05d34381e3918707ccc14e2ecf32851358
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/__pycache__/general_utils.cpython-39.pyc b/utils/__pycache__/general_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b73c379f51f766f20675fa84fdd867d6a568ba7a
Binary files /dev/null and b/utils/__pycache__/general_utils.cpython-39.pyc differ
diff --git a/utils/dust3r/__init__.py b/utils/dust3r/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/dust3r/__pycache__/__init__.cpython-39.pyc b/utils/dust3r/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75b6a8592685552c4086407b0b2d37e0d1ffefe5
Binary files /dev/null and b/utils/dust3r/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/dust3r/__pycache__/dpt_block.cpython-39.pyc b/utils/dust3r/__pycache__/dpt_block.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18d710986d2c6f11581a98118b3932a8ff175214
Binary files /dev/null and b/utils/dust3r/__pycache__/dpt_block.cpython-39.pyc differ
diff --git a/utils/dust3r/dpt_block.py b/utils/dust3r/dpt_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb79ab51b0102c32644c0ff35909aa052648069
--- /dev/null
+++ b/utils/dust3r/dpt_block.py
@@ -0,0 +1,462 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+# --------------------------------------------------------
+# DPT head for ViTs
+# --------------------------------------------------------
+# References:
+# https://github.com/isl-org/DPT
+# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from typing import Union, Tuple, Iterable, List, Optional, Dict
+
+from pdb import set_trace as st
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+def make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand == True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0],
+ out_shape1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1],
+ out_shape2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2],
+ out_shape3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3],
+ out_shape4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+
+ scratch.layer_rn = nn.ModuleList([
+ scratch.layer1_rn,
+ scratch.layer2_rn,
+ scratch.layer3_rn,
+ scratch.layer4_rn,
+ ])
+
+ return scratch
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ self.conv2 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ width_ratio=1,
+ ):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+ self.width_ratio = width_ratio
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups = 1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ groups=1,
+ )
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ if self.width_ratio != 1:
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
+
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ if self.width_ratio != 1:
+ # and output.shape[3] < self.width_ratio * output.shape[2]
+ #size=(image.shape[])
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
+ shape = 3 * output.shape[3]
+ else:
+ shape = int(self.width_ratio * 2 * output.shape[2])
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
+ else:
+ output = nn.functional.interpolate(output, scale_factor=2,
+ mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+ return output
+
+def make_fusion_block(features, use_bn, width_ratio=1):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ width_ratio=width_ratio,
+ )
+
+class Interpolate(nn.Module):
+ """Interpolation module."""
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x,
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ align_corners=self.align_corners,
+ )
+
+ return x
+
+class DPTOutputAdapter(nn.Module):
+ """DPT output adapter.
+
+ :param num_cahnnels: Number of output channels
+ :param stride_level: tride level compared to the full-sized image.
+ E.g. 4 for 1/4th the size of the image.
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
+ Patch size for smaller inputs will be computed accordingly.
+ :param hooks: Index of intermediate layers
+ :param layer_dims: Dimension of intermediate layers
+ :param feature_dim: Feature dimension
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
+ :param use_bn: If set to True, activates batch norm
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
+ """
+
+ def __init__(self,
+ num_channels: int = 1,
+ stride_level: int = 1,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ main_tasks: Iterable[str] = ('rgb',),
+ hooks: List[int] = [2, 5, 8, 11],
+ layer_dims: List[int] = [96, 192, 384, 768],
+ feature_dim: int = 256,
+ last_dim: int = 32,
+ use_bn: bool = False,
+ dim_tokens_enc: Optional[int] = None,
+ head_type: str = 'regression',
+ output_width_ratio=1,
+ **kwargs):
+ super().__init__()
+ self.num_channels = num_channels
+ self.stride_level = stride_level
+ self.patch_size = pair(patch_size)
+ self.main_tasks = main_tasks
+ self.hooks = hooks
+ self.layer_dims = layer_dims
+ self.feature_dim = feature_dim
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
+ self.head_type = head_type
+
+ # Actual patch height and width, taking into account stride of input
+ self.P_H = max(1, self.patch_size[0] // stride_level)
+ self.P_W = max(1, self.patch_size[1] // stride_level)
+
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
+
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
+
+ if self.head_type == 'regression':
+ # The "DPTDepthModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
+ )
+ elif self.head_type == 'regression_gs': # avoid upsampling here
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
+ # Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.ReLU(True),
+ # nn.Dropout(0.1, False),
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
+ )
+ elif self.head_type == 'semseg':
+ # The "DPTSegmentationModel" head
+ self.head = nn.Sequential(
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
+ nn.ReLU(True),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ )
+ else:
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
+
+ if self.dim_tokens_enc is not None:
+ self.init(dim_tokens_enc=dim_tokens_enc)
+
+ def init(self, dim_tokens_enc=768):
+ """
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
+ Should be called when setting up MultiMAE.
+
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
+ """
+ #print(dim_tokens_enc)
+
+ # Set up activation postprocessing layers
+ if isinstance(dim_tokens_enc, int):
+ dim_tokens_enc = 4 * [dim_tokens_enc]
+
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
+
+ self.act_1_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[0],
+ out_channels=self.layer_dims[0],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=self.layer_dims[0],
+ out_channels=self.layer_dims[0],
+ kernel_size=4, stride=4, padding=0,
+ bias=True, dilation=1, groups=1,
+ )
+ )
+
+ self.act_2_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[1],
+ out_channels=self.layer_dims[1],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=self.layer_dims[1],
+ out_channels=self.layer_dims[1],
+ kernel_size=2, stride=2, padding=0,
+ bias=True, dilation=1, groups=1,
+ )
+ )
+
+ self.act_3_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[2],
+ out_channels=self.layer_dims[2],
+ kernel_size=1, stride=1, padding=0,
+ )
+ )
+
+ self.act_4_postprocess = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.dim_tokens_enc[3],
+ out_channels=self.layer_dims[3],
+ kernel_size=1, stride=1, padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=self.layer_dims[3],
+ out_channels=self.layer_dims[3],
+ kernel_size=3, stride=2, padding=1,
+ )
+ )
+
+ self.act_postprocess = nn.ModuleList([
+ self.act_1_postprocess,
+ self.act_2_postprocess,
+ self.act_3_postprocess,
+ self.act_4_postprocess
+ ])
+
+ def adapt_tokens(self, encoder_tokens):
+ # Adapt tokens
+ x = []
+ x.append(encoder_tokens[:, :])
+ x = torch.cat(x, dim=-1)
+ return x
+
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
+ #input_info: Dict):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ H, W = image_size
+
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(layers[3])
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ # Output head
+ out = self.head(path_1)
+
+ return out
diff --git a/utils/dust3r/heads/__init__.py b/utils/dust3r/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53d0aa5610cae95f34f96bdb3ff9e835a2d6208e
--- /dev/null
+++ b/utils/dust3r/heads/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# head factory
+# --------------------------------------------------------
+from .linear_head import LinearPts3d
+from .dpt_head import create_dpt_head
+
+
+def head_factory(head_type, output_mode, net, has_conf=False):
+ """" build a prediction head for the decoder
+ """
+ if head_type == 'linear' and output_mode == 'pts3d':
+ return LinearPts3d(net, has_conf)
+ elif head_type == 'dpt' and output_mode == 'pts3d':
+ return create_dpt_head(net, has_conf=has_conf)
+ else:
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
diff --git a/utils/dust3r/heads/__pycache__/__init__.cpython-39.pyc b/utils/dust3r/heads/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..801c80beb9b23494d4e2370da819b57e52969c5a
Binary files /dev/null and b/utils/dust3r/heads/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/dust3r/heads/__pycache__/dpt_head.cpython-39.pyc b/utils/dust3r/heads/__pycache__/dpt_head.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e043846aaa98a23505e9a7ac82af6203c49f4b69
Binary files /dev/null and b/utils/dust3r/heads/__pycache__/dpt_head.cpython-39.pyc differ
diff --git a/utils/dust3r/heads/__pycache__/linear_head.cpython-39.pyc b/utils/dust3r/heads/__pycache__/linear_head.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed51cff92c2aba36687cb5df464dbc4e4d2d436e
Binary files /dev/null and b/utils/dust3r/heads/__pycache__/linear_head.cpython-39.pyc differ
diff --git a/utils/dust3r/heads/__pycache__/postprocess.cpython-39.pyc b/utils/dust3r/heads/__pycache__/postprocess.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1132097f274d5961d0842248318d9db88a8c04a
Binary files /dev/null and b/utils/dust3r/heads/__pycache__/postprocess.cpython-39.pyc differ
diff --git a/utils/dust3r/heads/dpt_head.py b/utils/dust3r/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..09f5f83a732a43250061c605bcf1ae5299330391
--- /dev/null
+++ b/utils/dust3r/heads/dpt_head.py
@@ -0,0 +1,167 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# dpt head implementation for DUST3R
+# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
+# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
+# the forward function also takes as input a dictionnary img_info with key "height" and "width"
+# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
+# --------------------------------------------------------
+from einops import rearrange
+from typing import List
+import torch
+import torch.nn as nn
+from utils.dust3r.heads.postprocess import postprocess
+# import utils.dust3r.utils.path_to_croco # noqa: F401
+from utils.dust3r.dpt_block import DPTOutputAdapter # noqa
+
+from pdb import set_trace as st
+
+
+class DPTOutputAdapter_fix(DPTOutputAdapter):
+ """
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
+ remove duplicated weigths, and fix forward for dust3r
+ """
+
+ def init(self, dim_tokens_enc=768):
+ super().init(dim_tokens_enc)
+ # these are duplicated weights
+ del self.act_1_postprocess
+ del self.act_2_postprocess
+ del self.act_3_postprocess
+ del self.act_4_postprocess
+
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+ # H, W = input_info['image_size']
+ image_size = self.image_size if image_size is None else image_size
+ H, W = image_size
+ # Number of patches in height and width
+ N_H = H // (self.stride_level * self.P_H)
+ N_W = W // (self.stride_level * self.P_W)
+
+ # Hook decoder onto 4 layers from specified ViT layers
+ layers = [encoder_tokens[hook] for hook in self.hooks]
+
+ # Extract only task-relevant tokens and ignore global tokens.
+ layers = [self.adapt_tokens(l) for l in layers]
+
+ # Reshape tokens to spatial representation
+ layers = [
+ rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W)
+ for l in layers
+ ]
+ # st()
+
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+ # Project layers to chosen feature dim
+ layers = [
+ self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)
+ ]
+
+ # Fuse layers using refinement stages
+ path_4 = self.scratch.refinenet4(
+ layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+ # Output head
+ out = self.head(path_1)
+
+ return out
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
+
+ def __init__(self,
+ *,
+ n_cls_token=0,
+ hooks_idx=None,
+ dim_tokens=None,
+ output_width_ratio=1,
+ num_channels=1,
+ postprocess=None,
+ depth_mode=None,
+ conf_mode=None,
+ **kwargs):
+ super(PixelwiseTaskWithDPT, self).__init__()
+ self.return_all_layers = True # backbone needs to return all layers
+ self.postprocess = postprocess
+ self.depth_mode = depth_mode
+ self.conf_mode = conf_mode
+
+ assert n_cls_token == 0, "Not implemented"
+ dpt_args = dict(output_width_ratio=output_width_ratio,
+ num_channels=num_channels,
+ **kwargs)
+ if hooks_idx is not None:
+ dpt_args.update(hooks=hooks_idx)
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
+ dpt_init_args = {} if dim_tokens is None else {
+ 'dim_tokens_enc': dim_tokens
+ }
+ self.dpt.init(**dpt_init_args)
+
+ # ! remove unused param
+ del self.dpt.scratch.refinenet4.resConfUnit1
+
+ def forward(self, x, img_info):
+ out = self.dpt(x, image_size=(img_info[0], img_info[1]))
+ if self.postprocess:
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
+ return out
+
+
+def create_dpt_head(net, has_conf=False):
+ """
+ return PixelwiseTaskWithDPT for given net params
+ """
+ assert net.dec_depth > 9
+ l2 = net.dec_depth
+ feature_dim = 256
+ last_dim = feature_dim // 2
+ out_nchan = 3
+ ed = net.enc_embed_dim
+ dd = net.dec_embed_dim
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
+ feature_dim=feature_dim,
+ last_dim=last_dim,
+ hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
+ dim_tokens=[ed, dd, dd, dd],
+ postprocess=postprocess,
+ # postprocess=None,
+ depth_mode=net.depth_mode,
+ conf_mode=net.conf_mode,
+ head_type='regression')
+
+
+# def create_dpt_head_ln3diff(net, has_conf=False):
+def create_dpt_head_ln3diff(out_nchan, feature_dim, l2, dec_embed_dim,
+ patch_size=2, has_conf=False):
+ """
+ return PixelwiseTaskWithDPT for given net params
+ """
+ # assert net.dec_depth > 9
+ # l2 = net.dec_depth
+ # feature_dim = 256
+ last_dim = feature_dim // 2
+ # out_nchan = 3
+ # ed = net.enc_embed_dim
+ # dd = net.dec_embed_dim
+ dd = dec_embed_dim
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
+ feature_dim=feature_dim,
+ last_dim=last_dim,
+ patch_size=patch_size,
+ hooks_idx=[(l2 * 1 // 4)-1, (l2 * 2 // 4)-1, (l2 * 3 // 4)-1, l2-1],
+ # dim_tokens=[ed, dd, dd, dd],
+ dim_tokens=[dd, dd, dd, dd],
+ # postprocess=postprocess,
+ postprocess=None,
+ # depth_mode=net.depth_mode,
+ # conf_mode=net.conf_mode,
+ head_type='regression_gs')
diff --git a/utils/dust3r/heads/linear_head.py b/utils/dust3r/heads/linear_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..27c5678d551033cc576798626b7ba59b1e7b20cc
--- /dev/null
+++ b/utils/dust3r/heads/linear_head.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# linear head implementation for DUST3R
+# --------------------------------------------------------
+import torch.nn as nn
+import torch.nn.functional as F
+from .postprocess import postprocess
+
+
+class LinearPts3d (nn.Module):
+ """
+ Linear head for dust3r
+ Each token outputs: - 16x16 3D points (+ confidence)
+ """
+
+ def __init__(self, net, has_conf=False):
+ super().__init__()
+ self.patch_size = net.patch_embed.patch_size[0]
+ self.depth_mode = net.depth_mode
+ self.conf_mode = net.conf_mode
+ self.has_conf = has_conf
+
+ self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
+
+ def setup(self, croconet):
+ pass
+
+ def forward(self, decout, img_shape):
+ H, W = img_shape
+ tokens = decout[-1]
+ B, S, D = tokens.shape
+
+ # extract 3D points
+ feat = self.proj(tokens) # B,S,D
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
+
+ # permute + norm depth
+ return postprocess(feat, self.depth_mode, self.conf_mode)
diff --git a/utils/dust3r/heads/postprocess.py b/utils/dust3r/heads/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd68a90d89b8dcd7d8a4b4ea06ef8b17eb5da093
--- /dev/null
+++ b/utils/dust3r/heads/postprocess.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# post process function for all heads: extract 3D points/confidence from output
+# --------------------------------------------------------
+import torch
+
+
+def postprocess(out, depth_mode, conf_mode):
+ """
+ extract 3D points/confidence from prediction head output
+ """
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,3
+ res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
+
+ if conf_mode is not None:
+ res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
+ return res
+
+
+def reg_dense_depth(xyz, mode):
+ """
+ extract 3D points from prediction head output
+ """
+ mode, vmin, vmax = mode
+
+ no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
+ assert no_bounds
+
+ if mode == 'linear':
+ if no_bounds:
+ return xyz # [-inf, +inf]
+ return xyz.clip(min=vmin, max=vmax)
+
+ # distance to origin
+ d = xyz.norm(dim=-1, keepdim=True)
+ xyz = xyz / d.clip(min=1e-8)
+
+ if mode == 'square':
+ return xyz * d.square()
+
+ if mode == 'exp':
+ return xyz * torch.expm1(d)
+
+ raise ValueError(f'bad {mode=}')
+
+
+def reg_dense_conf(x, mode):
+ """
+ extract confidence from prediction head output
+ """
+ mode, vmin, vmax = mode
+ if mode == 'exp':
+ return vmin + x.exp().clip(max=vmax-vmin)
+ if mode == 'sigmoid':
+ return (vmax - vmin) * torch.sigmoid(x) + vmin
+ raise ValueError(f'bad {mode=}')
diff --git a/utils/general_utils.py b/utils/general_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..29abf43a6635c28c66aaee897fbba6b1d778d350
--- /dev/null
+++ b/utils/general_utils.py
@@ -0,0 +1,211 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import sys
+from datetime import datetime
+import numpy as np
+import random
+
+def inverse_sigmoid(x):
+ return torch.log(x/(1-x))
+
+def PILtoTorch(pil_image, resolution):
+ resized_image_PIL = pil_image.resize(resolution)
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
+ if len(resized_image.shape) == 3:
+ return resized_image.permute(2, 0, 1)
+ else:
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
+
+def get_expon_lr_func(
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+ """
+ Copied from Plenoxels
+
+ Continuous learning rate decay function. Adapted from JaxNeRF
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+ function of lr_delay_mult, such that the initial learning rate is
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+ to the normal learning rate when steps>lr_delay_steps.
+ :param conf: config subtree 'lr' or similar
+ :param max_steps: int, the number of steps during optimization.
+ :return HoF which takes step as input
+ """
+
+ def helper(step):
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+ # Disable this parameter
+ return 0.0
+ if lr_delay_steps > 0:
+ # A kind of reverse cosine decay.
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+ )
+ else:
+ delay_rate = 1.0
+ t = np.clip(step / max_steps, 0, 1)
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+ return delay_rate * log_lerp
+
+ return helper
+
+def strip_lowerdiag(L):
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
+
+ uncertainty[:, 0] = L[:, 0, 0]
+ uncertainty[:, 1] = L[:, 0, 1]
+ uncertainty[:, 2] = L[:, 0, 2]
+ uncertainty[:, 3] = L[:, 1, 1]
+ uncertainty[:, 4] = L[:, 1, 2]
+ uncertainty[:, 5] = L[:, 2, 2]
+ return uncertainty
+
+def strip_symmetric(sym):
+ return strip_lowerdiag(sym)
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ From Pytorch3d
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ From Pytorch3d
+ Multiply two quaternions.
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ a: Quaternions as tensor of shape (..., 4), real part first.
+ b: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of a and b, a tensor of quaternions shape (..., 4).
+ """
+ aw, ax, ay, az = torch.unbind(a, -1)
+ bw, bx, by, bz = torch.unbind(b, -1)
+ ow = aw * bw - ax * bx - ay * by - az * bz
+ ox = aw * bx + ax * bw + ay * bz - az * by
+ oy = aw * by - ax * bz + ay * bw + az * bx
+ oz = aw * bz + ax * by - ay * bx + az * bw
+ return torch.stack((ow, ox, oy, oz), -1)
+
+# Matrix to quaternion does not come under NVIDIA Copyright
+# Written by Stan Szymanowicz 2023
+def matrix_to_quaternion(M: torch.Tensor) -> torch.Tensor:
+ """
+ Matrix-to-quaternion conversion method. Equation taken from
+ https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/index.htm
+ Args:
+ M: rotation matrices, (3 x 3)
+ Returns:
+ q: quaternion of shape (4)
+ """
+ tr = 1 + M[ 0, 0] + M[ 1, 1] + M[ 2, 2]
+
+ if tr > 0:
+ r = torch.sqrt(tr) / 2.0
+ x = ( M[ 2, 1] - M[ 1, 2] ) / ( 4 * r )
+ y = ( M[ 0, 2] - M[ 2, 0] ) / ( 4 * r )
+ z = ( M[ 1, 0] - M[ 0, 1] ) / ( 4 * r )
+ elif ( M[ 0, 0] > M[ 1, 1]) and (M[ 0, 0] > M[ 2, 2]):
+ S = torch.sqrt(1.0 + M[ 0, 0] - M[ 1, 1] - M[ 2, 2]) * 2 # S=4*qx
+ r = (M[ 2, 1] - M[ 1, 2]) / S
+ x = 0.25 * S
+ y = (M[ 0, 1] + M[ 1, 0]) / S
+ z = (M[ 0, 2] + M[ 2, 0]) / S
+ elif M[ 1, 1] > M[ 2, 2]:
+ S = torch.sqrt(1.0 + M[ 1, 1] - M[ 0, 0] - M[ 2, 2]) * 2 # S=4*qy
+ r = (M[ 0, 2] - M[ 2, 0]) / S
+ x = (M[ 0, 1] + M[ 1, 0]) / S
+ y = 0.25 * S
+ z = (M[ 1, 2] + M[ 2, 1]) / S
+ else:
+ S = torch.sqrt(1.0 + M[ 2, 2] - M[ 0, 0] - M[ 1, 1]) * 2 # S=4*qz
+ r = (M[ 1, 0] - M[ 0, 1]) / S
+ x = (M[ 0, 2] + M[ 2, 0]) / S
+ y = (M[ 1, 2] + M[ 2, 1]) / S
+ z = 0.25 * S
+
+ return torch.stack([r, x, y, z], dim=-1)
+
+def build_rotation(r):
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
+
+ q = r / norm[:, None]
+
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+ r = q[:, 0]
+ x = q[:, 1]
+ y = q[:, 2]
+ z = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+ R[:, 0, 1] = 2 * (x*y - r*z)
+ R[:, 0, 2] = 2 * (x*z + r*y)
+ R[:, 1, 0] = 2 * (x*y + r*z)
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+ R[:, 1, 2] = 2 * (y*z - r*x)
+ R[:, 2, 0] = 2 * (x*z - r*y)
+ R[:, 2, 1] = 2 * (y*z + r*x)
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+ return R
+
+def build_scaling_rotation(s, r):
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+ R = build_rotation(r)
+
+ L[:,0,0] = s[:,0]
+ L[:,1,1] = s[:,1]
+ L[:,2,2] = s[:,2]
+
+ L = R @ L
+ return L
+
+def safe_state(cfg, silent=False):
+ old_f = sys.stdout
+ class F:
+ def __init__(self, silent):
+ self.silent = silent
+
+ def write(self, x):
+ if not self.silent:
+ if x.endswith("\n"):
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
+ else:
+ old_f.write(x)
+
+ def flush(self):
+ old_f.flush()
+
+ sys.stdout = F(silent)
+
+ random.seed(cfg.general.random_seed)
+ np.random.seed(cfg.general.random_seed)
+ torch.manual_seed(cfg.general.random_seed)
+ device = torch.device("cuda:{}".format(cfg.general.device))
+ torch.cuda.set_device(device)
+
+ return device
diff --git a/utils/gs_utils/__pycache__/graphics_utils.cpython-39.pyc b/utils/gs_utils/__pycache__/graphics_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de5bbc7930628bae7188b1fb9d7c099da4267c84
Binary files /dev/null and b/utils/gs_utils/__pycache__/graphics_utils.cpython-39.pyc differ
diff --git a/utils/gs_utils/camera_utils.py b/utils/gs_utils/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a54d0ada0361997109c462cde1e088ea5da9ff2
--- /dev/null
+++ b/utils/gs_utils/camera_utils.py
@@ -0,0 +1,82 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+from scene.cameras import Camera
+import numpy as np
+from utils.general_utils import PILtoTorch
+from utils.graphics_utils import fov2focal
+
+WARNED = False
+
+def loadCam(args, id, cam_info, resolution_scale):
+ orig_w, orig_h = cam_info.image.size
+
+ if args.resolution in [1, 2, 4, 8]:
+ resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
+ else: # should be a type that converts to float
+ if args.resolution == -1:
+ if orig_w > 1600:
+ global WARNED
+ if not WARNED:
+ print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
+ "If this is not desired, please explicitly specify '--resolution/-r' as 1")
+ WARNED = True
+ global_down = orig_w / 1600
+ else:
+ global_down = 1
+ else:
+ global_down = orig_w / args.resolution
+
+ scale = float(global_down) * float(resolution_scale)
+ resolution = (int(orig_w / scale), int(orig_h / scale))
+
+ resized_image_rgb = PILtoTorch(cam_info.image, resolution)
+
+ gt_image = resized_image_rgb[:3, ...]
+ loaded_mask = None
+
+ if resized_image_rgb.shape[1] == 4:
+ loaded_mask = resized_image_rgb[3:4, ...]
+
+ return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
+ FoVx=cam_info.FovX, FoVy=cam_info.FovY,
+ image=gt_image, gt_alpha_mask=loaded_mask,
+ image_name=cam_info.image_name, uid=id, data_device=args.data_device)
+
+def cameraList_from_camInfos(cam_infos, resolution_scale, args):
+ camera_list = []
+
+ for id, c in enumerate(cam_infos):
+ camera_list.append(loadCam(args, id, c, resolution_scale))
+
+ return camera_list
+
+def camera_to_JSON(id, camera : Camera):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = camera.R.transpose()
+ Rt[:3, 3] = camera.T
+ Rt[3, 3] = 1.0
+
+ W2C = np.linalg.inv(Rt)
+ pos = W2C[:3, 3]
+ rot = W2C[:3, :3]
+ serializable_array_2d = [x.tolist() for x in rot]
+ camera_entry = {
+ 'id' : id,
+ 'img_name' : camera.image_name,
+ 'width' : camera.width,
+ 'height' : camera.height,
+ 'position': pos.tolist(),
+ 'rotation': serializable_array_2d,
+ 'fy' : fov2focal(camera.FovY, camera.height),
+ 'fx' : fov2focal(camera.FovX, camera.width)
+ }
+ return camera_entry
diff --git a/utils/gs_utils/general_utils.py b/utils/gs_utils/general_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d
--- /dev/null
+++ b/utils/gs_utils/general_utils.py
@@ -0,0 +1,133 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import sys
+from datetime import datetime
+import numpy as np
+import random
+
+def inverse_sigmoid(x):
+ return torch.log(x/(1-x))
+
+def PILtoTorch(pil_image, resolution):
+ resized_image_PIL = pil_image.resize(resolution)
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
+ if len(resized_image.shape) == 3:
+ return resized_image.permute(2, 0, 1)
+ else:
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
+
+def get_expon_lr_func(
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+ """
+ Copied from Plenoxels
+
+ Continuous learning rate decay function. Adapted from JaxNeRF
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+ function of lr_delay_mult, such that the initial learning rate is
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+ to the normal learning rate when steps>lr_delay_steps.
+ :param conf: config subtree 'lr' or similar
+ :param max_steps: int, the number of steps during optimization.
+ :return HoF which takes step as input
+ """
+
+ def helper(step):
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+ # Disable this parameter
+ return 0.0
+ if lr_delay_steps > 0:
+ # A kind of reverse cosine decay.
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+ )
+ else:
+ delay_rate = 1.0
+ t = np.clip(step / max_steps, 0, 1)
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+ return delay_rate * log_lerp
+
+ return helper
+
+def strip_lowerdiag(L):
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
+
+ uncertainty[:, 0] = L[:, 0, 0]
+ uncertainty[:, 1] = L[:, 0, 1]
+ uncertainty[:, 2] = L[:, 0, 2]
+ uncertainty[:, 3] = L[:, 1, 1]
+ uncertainty[:, 4] = L[:, 1, 2]
+ uncertainty[:, 5] = L[:, 2, 2]
+ return uncertainty
+
+def strip_symmetric(sym):
+ return strip_lowerdiag(sym)
+
+def build_rotation(r):
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
+
+ q = r / norm[:, None]
+
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+ r = q[:, 0]
+ x = q[:, 1]
+ y = q[:, 2]
+ z = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+ R[:, 0, 1] = 2 * (x*y - r*z)
+ R[:, 0, 2] = 2 * (x*z + r*y)
+ R[:, 1, 0] = 2 * (x*y + r*z)
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+ R[:, 1, 2] = 2 * (y*z - r*x)
+ R[:, 2, 0] = 2 * (x*z - r*y)
+ R[:, 2, 1] = 2 * (y*z + r*x)
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+ return R
+
+def build_scaling_rotation(s, r):
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+ R = build_rotation(r)
+
+ L[:,0,0] = s[:,0]
+ L[:,1,1] = s[:,1]
+ L[:,2,2] = s[:,2]
+
+ L = R @ L
+ return L
+
+def safe_state(silent):
+ old_f = sys.stdout
+ class F:
+ def __init__(self, silent):
+ self.silent = silent
+
+ def write(self, x):
+ if not self.silent:
+ if x.endswith("\n"):
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
+ else:
+ old_f.write(x)
+
+ def flush(self):
+ old_f.flush()
+
+ sys.stdout = F(silent)
+
+ random.seed(0)
+ np.random.seed(0)
+ torch.manual_seed(0)
+ torch.cuda.set_device(torch.device("cuda:0"))
diff --git a/utils/gs_utils/graphics_utils.py b/utils/gs_utils/graphics_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb2659360987236de470acda9ec581cd8d53456
--- /dev/null
+++ b/utils/gs_utils/graphics_utils.py
@@ -0,0 +1,91 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import math
+import numpy as np
+from typing import NamedTuple
+
+class BasicPointCloud(NamedTuple):
+ points : np.array
+ colors : np.array
+ normals : np.array
+
+def geom_transform_points(points, transf_matrix):
+ P, _ = points.shape
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
+ points_hom = torch.cat([points, ones], dim=1)
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
+
+ denom = points_out[..., 3:] + 0.0000001
+ return (points_out[..., :3] / denom).squeeze(dim=0)
+
+def getWorld2View(R, t):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+ return np.float32(Rt)
+
+def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+
+ C2W = np.linalg.inv(Rt)
+ cam_center = C2W[:3, 3]
+ cam_center = (cam_center + translate) * scale
+ C2W[:3, 3] = cam_center
+ Rt = np.linalg.inv(C2W)
+ return np.float32(Rt)
+
+def getView2World(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+
+ C2W = np.linalg.inv(Rt)
+ cam_center = C2W[:3, 3]
+ cam_center = (cam_center + translate) * scale
+ C2W[:3, 3] = cam_center
+ Rt = C2W
+ return np.float32(Rt)
+
+
+def getProjectionMatrix(znear, zfar, fovX, fovY):
+ tanHalfFovY = math.tan((fovY / 2))
+ tanHalfFovX = math.tan((fovX / 2))
+
+ top = tanHalfFovY * znear
+ bottom = -top
+ right = tanHalfFovX * znear
+ left = -right
+
+ P = torch.zeros(4, 4)
+
+ z_sign = 1.0
+
+ P[0, 0] = 2.0 * znear / (right - left)
+ P[1, 1] = 2.0 * znear / (top - bottom)
+ P[0, 2] = (right + left) / (right - left)
+ P[1, 2] = (top + bottom) / (top - bottom)
+ P[3, 2] = z_sign
+ P[2, 2] = z_sign * zfar / (zfar - znear)
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
+ return P
+
+def fov2focal(fov, pixels):
+ return pixels / (2 * math.tan(fov / 2))
+
+def focal2fov(focal, pixels):
+ return 2*math.atan(pixels/(2*focal))
\ No newline at end of file
diff --git a/utils/gs_utils/image_utils.py b/utils/gs_utils/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdeaa1b6d250e549181ab165070f82ccd31b3eb9
--- /dev/null
+++ b/utils/gs_utils/image_utils.py
@@ -0,0 +1,19 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+
+def mse(img1, img2):
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+
+def psnr(img1, img2):
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
diff --git a/utils/gs_utils/loss_utils.py b/utils/gs_utils/loss_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9defc23a913e5d861aa5adc63270050884923094
--- /dev/null
+++ b/utils/gs_utils/loss_utils.py
@@ -0,0 +1,64 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from math import exp
+
+def l1_loss(network_output, gt):
+ return torch.abs((network_output - gt)).mean()
+
+def l2_loss(network_output, gt):
+ return ((network_output - gt) ** 2).mean()
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ channel = img1.size(-3)
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, size_average)
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
diff --git a/utils/gs_utils/sh_utils.py b/utils/gs_utils/sh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785
--- /dev/null
+++ b/utils/gs_utils/sh_utils.py
@@ -0,0 +1,118 @@
+# Copyright 2021 The PlenOctree Authors.
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+import torch
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+]
+C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+]
+C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761,
+]
+
+
+def eval_sh(deg, sh, dirs):
+ """
+ Evaluate spherical harmonics at unit directions
+ using hardcoded SH polynomials.
+ Works with torch/np/jnp.
+ ... Can be 0 or more batch dimensions.
+ Args:
+ deg: int SH deg. Currently, 0-3 supported
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+ dirs: jnp.ndarray unit directions [..., 3]
+ Returns:
+ [..., C]
+ """
+ assert deg <= 4 and deg >= 0
+ coeff = (deg + 1) ** 2
+ assert sh.shape[-1] >= coeff
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+def RGB2SH(rgb):
+ return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+ return sh * C0 + 0.5
\ No newline at end of file
diff --git a/utils/gs_utils/system_utils.py b/utils/gs_utils/system_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e
--- /dev/null
+++ b/utils/gs_utils/system_utils.py
@@ -0,0 +1,28 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+from errno import EEXIST
+from os import makedirs, path
+import os
+
+def mkdir_p(folder_path):
+ # Creates a directory. equivalent to using mkdir -p on the command line
+ try:
+ makedirs(folder_path)
+ except OSError as exc: # Python >2.5
+ if exc.errno == EEXIST and path.isdir(folder_path):
+ pass
+ else:
+ raise
+
+def searchForMaxIteration(folder):
+ saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
+ return max(saved_iters)
diff --git a/utils/torch_utils/__init__.py b/utils/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/utils/torch_utils/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/utils/torch_utils/__pycache__/__init__.cpython-39.pyc b/utils/torch_utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..555ef3872259559aac2f38e1b86b6a065a3ca7c9
Binary files /dev/null and b/utils/torch_utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/torch_utils/__pycache__/components.cpython-39.pyc b/utils/torch_utils/__pycache__/components.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..783678ae4a95999983897e1fb398abaaf5fd0609
Binary files /dev/null and b/utils/torch_utils/__pycache__/components.cpython-39.pyc differ
diff --git a/utils/torch_utils/__pycache__/custom_ops.cpython-39.pyc b/utils/torch_utils/__pycache__/custom_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a00d62b967c8673dca1535185ee6f6688559e09
Binary files /dev/null and b/utils/torch_utils/__pycache__/custom_ops.cpython-39.pyc differ
diff --git a/utils/torch_utils/__pycache__/legacy.cpython-39.pyc b/utils/torch_utils/__pycache__/legacy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92ddaa1541a6094fc7b93315f09cd9033085d4dd
Binary files /dev/null and b/utils/torch_utils/__pycache__/legacy.cpython-39.pyc differ
diff --git a/utils/torch_utils/__pycache__/misc.cpython-39.pyc b/utils/torch_utils/__pycache__/misc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e8a40f756e1ddc0e2ef284d93c83e0164a4fd24
Binary files /dev/null and b/utils/torch_utils/__pycache__/misc.cpython-39.pyc differ
diff --git a/utils/torch_utils/__pycache__/persistence.cpython-39.pyc b/utils/torch_utils/__pycache__/persistence.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c05464a44adbcb0bd46166211886b66fa01129a
Binary files /dev/null and b/utils/torch_utils/__pycache__/persistence.cpython-39.pyc differ
diff --git a/utils/torch_utils/clip_practice.py b/utils/torch_utils/clip_practice.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ab54071ebc555e585657d9dbd44e14b8c9baeb
--- /dev/null
+++ b/utils/torch_utils/clip_practice.py
@@ -0,0 +1,38 @@
+import torch
+import clip
+from PIL import Image
+
+from pdb import set_trace as st
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model, preprocess = clip.load("ViT-B/16", device=device)
+
+image = preprocess(Image.open("utils.torch_utils/CLIP.png")).unsqueeze(0).to(device)
+text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
+
+# with torch.no_grad():
+# image_features = model.encode_image(image)
+# text_features = model.encode_text(text)
+
+# logits_per_image, logits_per_text = model(image, text)
+# probs = logits_per_image.softmax(dim=-1).cpu().numpy()
+
+with torch.no_grad():
+ x = image.type(model.dtype) # 1 3 224 224
+ self = model.visual
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD , 1, 50, 768
+ st()
+
+ pass
+
+
+print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
\ No newline at end of file
diff --git a/utils/torch_utils/components.py b/utils/torch_utils/components.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b9aa055759bd0c6972abe790304621d036789e
--- /dev/null
+++ b/utils/torch_utils/components.py
@@ -0,0 +1,445 @@
+# https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+# https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py#L812
+
+import copy
+import math
+from collections import namedtuple
+from contextlib import contextmanager, nullcontext
+from functools import partial, wraps
+from pathlib import Path
+from random import random
+
+from einops import rearrange, repeat, reduce, pack, unpack
+
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+from torch import einsum, nn
+from beartype.typing import List, Union
+from beartype import beartype
+from tqdm.auto import tqdm
+from pdb import set_trace as st
+
+# helper functions, from:
+# https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+
+
+def exists(val):
+ return val is not None
+
+
+def identity(t, *args, **kwargs):
+ return t
+
+
+def divisible_by(numer, denom):
+ return (numer % denom) == 0
+
+
+def first(arr, d=None):
+ if len(arr) == 0:
+ return d
+ return arr[0]
+
+
+def maybe(fn):
+ @wraps(fn)
+ def inner(x):
+ if not exists(x):
+ return x
+ return fn(x)
+
+ return inner
+
+
+def once(fn):
+ called = False
+
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+
+ return inner
+
+
+print_once = once(print)
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+def compact(input_dict):
+ return {key: value for key, value in input_dict.items() if exists(value)}
+
+
+def maybe_transform_dict_key(input_dict, key, fn):
+ if key not in input_dict:
+ return input_dict
+
+ copied_dict = input_dict.copy()
+ copied_dict[key] = fn(copied_dict[key])
+ return copied_dict
+
+
+def cast_uint8_images_to_float(images):
+ if not images.dtype == torch.uint8:
+ return images
+ return images / 255
+
+
+def module_device(module):
+ return next(module.parameters()).device
+
+
+def zero_init_(m):
+ nn.init.zeros_(m.weight)
+ if exists(m.bias):
+ nn.init.zeros_(m.bias)
+
+
+def eval_decorator(fn):
+ def inner(model, *args, **kwargs):
+ was_training = model.training
+ model.eval()
+ out = fn(model, *args, **kwargs)
+ model.train(was_training)
+ return out
+
+ return inner
+
+
+def pad_tuple_to_length(t, length, fillvalue=None):
+ remain_length = length - len(t)
+ if remain_length <= 0:
+ return t
+ return (*t, *((fillvalue, ) * remain_length))
+
+
+# helper classes
+
+
+class Identity(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+
+# tensor helpers
+
+
+def log(t, eps: float = 1e-12):
+ return torch.log(t.clamp(min=eps))
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1)
+
+
+def right_pad_dims_to(x, t):
+ padding_dims = x.ndim - t.ndim
+ if padding_dims <= 0:
+ return t
+ return t.view(*t.shape, *((1, ) * padding_dims))
+
+
+def masked_mean(t, *, dim, mask=None):
+ if not exists(mask):
+ return t.mean(dim=dim)
+
+ denom = mask.sum(dim=dim, keepdim=True)
+ mask = rearrange(mask, 'b n -> b n 1')
+ masked_t = t.masked_fill(~mask, 0.)
+
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
+
+
+def resize_image_to(image,
+ target_image_size,
+ clamp_range=None,
+ mode='nearest'):
+ orig_image_size = image.shape[-1]
+
+ if orig_image_size == target_image_size:
+ return image
+
+ out = F.interpolate(image, target_image_size, mode=mode)
+
+ if exists(clamp_range):
+ out = out.clamp(*clamp_range)
+
+ return out
+
+
+def calc_all_frame_dims(downsample_factors: List[int], frames):
+ if not exists(frames):
+ return (tuple(), ) * len(downsample_factors)
+
+ all_frame_dims = []
+
+ for divisor in downsample_factors:
+ assert divisible_by(frames, divisor)
+ all_frame_dims.append((frames // divisor, ))
+
+ return all_frame_dims
+
+
+def safe_get_tuple_index(tup, index, default=None):
+ if len(tup) <= index:
+ return default
+ return tup[index]
+
+
+# image normalization functions
+# ddpms expect images to be in the range of -1 to 1
+
+
+def normalize_neg_one_to_one(img):
+ return img * 2 - 1
+
+
+def unnormalize_zero_to_one(normed_img):
+ return (normed_img + 1) * 0.5
+
+
+# def Upsample(dim, dim_out=None):
+# dim_out = default(dim_out, dim)
+
+# return nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
+# nn.Conv2d(dim, dim_out, 3, padding=1))
+
+
+
+class PixelShuffleUpsample(nn.Module):
+ """
+ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
+ https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
+ """
+ def __init__(self, dim, dim_out=None):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
+
+ self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
+
+ self.init_conv_(conv)
+
+ def init_conv_(self, conv):
+ o, i, h, w = conv.weight.shape
+ conv_weight = torch.empty(o // 4, i, h, w)
+ nn.init.kaiming_uniform_(conv_weight)
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
+
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self,
+ dim_in,
+ dim_out,
+ dim_inter=None,
+ use_norm=True,
+ norm_layer=nn.BatchNorm2d,
+ bias=False):
+ super().__init__()
+ if dim_inter is None:
+ dim_inter = dim_out
+
+ if use_norm:
+ self.conv = nn.Sequential(
+ norm_layer(dim_in),
+ nn.ReLU(True),
+ nn.Conv2d(dim_in,
+ dim_inter,
+ 3,
+ 1,
+ 1,
+ bias=bias,
+ padding_mode='reflect'),
+ norm_layer(dim_inter),
+ nn.ReLU(True),
+ nn.Conv2d(dim_inter,
+ dim_out,
+ 3,
+ 1,
+ 1,
+ bias=bias,
+ padding_mode='reflect'),
+ )
+ else:
+ self.conv = nn.Sequential(
+ nn.ReLU(True),
+ nn.Conv2d(dim_in, dim_inter, 3, 1, 1),
+ nn.ReLU(True),
+ nn.Conv2d(dim_inter, dim_out, 3, 1, 1),
+ )
+
+ self.short_cut = None
+ if dim_in != dim_out:
+ self.short_cut = nn.Conv2d(dim_in, dim_out, 1, 1)
+
+ def forward(self, feats):
+ feats_out = self.conv(feats)
+ if self.short_cut is not None:
+ feats_out = self.short_cut(feats) + feats_out
+ else:
+ feats_out = feats_out + feats
+ return feats_out
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. '
+ 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+class PixelUnshuffleUpsample(nn.Module):
+ def __init__(self, output_dim, num_feat=128, num_out_ch=3, sr_ratio=4, *args, **kwargs) -> None:
+ super().__init__()
+
+ self.conv_after_body = nn.Conv2d(output_dim, output_dim, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(output_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(sr_ratio, num_feat) # 4 time SR
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ def forward(self, x, input_skip_connection=True, *args, **kwargs):
+ # x = self.conv_first(x)
+ if input_skip_connection:
+ x = self.conv_after_body(x) + x
+ else:
+ x = self.conv_after_body(x)
+
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ return x
+
+
+class Conv3x3TriplaneTransformation(nn.Module):
+ # used in the final layer before triplane
+ def __init__(self, input_dim, output_dim) -> None:
+ super().__init__()
+
+ self.conv_after_unpachify = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True)
+ )
+
+ self.conv_before_rendering = nn.Sequential(
+ nn.Conv2d(output_dim, output_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+
+ def forward(self, unpachified_latent):
+ latent = self.conv_after_unpachify(unpachified_latent) # no residual connections here
+ latent = self.conv_before_rendering(latent) + latent
+ return latent
+
+
+# https://github.com/JingyunLiang/SwinIR/blob/6545850fbf8df298df73d81f3e8cba638787c8bd/models/network_swinir.py#L750
+class NearestConvSR(nn.Module):
+ """
+ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
+ https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
+ """
+ def __init__(self, output_dim, num_feat=128, num_out_ch=3, sr_ratio=4, *args, **kwargs) -> None:
+ super().__init__()
+
+ self.upscale = sr_ratio
+
+ self.conv_after_body = nn.Conv2d(output_dim, output_dim, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(output_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ if self.upscale == 4:
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x, *args, **kwargs):
+
+ # x = self.conv_first(x)
+ x = self.conv_after_body(x) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ if self.upscale == 4:
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+
+
+ return x
+
+# https://github.com/yumingj/C2-Matching/blob/fa171ca6707c6f16a5d04194ce866ea70bb21d2b/mmsr/models/archs/ref_restoration_arch.py#L65
+class NearestConvSR_Residual(NearestConvSR):
+ # learn residual + normalize
+
+ def __init__(self, output_dim, num_feat=128, num_out_ch=3, sr_ratio=4, *args, **kwargs) -> None:
+ super().__init__(output_dim, num_feat, num_out_ch, sr_ratio, *args, **kwargs)
+ # self.mean = torch.Tensor((0.485, 0.456, 0.406)).view(1,3,1,1) # imagenet mean
+ self.act = nn.Tanh()
+
+ def forward(self, x, base_x, *args, **kwargs):
+ # base_x: low-resolution 3D rendering, for residual addition
+ # self.mean = self.mean.type_as(x)
+ # x = super().forward(x).clamp(-1,1)
+ x = super().forward(x)
+ x = self.act(x) # residual normalize to [-1,1]
+ scale = x.shape[-1] // base_x.shape[-1] # 2 or 4
+ x = x + F.interpolate(base_x, None, scale, 'bilinear', False) # add residual; [-1,1] range
+
+ # return x + 2 * self.mean
+ return x
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+# class PixelShuffledDirect(nn.Module):
diff --git a/utils/torch_utils/custom_ops.py b/utils/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7dabe00493504ffb30ab8962b8c975b13d5c0d9
--- /dev/null
+++ b/utils/torch_utils/custom_ops.py
@@ -0,0 +1,187 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+
+#----------------------------------------------------------------------------
+
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+
+def get_plugin(module_name,
+ sources,
+ headers=None,
+ source_dir=None,
+ **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ',
+ end='',
+ flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(
+ f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".'
+ )
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(
+ os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs
+ ) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(
+ module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(
+ build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(
+ src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [
+ os.path.join(cached_build_dir, os.path.basename(fname))
+ for fname in sources
+ ]
+ torch.utils.cpp_extension.load(name=module_name,
+ build_directory=cached_build_dir,
+ verbose=verbose_build,
+ sources=cached_sources,
+ **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name,
+ verbose=verbose_build,
+ sources=sources,
+ **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ddp_practice.py b/utils/torch_utils/ddp_practice.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc96d0d39aebd68e66f4cf2292f29589108c34ea
--- /dev/null
+++ b/utils/torch_utils/ddp_practice.py
@@ -0,0 +1,285 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from argparse import ArgumentParser
+import sys
+import os
+
+sys.path.append('..')
+sys.path.append('.')
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, Dataset
+from torch.utils.data.distributed import DistributedSampler
+
+from vit.vision_transformer import VisionTransformer as ViT
+from vit.vit_triplane import ViTTriplane
+from guided_diffusion import dist_util, logger
+
+import click
+import dnnlib
+
+SEED = 42
+BATCH_SIZE = 8
+NUM_EPOCHS = 1
+
+
+class YourDataset(Dataset):
+ def __init__(self):
+ pass
+
+
+@click.command()
+@click.option('--cfg', help='Base configuration', type=str, default='ffhq')
+@click.option('--sr-module',
+ help='Superresolution module override',
+ metavar='STR',
+ required=False,
+ default=None)
+@click.option('--density_reg',
+ help='Density regularization strength.',
+ metavar='FLOAT',
+ type=click.FloatRange(min=0),
+ default=0.25,
+ required=False,
+ show_default=True)
+@click.option('--density_reg_every',
+ help='lazy density reg',
+ metavar='int',
+ type=click.FloatRange(min=1),
+ default=4,
+ required=False,
+ show_default=True)
+@click.option('--density_reg_p_dist',
+ help='density regularization strength.',
+ metavar='FLOAT',
+ type=click.FloatRange(min=0),
+ default=0.004,
+ required=False,
+ show_default=True)
+@click.option('--reg_type',
+ help='Type of regularization',
+ metavar='STR',
+ type=click.Choice([
+ 'l1', 'l1-alt', 'monotonic-detach', 'monotonic-fixed',
+ 'total-variation'
+ ]),
+ required=False,
+ default='l1')
+@click.option('--decoder_lr_mul',
+ help='decoder learning rate multiplier.',
+ metavar='FLOAT',
+ type=click.FloatRange(min=0),
+ default=1,
+ required=False,
+ show_default=True)
+@click.option('--c_scale',
+ help='Scale factor for generator pose conditioning.',
+ metavar='FLOAT',
+ type=click.FloatRange(min=0),
+ required=False,
+ default=1)
+def main(**kwargs):
+ # parser = ArgumentParser('DDP usage example')
+ # parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work
+ # args = parser.parse_args()
+
+ opts = dnnlib.EasyDict(kwargs) # Command line arguments.
+ c = dnnlib.EasyDict() # Main config dict.
+
+ rendering_options = {
+ # 'image_resolution': c.training_set_kwargs.resolution,
+ 'image_resolution': 256,
+ 'disparity_space_sampling': False,
+ 'clamp_mode': 'softplus',
+ # 'superresolution_module': sr_module,
+ # 'c_gen_conditioning_zero': not opts.
+ # gen_pose_cond, # if true, fill generator pose conditioning label with dummy zero vector
+ # 'gpc_reg_prob': opts.gpc_reg_prob if opts.gen_pose_cond else None,
+ 'c_scale':
+ opts.c_scale, # mutliplier for generator pose conditioning label
+ # 'superresolution_noise_mode': opts.
+ # sr_noise_mode, # [random or none], whether to inject pixel noise into super-resolution layers
+ 'density_reg': opts.density_reg, # strength of density regularization
+ 'density_reg_p_dist': opts.
+ density_reg_p_dist, # distance at which to sample perturbed points for density regularization
+ 'reg_type': opts.
+ reg_type, # for experimenting with variations on density regularization
+ 'decoder_lr_mul':
+ opts.decoder_lr_mul, # learning rate multiplier for decoder
+ 'sr_antialias': True,
+ 'return_triplane_features': True, # for DDF supervision
+ 'return_sampling_details_flag': True,
+ }
+
+ if opts.cfg == 'ffhq':
+ rendering_options.update({
+ 'focal': 2985.29 / 700,
+ 'depth_resolution':
+ # 48, # number of uniform samples to take per ray.
+ 36, # number of uniform samples to take per ray.
+ 'depth_resolution_importance':
+ # 48, # number of importance samples to take per ray.
+ 36, # number of importance samples to take per ray.
+ 'ray_start':
+ 2.25, # near point along each ray to start taking samples.
+ 'ray_end':
+ 3.3, # far point along each ray to stop taking samples.
+ 'box_warp':
+ 1, # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5].
+ 'avg_camera_radius':
+ 2.7, # used only in the visualizer to specify camera orbit radius.
+ 'avg_camera_pivot': [
+ 0, 0, 0.2
+ ], # used only in the visualizer to control center of camera rotation.
+ })
+ elif opts.cfg == 'afhq':
+ rendering_options.update({
+ 'focal': 4.2647,
+ 'depth_resolution': 48,
+ 'depth_resolution_importance': 48,
+ 'ray_start': 2.25,
+ 'ray_end': 3.3,
+ 'box_warp': 1,
+ 'avg_camera_radius': 2.7,
+ 'avg_camera_pivot': [0, 0, -0.06],
+ })
+ elif opts.cfg == 'shapenet':
+ rendering_options.update({
+ 'depth_resolution': 64,
+ 'depth_resolution_importance': 64,
+ # 'ray_start': 0.1,
+ # 'ray_end': 2.6,
+ 'ray_start': 0.1,
+ 'ray_end': 3.3,
+ 'box_warp': 1.6,
+ 'white_back': True,
+ 'avg_camera_radius': 1.7,
+ 'avg_camera_pivot': [0, 0, 0],
+ })
+ else:
+ assert False, "Need to specify config"
+
+ c.rendering_kwargs = rendering_options
+
+ args = opts
+
+ # keep track of whether the current process is the `master` process (totally optional, but I find it useful for data laoding, logging, etc.)
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+ args.is_master = args.local_rank == 0
+
+ # set the device
+ # device = torch.cuda.device(args.local_rank)
+ device = torch.device(f"cuda:{args.local_rank}")
+
+ # initialize PyTorch distributed using environment variables (you could also do this more explicitly by specifying `rank` and `world_size`, but I find using environment variables makes it so that you can easily use the same script on different machines)
+ dist.init_process_group(backend='nccl',
+ init_method='env://',
+ rank=args.local_rank,
+ world_size=torch.cuda.device_count())
+ print(f"{args.local_rank=} init complete")
+ torch.cuda.set_device(args.local_rank)
+
+ # set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.)
+ torch.cuda.manual_seed_all(SEED)
+
+ # initialize your model (BERT in this example)
+ # model = BertForMaskedLM.from_pretrained('bert-base-uncased')
+
+ # model = ViT(
+ # image_size = 256,
+ # patch_size = 32,
+ # num_classes = 1000,
+ # dim = 1024,
+ # depth = 6,
+ # heads = 16,
+ # mlp_dim = 2048,
+ # dropout = 0.1,
+ # emb_dropout = 0.1
+ # )
+
+ # TODO, check pre-trained ViT encoder cfgs
+ model = ViTTriplane(
+ img_size=[224],
+ patch_size=16,
+ in_chans=384,
+ num_classes=0,
+ embed_dim=384, # Check ViT encoder dim
+ depth=2,
+ num_heads=16,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.1,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ out_chans=96,
+ c_dim=25, # Conditioning label (C) dimensionality.
+ img_resolution=128, # Output resolution.
+ img_channels=3, # Number of output color channels.
+ cls_token=False,
+ # TODO, replace with c
+ rendering_kwargs=c.rendering_kwargs,
+ )
+ # noise = torch.randn(1, 8, 8, 1024)
+
+ # send your model to GPU
+ model = model.to(device)
+
+ # initialize distributed data parallel (DDP)
+ model = DDP(model,
+ device_ids=[args.local_rank],
+ output_device=args.local_rank)
+
+ dist_util.sync_params(model.named_parameters())
+
+ # # initialize your dataset
+ # dataset = YourDataset()
+
+ # # initialize the DistributedSampler
+ # sampler = DistributedSampler(dataset)
+
+ # # initialize the dataloader
+ # dataloader = DataLoader(
+ # dataset=dataset,
+ # sampler=sampler,
+ # batch_size=BATCH_SIZE
+ # )
+
+ # start your training!
+ for epoch in range(NUM_EPOCHS):
+ # put model in train mode
+ model.train()
+
+ # let all processes sync up before starting with a new epoch of training
+ dist.barrier()
+
+ noise = torch.randn(1, 14 * 14, 384).to(device) # B, L, C
+ img = model(noise, torch.zeros(1, 25).to(device))
+ print(img['image'].shape)
+ # st()
+
+ # img = torch.randn(1, 3, 256, 256).to(device)
+
+ # preds = model(img)
+ # print(preds.shape)
+ # assert preds.shape == (1, 1000), 'correct logits outputted'
+
+ # for step, batch in enumerate(dataloader):
+ # # send batch to device
+ # batch = tuple(t.to(args.device) for t in batch)
+
+ # # forward pass
+ # outputs = model(*batch)
+
+ # # compute loss
+ # loss = outputs[0]
+
+ # # etc.
+
+
+if __name__ == '__main__':
+ main()
diff --git a/utils/torch_utils/dist_practice.py b/utils/torch_utils/dist_practice.py
new file mode 100644
index 0000000000000000000000000000000000000000..e103870819b989b1ea1ed5f1f9deeecb436cf9e3
--- /dev/null
+++ b/utils/torch_utils/dist_practice.py
@@ -0,0 +1,43 @@
+import torch
+import torch.multiprocessing as mp
+import torch.distributed as dist
+import os
+
+
+def find_free_port():
+ """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
+ import socket
+ from contextlib import closing
+
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(('', 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return str(s.getsockname()[1])
+
+
+def setup_process(rank, master_addr, master_port, world_size, backend='nccl'):
+ print(f'setting up {rank=} {world_size=} {backend=}')
+
+ # set up the master's ip address so this child process can coordinate
+ os.environ['MASTER_ADDR'] = master_addr
+ os.environ['MASTER_PORT'] = master_port
+ print(f"{master_addr=} {master_port=}")
+
+ # Initializes the default distributed process group, and this will also initialize the distributed package.
+ dist.init_process_group(backend, rank=rank, world_size=world_size)
+ print(f"{rank=} init complete")
+ dist.destroy_process_group()
+ print(f"{rank=} destroy complete")
+
+
+if __name__ == '__main__':
+ world_size = 2
+ master_addr = '127.0.0.1'
+ master_port = find_free_port()
+ mp.spawn(setup_process,
+ args=(
+ master_addr,
+ master_port,
+ world_size,
+ ),
+ nprocs=world_size)
diff --git a/utils/torch_utils/distributions/__init__.py b/utils/torch_utils/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/torch_utils/distributions/__pycache__/__init__.cpython-39.pyc b/utils/torch_utils/distributions/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecd268b82cd088155a509adfa1765ff511c37322
Binary files /dev/null and b/utils/torch_utils/distributions/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/torch_utils/distributions/__pycache__/distributions.cpython-39.pyc b/utils/torch_utils/distributions/__pycache__/distributions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..777c05116fff69c43bdddee42d927dacf61808fe
Binary files /dev/null and b/utils/torch_utils/distributions/__pycache__/distributions.cpython-39.pyc differ
diff --git a/utils/torch_utils/distributions/distributions.py b/utils/torch_utils/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..d01355749a0f55bad75b01083c2e086e8d203578
--- /dev/null
+++ b/utils/torch_utils/distributions/distributions.py
@@ -0,0 +1,138 @@
+# https://raw.githubusercontent.com/CompVis/latent-diffusion/e66308c7f2e64cb581c6d27ab6fbeb846828253b/ldm/modules/distributions/distributions.py
+
+import torch
+import numpy as np
+from pdb import set_trace as st
+
+
+class AbstractDistribution:
+
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+@torch.jit.script
+def soft_clamp20(x: torch.Tensor):
+ # return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
+ # return x.div(5.).tanh().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
+ # return x.div(15.).tanh().mul(15.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
+ return x.div(20.).tanh().mul(
+ 20.
+ ) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]
+
+
+# @torch.jit.script
+# def soft_clamp(x: torch.Tensor, a: torch.Tensor):
+# return x.div(a).tanh_().mul(a)
+
+
+class DiagonalGaussianDistribution(object):
+
+ def __init__(self, parameters, deterministic=False, soft_clamp=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+
+ if soft_clamp:
+ # self.mean, self.logvar = soft_clamp5(self.mean), soft_clamp5(self.logvar) # as in LSGM, bound the range. needs re-training?
+ self.logvar = soft_clamp20(
+ self.logvar) # as in LSGM, bound the range. [-20, 20]
+ else:
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(
+ self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ # https://github.dev/NVlabs/LSGM/util/distributions.py
+ def log_p(self, samples):
+ # for calculating the negative encoder entropy term
+ normalized_samples = (samples - self.mean) / self.var
+ log_p = -0.5 * normalized_samples * normalized_samples - 0.5 * np.log(
+ 2 * np.pi) - self.logvar #
+
+ return log_p # ! TODO
+
+ def normal_entropy(self):
+ # for calculating normal entropy. Motivation: supervise logvar directly.
+ # normalized_samples = (samples - self.mean) / self.var
+ # log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.logvar #
+ # entropy = torch.sum(self.logvar + 0.5 * (np.log(2 * np.pi) + 1),
+ # dim=[1, 2, 3]).mean(0)
+ # entropy = torch.mean(self.logvar + 0.5 * (np.log(2 * np.pi) + 1)) # follow eps loss tradition here, average overall dims.
+ entropy = self.logvar + 0.5 * (np.log(2 * np.pi) + 1) # follow eps loss tradition here, average overall dims.
+
+
+ return entropy # ! TODO
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var +
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar +
+ torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
+ ((mean1 - mean2)**2) * torch.exp(-logvar2))
diff --git a/utils/torch_utils/inference_matt.py b/utils/torch_utils/inference_matt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b80917995a2fe280a08029fcfe8f88e5973fcf9
--- /dev/null
+++ b/utils/torch_utils/inference_matt.py
@@ -0,0 +1,139 @@
+# https://github.com/xinntao/facexlib/blob/master/inference/inference_matting.py
+
+from tqdm import tqdm, trange
+import argparse
+from pathlib import Path
+import cv2
+import numpy as np
+import torch.nn.functional as F
+from torchvision.transforms.functional import normalize
+
+from facexlib.matting import init_matting_model
+from facexlib.utils import img2tensor
+
+
+def matt_single(args):
+ modnet = init_matting_model()
+
+ # read image
+ img = cv2.imread(args.img_path) / 255.
+ # unify image channels to 3
+ if len(img.shape) == 2:
+ img = img[:, :, None]
+ if img.shape[2] == 1:
+ img = np.repeat(img, 3, axis=2)
+ elif img.shape[2] == 4:
+ img = img[:, :, 0:3]
+
+ img_t = img2tensor(img, bgr2rgb=True, float32=True)
+ normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ img_t = img_t.unsqueeze(0).cuda()
+
+ # resize image for input
+ _, _, im_h, im_w = img_t.shape
+ ref_size = 512
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
+ if im_w >= im_h:
+ im_rh = ref_size
+ im_rw = int(im_w / im_h * ref_size)
+ elif im_w < im_h:
+ im_rw = ref_size
+ im_rh = int(im_h / im_w * ref_size)
+ else:
+ im_rh = im_h
+ im_rw = im_w
+ im_rw = im_rw - im_rw % 32
+ im_rh = im_rh - im_rh % 32
+ img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
+
+ # inference
+ _, _, matte = modnet(img_t, True)
+
+ # resize and save matte
+ matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
+ matte = matte[0][0].data.cpu().numpy()
+ cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
+
+ # get foreground
+ matte = matte[:, :, None]
+ foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
+ cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
+
+def matt_directory(args): # for extracting ffhq imgs foreground
+ modnet = init_matting_model()
+
+ all_imgs = list(Path(args.img_dir_path).rglob('*.png'))
+ print('all imgs: ', len(all_imgs))
+
+ tgt_dir_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matte/'
+ # tgt_img_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matting/'
+
+ for img_path in tqdm(all_imgs):
+
+ # read image
+ # img = cv2.imread(args.img_path) / 255.
+ img = cv2.imread(str(img_path)) / 255.
+
+ relative_img_path = Path(img_path).relative_to('/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_512/')
+ tgt_save_path = tgt_dir_path / relative_img_path
+
+ (tgt_save_path.parent).mkdir(parents=True, exist_ok=True)
+
+ # unify image channels to 3
+ if len(img.shape) == 2:
+ img = img[:, :, None]
+ if img.shape[2] == 1:
+ img = np.repeat(img, 3, axis=2)
+ elif img.shape[2] == 4:
+ img = img[:, :, 0:3]
+
+ img_t = img2tensor(img, bgr2rgb=True, float32=True)
+ normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ img_t = img_t.unsqueeze(0).cuda()
+
+ # resize image for input
+ _, _, im_h, im_w = img_t.shape
+ ref_size = 512
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
+ if im_w >= im_h:
+ im_rh = ref_size
+ im_rw = int(im_w / im_h * ref_size)
+ elif im_w < im_h:
+ im_rw = ref_size
+ im_rh = int(im_h / im_w * ref_size)
+ else:
+ im_rh = im_h
+ im_rw = im_w
+ im_rw = im_rw - im_rw % 32
+ im_rh = im_rh - im_rh % 32
+ img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
+
+ # inference
+ _, _, matte = modnet(img_t, True)
+
+ # resize and save matte
+ matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
+ matte = matte[0][0].data.cpu().numpy()
+ # cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
+ cv2.imwrite(str(tgt_save_path), (matte * 255).astype('uint8'))
+
+ assert tgt_save_path.exists()
+
+ # get foreground
+ # matte = matte[:, :, None]
+ # foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
+ # cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
+
+ pass
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--img_path', type=str, default='assets/test.jpg')
+ parser.add_argument('--save_path', type=str, default='test_matting.png')
+
+ parser.add_argument('--img_dir_path', type=str, default='assets', required=False)
+ args = parser.parse_args()
+
+ # matt_single(args)
+ matt_directory(args)
\ No newline at end of file
diff --git a/utils/torch_utils/legacy.py b/utils/torch_utils/legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab45e1d7e7c751817570b7cb455f9c555deab5d9
--- /dev/null
+++ b/utils/torch_utils/legacy.py
@@ -0,0 +1,368 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Converting legacy network pickle into the new format."""
+
+from pdb import set_trace as st
+import click
+import pickle
+import re
+import copy
+import numpy as np
+import torch
+import dnnlib
+from utils.torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def load_network_pkl(f, device, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_G, tf_D, tf_Gs = data
+ G = convert_tf_generator(tf_G)
+ D = convert_tf_discriminator(tf_D)
+ G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=G, D=D, G_ema=G_ema)
+
+ # for k, module in data.items():
+ # for key in ['G', 'D', 'G_ema']:
+ # data[key].to(device)
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ assert isinstance(data['G'], torch.nn.Module)
+ assert isinstance(data['D'], torch.nn.Module)
+ assert isinstance(data['G_ema'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['G', 'D', 'G_ema']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
+ fp16_kwargs.num_fp16_res = 4
+ fp16_kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ return data
+
+def load_network_pkl_E(f, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_E = data
+ E = convert_tf_generator(tf_E)
+ # D = convert_tf_discriminator(tf_D)
+ # G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=E)
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ assert isinstance(data['E'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['E']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
+ fp16_kwargs.num_fp16_res = 4
+ fp16_kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ return data
+#----------------------------------------------------------------------------
+
+class _TFNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _LegacyUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'dnnlib.tflib.network' and name == 'Network':
+ return _TFNetworkStub
+ if 'training' in module:
+ module = module.replace('training', 'nsr') # map module position from eg3d repo
+
+ return super().find_class(module, name)
+
+#----------------------------------------------------------------------------
+
+def _collect_tf_params(tf_net):
+ # pylint: disable=protected-access
+ tf_params = dict()
+ def recurse(prefix, tf_net):
+ for name, value in tf_net.variables:
+ tf_params[prefix + name] = value
+ for name, comp in tf_net.components.items():
+ recurse(prefix + name + '/', comp)
+ recurse('', tf_net)
+ return tf_params
+
+#----------------------------------------------------------------------------
+
+def _populate_module_params(module, *patterns):
+ for name, tensor in misc.named_params_and_buffers(module):
+ found = False
+ value = None
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
+ match = re.fullmatch(pattern, name)
+ if match:
+ found = True
+ if value_fn is not None:
+ value = value_fn(*match.groups())
+ break
+ try:
+ assert found
+ if value is not None:
+ tensor.copy_(torch.from_numpy(np.array(value)))
+ except:
+ print(name, list(tensor.shape))
+ raise
+
+#----------------------------------------------------------------------------
+
+def convert_tf_generator(tf_G):
+ if tf_G.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_G.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None, none=None):
+ known_kwargs.add(tf_name)
+ val = tf_kwargs.get(tf_name, default)
+ return val if val is not None else none
+
+ # Convert kwargs.
+ from training import networks_stylegan2
+ network_class = networks_stylegan2.Generator
+ kwargs = dnnlib.EasyDict(
+ z_dim = kwarg('latent_size', 512),
+ c_dim = kwarg('label_size', 0),
+ w_dim = kwarg('dlatent_size', 512),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ architecture = kwarg('architecture', 'skip'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ use_noise = kwarg('use_noise', True),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 8),
+ embed_features = kwarg('label_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('truncation_psi')
+ kwarg('truncation_cutoff')
+ kwarg('style_mixing_prob')
+ kwarg('structure')
+ kwarg('conditioning')
+ kwarg('fused_modconv')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_G)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
+ kwargs.synthesis.kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ G = network_class(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(G,
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'.*\.resample_filter', None,
+ r'.*\.act_filter', None,
+ )
+ return G
+
+#----------------------------------------------------------------------------
+
+def convert_tf_discriminator(tf_D):
+ if tf_D.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_D.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None):
+ known_kwargs.add(tf_name)
+ return tf_kwargs.get(tf_name, default)
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ c_dim = kwarg('label_size', 0),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ architecture = kwarg('architecture', 'resnet'),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ cmap_dim = kwarg('mapping_fmaps', None),
+ block_kwargs = dnnlib.EasyDict(
+ activation = kwarg('nonlinearity', 'lrelu'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ freeze_layers = kwarg('freeze_layers', 0),
+ ),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 0),
+ embed_features = kwarg('mapping_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
+ ),
+ epilogue_kwargs = dnnlib.EasyDict(
+ mbstd_group_size = kwarg('mbstd_group_size', None),
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('structure')
+ kwarg('conditioning')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_D)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
+ kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks_stylegan2
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(D,
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
+ r'.*\.resample_filter', None,
+ )
+ return D
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--source', help='Input pickle', required=True, metavar='PATH')
+@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
+@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
+def convert_network_pickle(source, dest, force_fp16):
+ """Convert legacy network pickle into the native PyTorch format.
+
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
+
+ Example:
+
+ \b
+ python legacy.py \\
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
+ --dest=stylegan2-cat-config-f.pkl
+ """
+ print(f'Loading "{source}"...')
+ with dnnlib.util.open_url(source) as f:
+ data = load_network_pkl(f, force_fp16=force_fp16)
+ print(f'Saving "{dest}"...')
+ with open(dest, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/misc.py b/utils/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ecc786f81740321cd6917ea33f19a3deeb5b64
--- /dev/null
+++ b/utils/torch_utils/misc.py
@@ -0,0 +1,344 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+import dnnlib
+
+from guided_diffusion import dist_util, logger
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device,
+ memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0),
+ min=neginf,
+ max=posinf,
+ out=out)
+
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(
+ f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}'
+ )
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(
+ ): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size),
+ f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(
+ ): # as_tensor results are registered as constants
+ symbolic_assert(
+ torch.equal(size, torch.as_tensor(ref_size)),
+ f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(
+ f'Wrong size for dimension {idx}: got {size}, expected {ref_size}'
+ )
+
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+
+ decorator.__name__ = fn.__name__
+ return decorator
+
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self,
+ dataset,
+ rank=0,
+ num_replicas=1,
+ shuffle=True,
+ seed=0,
+ window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False, load_except=(), model_name=''):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ try:
+ if name in load_except:
+ logger.log('ignore load_except module: ', name)
+ else:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(
+ tensor.requires_grad)
+ except:
+ print(name)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module,
+ torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ # print(fullname)
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs,
+ (tuple,
+ list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+
+ hooks = [
+ mod.register_forward_pre_hook(pre_hook) for mod in module.modules()
+ ]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [
+ t for t in e.mod.parameters() if id(t) not in tensors_seen
+ ]
+ e.unique_buffers = [
+ t for t in e.mod.buffers() if id(t) not in tensors_seen
+ ]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {
+ id(t)
+ for t in e.unique_params + e.unique_buffers + e.unique_outputs
+ }
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [
+ e for e in entries if len(e.unique_params) or len(e.unique_buffers)
+ or len(e.unique_outputs)
+ ]
+
+ # Construct table.
+ rows = [[
+ type(module).__name__, 'Parameters', 'Buffers', 'Output shape',
+ 'Datatype'
+ ]]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[
+ name + f':{idx}', '-', '-', output_shapes[idx],
+ output_dtypes[idx]
+ ]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell))
+ for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/__init__.py b/utils/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/utils/torch_utils/ops/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/utils/torch_utils/ops/__pycache__/__init__.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ad9e361e36a958b1f257731c81cd87453c54820
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/__init__.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a508bc227747bb9a8a0fbb3dd2b087eb37bc4ec0
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a9fcee7951404021b29048d57c36ae95087ae42
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ebceaede18354b24b359d6800a99fe97eaea289
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/filtered_lrelu.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/filtered_lrelu.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4b7c78f08f130772541840cebf8a7bcd04311b5
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/filtered_lrelu.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/fma.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/fma.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4de1b735a44066224d0ac4a1cd213f59662cf4e3
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/fma.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..810623cf9260e799f7c9a0a1522b3bfdfea69351
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc b/utils/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c045e5990707d529c09adcd0203106af1f4011dd
Binary files /dev/null and b/utils/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc differ
diff --git a/utils/torch_utils/ops/bias_act.cpp b/utils/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ee6f6d0caaf4f84b94851d223e384344e1109cdc
--- /dev/null
+++ b/utils/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,103 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/bias_act.cu b/utils/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..71ca3900deda41e62d80044f0e409875f4c794b5
--- /dev/null
+++ b/utils/torch_utils/ops/bias_act.cu
@@ -0,0 +1,177 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/bias_act.h b/utils/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..8994bfb4e9cae790865348e08de5f685152d3344
--- /dev/null
+++ b/utils/torch_utils/ops/bias_act.h
@@ -0,0 +1,42 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/bias_act.py b/utils/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a8ec0527471d0774873eba378665014a9b30b8
--- /dev/null
+++ b/utils/torch_utils/ops/bias_act.py
@@ -0,0 +1,308 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear':
+ dnnlib.EasyDict(func=lambda x, **_: x,
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=1,
+ ref='',
+ has_2nd_grad=False),
+ 'relu':
+ dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x),
+ def_alpha=0,
+ def_gain=np.sqrt(2),
+ cuda_idx=2,
+ ref='y',
+ has_2nd_grad=False),
+ 'lrelu':
+ dnnlib.EasyDict(
+ func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
+ def_alpha=0.2,
+ def_gain=np.sqrt(2),
+ cuda_idx=3,
+ ref='y',
+ has_2nd_grad=False),
+ 'tanh':
+ dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=4,
+ ref='y',
+ has_2nd_grad=True),
+ 'sigmoid':
+ dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=5,
+ ref='y',
+ has_2nd_grad=True),
+ 'elu':
+ dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=6,
+ ref='y',
+ has_2nd_grad=True),
+ 'selu':
+ dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=7,
+ ref='y',
+ has_2nd_grad=True),
+ 'softplus':
+ dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x),
+ def_alpha=0,
+ def_gain=1,
+ cuda_idx=8,
+ ref='y',
+ has_2nd_grad=True),
+ 'swish':
+ dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x,
+ def_alpha=0,
+ def_gain=np.sqrt(2),
+ cuda_idx=9,
+ ref='x',
+ has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+
+#----------------------------------------------------------------------------
+
+
+# @torch.autocast(device_type='cuda')
+def bias_act(x,
+ b=None,
+ dim=1,
+ act='linear',
+ alpha=None,
+ gain=None,
+ clamp=None,
+ impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim,
+ act=act,
+ alpha=alpha,
+ gain=gain,
+ clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x,
+ b=b,
+ dim=dim,
+ act=act,
+ alpha=alpha,
+ gain=gain,
+ clamp=clamp)
+
+
+#----------------------------------------------------------------------------
+
+
+@misc.profiled_function
+def _bias_act_ref(x,
+ b=None,
+ dim=1,
+ act='linear',
+ alpha=None,
+ gain=None,
+ clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+
+# @torch.autocast(device_type='cuda')
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ # @torch.cuda.amp.custom_fwd
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
+ 1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor,
+ _null_tensor, 0, dim, spec.cuda_idx,
+ alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ # @torch.cuda.amp.custom_bwd
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ # @torch.cuda.amp.custom_fwd
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(
+ 1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim,
+ spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(dy if spec.has_2nd_grad else _null_tensor, x,
+ b, y)
+ return dx
+
+ @staticmethod
+ # @torch.cuda.amp.custom_bwd
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1]
+ or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim,
+ spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/conv2d_gradfix.py b/utils/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..c539b82a2a580ac301af741a65a745f9a72ef3d7
--- /dev/null
+++ b/utils/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,302 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import contextlib
+import torch
+from pdb import set_trace as st
+import traceback
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+
+@contextlib.contextmanager
+def no_weight_gradients(disable=True):
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ if disable:
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+
+#----------------------------------------------------------------------------
+
+
+def conv2d(input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False,
+ weight_shape=weight.shape,
+ stride=stride,
+ padding=padding,
+ output_padding=0,
+ dilation=dilation,
+ groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input,
+ weight=weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups)
+
+
+def conv_transpose2d(input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ groups=1,
+ dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True,
+ weight_shape=weight.shape,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input,
+ weight=weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation)
+
+
+#----------------------------------------------------------------------------
+
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ return True
+
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+_null_tensor = torch.empty([0])
+
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding,
+ dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation,
+ groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i])
+ for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups)
+
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] -
+ (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(
+ input if weight.requires_grad else _null_tensor,
+ weight if input.requires_grad else _null_tensor,
+ )
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
+ if weight_shape[2:] == stride == dilation == (
+ 1, 1) and padding == (
+ 0, 0) and torch.cuda.get_device_capability(
+ input.device) < (8, 0):
+ a = weight.reshape(groups, weight_shape[0] // groups,
+ weight_shape[1])
+ b = input.reshape(input.shape[0], groups,
+ input.shape[1] // groups, -1)
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(
+ 1, 2, 0, 3).flatten(2)
+ c = c.reshape(-1, input.shape[0],
+ *input.shape[2:]).transpose(0, 1)
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(
+ 2).unsqueeze(3)
+ return c.contiguous(
+ memory_format=(torch.channels_last if input.stride(1) ==
+ 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ if transpose:
+ return torch.nn.functional.conv_transpose2d(
+ input=input,
+ weight=weight,
+ bias=bias,
+ output_padding=output_padding,
+ **common_kwargs)
+ return torch.nn.functional.conv2d(input=input,
+ weight=weight,
+ bias=bias,
+ **common_kwargs)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ input_shape = ctx.input_shape
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input_shape,
+ output_shape=grad_output.shape)
+ op = _conv2d_gradfix(transpose=(not transpose),
+ weight_shape=weight_shape,
+ output_padding=p,
+ **common_kwargs)
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input_shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input,
+ weight)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, weight):
+ ctx.save_for_backward(
+ grad_output if input.requires_grad else _null_tensor,
+ input if grad_output.requires_grad else _null_tensor,
+ )
+ ctx.grad_output_shape = grad_output.shape
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
+ if weight_shape[2:] == stride == dilation == (
+ 1, 1) and padding == (0, 0):
+ a = grad_output.reshape(grad_output.shape[0], groups,
+ grad_output.shape[1] // groups,
+ -1).permute(1, 2, 0, 3).flatten(2)
+ b = input.reshape(input.shape[0], groups,
+ input.shape[1] // groups,
+ -1).permute(1, 2, 0, 3).flatten(2)
+ c = (b @ a.transpose(1, 2) if transpose else
+ a @ b.transpose(1, 2)).reshape(weight_shape)
+ return c.contiguous(
+ memory_format=(torch.channels_last if input.stride(1) ==
+ 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ # print(input.device, weight.device, flush=True)
+ # for line in traceback.format_stack():
+ # print(line.strip(), flush=True)
+ return torch.ops.aten.convolution_backward(
+ grad_output=grad_output,
+ input=input,
+ weight=weight,
+ bias_sizes=None,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ transposed=transpose,
+ output_padding=output_padding,
+ groups=groups,
+ output_mask=[False, True, False])[1]
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad_output_shape = ctx.grad_output_shape
+ input_shape = ctx.input_shape
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight,
+ None)
+ assert grad2_grad_output.shape == grad_output_shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input_shape,
+ output_shape=grad_output_shape)
+ op = _conv2d_gradfix(transpose=(not transpose),
+ weight_shape=weight_shape,
+ output_padding=p,
+ **common_kwargs)
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input_shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/conv2d_resample.py b/utils/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..df7bea8fd7d0b921227cec546bebbd0e836d9da8
--- /dev/null
+++ b/utils/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,208 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(
+ ): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+
+#----------------------------------------------------------------------------
+
+
+def _conv2d_wrapper(x,
+ w,
+ stride=1,
+ padding=0,
+ groups=1,
+ transpose=False,
+ flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight and (kw > 1 or kh > 1):
+ w = w.flip([2, 3])
+
+ # Execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+
+#----------------------------------------------------------------------------
+
+
+@misc.profiled_function
+def conv2d_resample(x,
+ w,
+ f=None,
+ up=1,
+ down=1,
+ padding=0,
+ groups=1,
+ flip_weight=True,
+ flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype
+ == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x,
+ f=f,
+ down=down,
+ padding=[px0, px1, py0, py1],
+ flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x,
+ f=f,
+ up=up,
+ padding=[px0, px1, py0, py1],
+ gain=up**2,
+ flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x,
+ f=f,
+ padding=[px0, px1, py0, py1],
+ flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x,
+ w=w,
+ stride=down,
+ groups=groups,
+ flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups,
+ in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group,
+ out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x,
+ w=w,
+ stride=up,
+ padding=[pyt, pxt],
+ groups=groups,
+ transpose=True,
+ flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(
+ x=x,
+ f=f,
+ padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
+ gain=up**2,
+ flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x,
+ f=f,
+ down=down,
+ flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x,
+ w=w,
+ padding=[py0, px0],
+ groups=groups,
+ flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x,
+ f=(f if up > 1 else None),
+ up=up,
+ padding=[px0, px1, py0, py1],
+ gain=up**2,
+ flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/filtered_lrelu.cpp b/utils/torch_utils/ops/filtered_lrelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4f55466235a020b0f5e150350bfdcd8b2a1e579d
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu.cpp
@@ -0,0 +1,304 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "filtered_lrelu.h"
+
+//------------------------------------------------------------------------
+
+static std::tuple filtered_lrelu(
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+ // Figure out how much shared memory is available on the device.
+ int maxSharedBytes = 0;
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
+ int sharedKB = maxSharedBytes >> 10;
+
+ // Populate enough launch parameters to check if a CUDA kernel exists.
+ filtered_lrelu_kernel_params p;
+ p.up = up;
+ p.down = down;
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ if (!test_spec.exec)
+ {
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+ }
+
+ // Input/output element size.
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+ // Input sizes.
+ int64_t xw = (int)x.size(3);
+ int64_t xh = (int)x.size(2);
+ int64_t fut_w = (int)fu.size(-1) - 1;
+ int64_t fut_h = (int)fu.size(0) - 1;
+ int64_t fdt_w = (int)fd.size(-1) - 1;
+ int64_t fdt_h = (int)fd.size(0) - 1;
+
+ // Logical size of upsampled buffer.
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+ // Compute output size and allocate.
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
+
+ // Allocate sign tensor.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ int64_t sw_active = 0; // Active width of sign tensor.
+ if (writeSigns)
+ {
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+ else if (readSigns)
+ sw_active = s.size(3) << 2;
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
+ }
+
+ // Populate rest of CUDA kernel parameters.
+ p.x = x.data_ptr();
+ p.y = y.data_ptr();
+ p.b = b.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.fu = fu.data_ptr();
+ p.fd = fd.data_ptr();
+ p.pad0 = make_int2(px0, py0);
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.flip = (flip_filters) ? 1 : 0;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
+
+ // x, y, b strides are in bytes.
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
+ p.bStride = sz * b.stride(0);
+
+ // fu, fd strides are in elements.
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
+ bool index64b = false;
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
+ if (s.numel() > INT_MAX) index64b = true;
+
+ // Choose CUDA kernel.
+ filtered_lrelu_kernel_spec spec = { 0 };
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
+ {
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
+ {
+ // Choose kernel based on index type, datatype and sign read/write modes.
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ }
+ });
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = spec.numWarps * 32;
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+ int gz = p.yShape.z * p.yShape.w;
+
+ // Repeat multiple horizontal tiles in a CTA?
+ if (spec.xrep)
+ {
+ p.tilesXrep = spec.xrep;
+ p.tilesXdim = gx;
+
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+ std::swap(gx, gy);
+ }
+ else
+ {
+ p.tilesXrep = 0;
+ p.tilesXdim = 0;
+ }
+
+ // Launch filter setup kernel.
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
+
+ // Copy kernels to constant memory.
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+
+ // Set cache and shared memory configurations for main kernel.
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
+
+ // Launch main kernel.
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
+ {
+ p.blockZofs = zofs;
+ int subGz = std::min(maxSubGz, gz - zofs);
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
+ }
+
+ // Done.
+ return std::make_tuple(y, so, 0);
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
+
+ // Output signs if we don't have sign input.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ if (writeSigns)
+ {
+ int64_t sw = x.size(3);
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
+ }
+
+ // Initialize CUDA kernel parameters.
+ filtered_lrelu_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+
+ // Choose CUDA kernel.
+ void* func = 0;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
+ {
+ if (writeSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else if (readSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else
+ func = choose_filtered_lrelu_act_kernel();
+ });
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = 128; // 4 warps per block.
+
+ // Logical size of launch = writeSigns ? p.s : p.x
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
+ gx = (gx - 1) / bx + 1;
+
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
+ const uint32_t gmax = 65535;
+ gy = std::min(gy, gmax);
+ gz = std::min(gz, gmax);
+
+ // Launch.
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
+ return so;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
+}
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/filtered_lrelu.cu b/utils/torch_utils/ops/filtered_lrelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..aaac95408365f023ffaa4cb89348d499d3b948f0
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu.cu
@@ -0,0 +1,1288 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "filtered_lrelu.h"
+#include
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum // Filter modes.
+{
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
+};
+
+template struct InternalType;
+template <> struct InternalType
+{
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
+template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
+{
+ if ((N & (N-1)) && N <= 256)
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
+ else
+ y = i/N;
+
+ x = i - y*N;
+}
+
+// Type cast stride before reading it.
+template __device__ __forceinline__ T get_stride(const int64_t& x)
+{
+ return *reinterpret_cast(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
+__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
+{
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
+ {
+ int x, y;
+ fast_div_mod(x, y, idx);
+
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+ if (p.fuShape.y > 0)
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+ else
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+ if (p.fdShape.y > 0)
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+ else
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+ }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer for main kernel.
+template static cudaError_t copy_filters(cudaStream_t stream)
+{
+ void* src = 0;
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
+ if (err) return err;
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor: inX, inY, tileInX, tileInY
+// - Relative to input tile: relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor: outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
+
+template
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
+{
+ // Check that we don't try to support non-existing filter modes.
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
+
+ // Static definitions.
+ typedef typename InternalType::scalar_t scalar_t;
+ typedef typename InternalType::vec2_t vec2_t;
+ typedef typename InternalType::vec4_t vec4_t;
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
+
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
+
+ // Sizes of logical buffers.
+ const int szIn = tileInH_up * tileInW;
+ const int szUpX = tileInH_up * tileUpW;
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+ const int szDownX = tileUpH * tileOutW;
+
+ // Sizes for shared memory arrays.
+ const int s_buf0_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUFD) ? szIn :
+ -1;
+ const int s_buf1_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
+ (filterMode == MODE_FUSD) ? szUpXY :
+ (filterMode == MODE_SUFD) ? szUpX :
+ (filterMode == MODE_FUFD) ? szUpXY :
+ -1;
+
+ // Ensure U128 alignment.
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+ // Check at compile time that we don't use too much shared memory.
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
+
+ // Declare shared memory arrays.
+ scalar_t* s_buf0;
+ scalar_t* s_buf1;
+ if (sharedKB <= 48)
+ {
+ // Allocate shared memory arrays here.
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
+ s_buf0 = s_buf0_st;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+ else
+ {
+ // Use the dynamically allocated shared memory array.
+ s_buf0 = (scalar_t*)s_buf_raw;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+
+ // Pointers to the buffers.
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
+ if (filterMode == MODE_SUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ s_tileDownX = s_buf1;
+ }
+ else if (filterMode == MODE_FUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ s_tileDownX = s_buf0;
+ }
+ else if (filterMode == MODE_SUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ }
+ else if (filterMode == MODE_FUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ }
+
+ // Allow large grids in z direction via per-launch offset.
+ int channelIdx = blockIdx.z + p.blockZofs;
+ int batchIdx = channelIdx / p.yShape.z;
+ channelIdx -= batchIdx * p.yShape.z;
+
+ // Offset to output feature map. In bytes.
+ index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
+
+ // Sign shift amount.
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+ // Inner tile loop.
+ #pragma unroll 1
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
+ {
+ // Locate output tile.
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+ int tileOutX = tileX * tileOutW;
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+ // Locate input tile.
+ int tmpX = tileOutX * down - p.pad0.x;
+ int tmpY = tileOutY * down - p.pad0.y;
+ int tileInX = CEIL_DIV(tmpX, up);
+ int tileInY = CEIL_DIV(tmpY, up);
+ const int phaseInX = tileInX * up - tmpX;
+ const int phaseInY = tileInY * up - tmpY;
+
+ // Extra sync if input and output buffers are the same and we are not on first tile.
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
+ __syncthreads();
+
+ // Load input tile & apply bias. Unrolled.
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
+ index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
+ int idx = threadIdx.x;
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+ #pragma unroll
+ for (int loop = 0; loop < loopCountIN; loop++)
+ {
+ int relInX, relInY;
+ fast_div_mod(relInX, relInY, idx);
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
+
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
+ if (!skip)
+ s_tileIn[idx] = v;
+
+ idx += threadsPerBlock;
+ }
+
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
+ {
+ // Horizontal upsampling.
+ __syncthreads();
+ if (up == 4)
+ {
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ scalar_t a = s_tileIn[src0];
+ if (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInX == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInX == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ s_tileUpX[dst+2] = v.z;
+ s_tileUpX[dst+3] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ bool p0 = (phaseInX == 0);
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ scalar_t a = s_tileIn[src0];
+ if (p0) // (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ }
+ }
+
+ // Vertical upsampling & nonlinearity.
+
+ __syncthreads();
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
+ if (up == 4)
+ {
+ minY -= 3; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec4_t v = InternalType::zero_vec4();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInY == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInY == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+ index_t si2 = si0 + p.sShape.x * 2;
+ index_t si3 = si0 + p.sShape.x * 3;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ int ss = (signX & 3) << 1;
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ minY -= 1; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec2_t v = InternalType::zero_vec2();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+
+ if (!downInline)
+ {
+ // Write into temporary buffer.
+ s_tileUpXY[dst] = v.x;
+ if (relUpY0 < tileUpH - 1)
+ s_tileUpXY[dst + tileUpW] = v.y;
+ }
+ else
+ {
+ // Write directly into output buffer.
+ if ((uint32_t)x < p.yShape.x)
+ {
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
+ index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut;
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+ }
+ }
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
+ {
+ // Full upsampling filter.
+
+ if (up == 2)
+ {
+ // 2 x 2-wide.
+ __syncthreads();
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
+ int src0 = relInX0 + tileInW * relInY0;
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
+
+ #define X_LOOP(TAPY, PX) \
+ for (int sx = 0; sx < fuSize / up; sx++) \
+ { \
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ }
+
+ vec4_t v = InternalType::zero_vec4();
+ if (tap0y == 0 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 0) }
+ if (tap0y == 0 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 1) }
+ if (tap0y == 1 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 0) }
+ if (tap0y == 1 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 1) }
+
+ #undef X_LOOP
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read sign and apply.
+ {
+ if ((uint32_t)signY < p.sShape.y)
+ {
+ int s = 0;
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
+ s >>= (signX & 3) << 1;
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[idx + 0] = v.x;
+ s_tileUpXY[idx + 1] = v.y;
+ s_tileUpXY[idx + 2] = v.z;
+ s_tileUpXY[idx + 3] = v.w;
+ }
+ }
+ else if (up == 1)
+ {
+ __syncthreads();
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ v *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write sign.
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ }
+ else
+ {
+ // Determine and write sign.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ else
+ {
+ // Just compute the value.
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ }
+ }
+ else if (signRead)
+ {
+ // Read sign and apply if within sign tensor bounds.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
+ {
+ int s = p.s[si];
+ s >>= signXo;
+ if (s & 1) v *= p.slope;
+ if (s & 2) v = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+
+ if (!downInline) // Write into temporary buffer.
+ s_tileUpXY[idx] = v;
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
+ *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+
+ // Downsampling.
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
+ {
+ // Horizontal downsampling.
+ __syncthreads();
+ if (down == 4 && tileOutW % 4 == 0)
+ {
+ // Calculate 4 pixels at a time.
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ s_tileDownX[idx+2] = v.z;
+ s_tileDownX[idx+3] = v.w;
+ }
+ }
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
+ {
+ // Calculate 2 pixels at a time.
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ }
+ }
+ else
+ {
+ // Calculate 1 pixel at a time.
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src = relUpY * tileUpW + relUpX0;
+ scalar_t v = 0.f;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
+ s_tileDownX[idx] = v;
+ }
+ }
+
+ // Vertical downsampling & store output tile.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX, relOutY0;
+ fast_div_mod(relOutX, relOutY0, idx);
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileOutW + relOutX;
+ scalar_t v = 0;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
+
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY0;
+
+ if (outX < p.yShape.x & outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
+ {
+ // Full downsampling filter.
+ if (down == 2)
+ {
+ // 2-wide.
+ __syncthreads();
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ int relUpX0 = relOutX0 * down;
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int sy = 0; sy < fdSize; sy++)
+ #pragma unroll
+ for (int sx = 0; sx < fdSize; sx++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ }
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outY < p.yShape.y)
+ {
+ index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut;
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y;
+ }
+ }
+ }
+ else if (down == 1 && !downInline)
+ {
+ // Thread per pixel.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ }
+
+ if (!enableXrep)
+ break;
+ }
+}
+
+//------------------------------------------------------------------------
+// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
+// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
+
+template
+static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Indexing.
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
+
+ // Loop to accommodate oversized tensors.
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
+ {
+ // Extract z and w (channel, minibatch index).
+ int32_t w = q / p.xShape.z;
+ int32_t z = q - w * p.xShape.z;
+
+ // Choose behavior based on sign read/write mode.
+ if (signWrite)
+ {
+ // Process value if in p.x.
+ uint32_t s = 0;
+ if (x < p.xShape.x && y < p.xShape.y)
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+
+ // Gain, LReLU, clamp.
+ v *= p.gain;
+ if (v < 0.f)
+ {
+ v *= p.slope;
+ s = 1; // Sign.
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ v = InternalType::clamp(v, p.clamp);
+ s = 2; // Clamp.
+ }
+
+ *pv = (T)v; // Write value.
+ }
+
+ // Coalesce into threads 0 and 16 of warp.
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
+ s |= __shfl_xor_sync(m, s, 2);
+ s |= __shfl_xor_sync(m, s, 4);
+ s |= __shfl_xor_sync(m, s, 8);
+
+ // Write signs if leader and in p.s.
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
+ {
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
+ ((uint32_t*)p.s)[is >> 4] = s;
+ }
+ }
+ else if (signRead)
+ {
+ // Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+
+ // Apply sign buffer offset.
+ uint32_t sx = x + p.sOfs.x;
+ uint32_t sy = y + p.sOfs.y;
+
+ // Read and apply signs if we land inside valid region of sign buffer.
+ if (sx < p.sShape.x && sy < p.sShape.y)
+ {
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
+ unsigned char s = p.s[is];
+ s >>= (sx & 3) << 1; // Shift into place.
+ if (s & 1) // Sign?
+ v *= p.slope;
+ if (s & 2) // Clamp?
+ v = 0.f;
+ }
+
+ *pv = (T)v; // Write value.
+ }
+ }
+ else
+ {
+ // Forward pass with no sign write. Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+ if (v < 0.f)
+ v *= p.slope;
+ if (fabsf(v) > p.clamp)
+ v = InternalType::clamp(v, p.clamp);
+ *pv = (T)v; // Write value.
+ }
+ }
+ }
+}
+
+template void* choose_filtered_lrelu_act_kernel(void)
+{
+ return (void*)filtered_lrelu_act_kernel;
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
+{
+ filtered_lrelu_kernel_spec s = { 0 };
+
+ // Return the first matching kernel.
+#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
+ if (sharedKB >= SH) \
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
+ { \
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
+ s.setup = (void*)setup_filters_kernel; \
+ s.exec = (void*)filtered_lrelu_kernel; \
+ s.tileOut = make_int2(TW, TH); \
+ s.numWarps = W; \
+ s.xrep = XR; \
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
+ return s; \
+ }
+
+ // Launch parameters for various kernel specializations.
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
+
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
+
+ #undef CASE
+ return s; // No kernel found.
+}
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/filtered_lrelu.h b/utils/torch_utils/ops/filtered_lrelu.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfd1dd537909de9cd3b14765a482056391683b
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu.h
@@ -0,0 +1,94 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct filtered_lrelu_kernel_params
+{
+ // These parameters decide which kernel to use.
+ int up; // upsampling ratio (1, 2, 4)
+ int down; // downsampling ratio (1, 2, 4)
+ int2 fuShape; // [size, 1] | [size, size]
+ int2 fdShape; // [size, 1] | [size, size]
+
+ int _dummy; // Alignment.
+
+ // Rest of the parameters.
+ const void* x; // Input tensor.
+ void* y; // Output tensor.
+ const void* b; // Bias tensor.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+ const float* fu; // Upsampling filter.
+ const float* fd; // Downsampling filter.
+
+ int2 pad0; // Left/top padding.
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+ int flip; // Filter kernel flip for gradient computation.
+
+ int tilesXdim; // Original number of horizontal output tiles.
+ int tilesXrep; // Number of horizontal tiles per CTA.
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
+
+ int4 xShape; // [width, height, channel, batch]
+ int4 yShape; // [width, height, channel, batch]
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+ int swLimit; // Active width of sign tensor in bytes.
+
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
+ longlong4 yStride; //
+ int64_t bStride; //
+ longlong3 fuStride; //
+ longlong3 fdStride; //
+};
+
+struct filtered_lrelu_act_kernel_params
+{
+ void* x; // Input/output, modified in-place.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+
+ int4 xShape; // [width, height, channel, batch]
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct filtered_lrelu_kernel_spec
+{
+ void* setup; // Function for filter kernel setup.
+ void* exec; // Function for main operation.
+ int2 tileOut; // Width/height of launch tile.
+ int numWarps; // Number of warps per thread block, determines launch block size.
+ int xrep; // For processing multiple horizontal tiles per thread block.
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template void* choose_filtered_lrelu_act_kernel(void);
+template cudaError_t copy_filters(cudaStream_t stream);
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/filtered_lrelu.py b/utils/torch_utils/ops/filtered_lrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6d22247caae82ad806a16e967ef50c96be5d78a
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu.py
@@ -0,0 +1,377 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+import numpy as np
+import torch
+import warnings
+
+from .. import custom_ops
+from .. import misc
+from . import upfirdn2d
+from . import bias_act
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='filtered_lrelu_plugin',
+ sources=[
+ 'filtered_lrelu.cpp', 'filtered_lrelu_wr.cu',
+ 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'
+ ],
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor)
+ assert 1 <= f.ndim <= 2
+ return f.shape[-1], f.shape[0] # width, height
+
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
+ padding = [int(x) for x in padding]
+ if len(padding) == 2:
+ px, py = padding
+ padding = [px, px, py, py]
+ px0, px1, py0, py1 = padding
+ return px0, px1, py0, py1
+
+
+#----------------------------------------------------------------------------
+
+
+def filtered_lrelu(x,
+ fu=None,
+ fd=None,
+ b=None,
+ up=1,
+ down=1,
+ padding=0,
+ gain=np.sqrt(2),
+ slope=0.2,
+ clamp=None,
+ flip_filter=False,
+ impl='cuda'):
+ r"""Filtered leaky ReLU for a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Add channel-specific bias if provided (`b`).
+
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 5. Multiply each value by the provided gain factor (`gain`).
+
+ 6. Apply leaky ReLU activation function to each value.
+
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
+
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
+ it so that the footprint of all output pixels lies within the input image.
+
+ 9. Downsample the image by keeping every Nth pixel (`down`).
+
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float16/float64 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ fu: Float32 upsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ fd: Float32 downsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The length of vector must must match the channel dimension of `x`.
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor. (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
+ flip_filter: False = convolution, True = correlation (default: False).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _filtered_lrelu_cuda(up=up,
+ down=down,
+ padding=padding,
+ gain=gain,
+ slope=slope,
+ clamp=clamp,
+ flip_filter=flip_filter).apply(
+ x, fu, fd, b, None, 0, 0)
+ return _filtered_lrelu_ref(x,
+ fu=fu,
+ fd=fd,
+ b=b,
+ up=up,
+ down=down,
+ padding=padding,
+ gain=gain,
+ slope=slope,
+ clamp=clamp,
+ flip_filter=flip_filter)
+
+
+#----------------------------------------------------------------------------
+
+
+@misc.profiled_function
+def _filtered_lrelu_ref(x,
+ fu=None,
+ fd=None,
+ b=None,
+ up=1,
+ down=1,
+ padding=0,
+ gain=np.sqrt(2),
+ slope=0.2,
+ clamp=None,
+ flip_filter=False):
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
+ existing `upfirdn2n()` and `bias_act()` ops.
+ """
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ fu_w, fu_h = _get_filter_size(fu)
+ fd_w, fd_h = _get_filter_size(fd)
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
+ misc.assert_shape(b, [x.shape[1]])
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ assert slope == float(slope) and slope >= 0
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+
+ # Calculate output size.
+ batch_size, channels, in_h, in_w = x.shape
+ in_dtype = x.dtype
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) +
+ (down - 1)) // down
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) +
+ (down - 1)) // down
+
+ # Compute using existing ops.
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
+ x = upfirdn2d.upfirdn2d(x=x,
+ f=fu,
+ up=up,
+ padding=[px0, px1, py0, py1],
+ gain=up**2,
+ flip_filter=flip_filter) # Upsample.
+ x = bias_act.bias_act(x=x,
+ act='lrelu',
+ alpha=slope,
+ gain=gain,
+ clamp=clamp) # Bias, leaky ReLU, clamp.
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down,
+ flip_filter=flip_filter) # Downsample.
+
+ # Check output shape & dtype.
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
+ assert x.dtype == in_dtype
+ return x
+
+
+#----------------------------------------------------------------------------
+
+_filtered_lrelu_cuda_cache = dict()
+
+
+def _filtered_lrelu_cuda(up=1,
+ down=1,
+ padding=0,
+ gain=np.sqrt(2),
+ slope=0.2,
+ clamp=None,
+ flip_filter=False):
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
+ """
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ gain = float(gain)
+ assert slope == float(slope) and slope >= 0
+ slope = float(slope)
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+ clamp = float(clamp if clamp is not None else 'inf')
+
+ # Lookup from cache.
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
+ if key in _filtered_lrelu_cuda_cache:
+ return _filtered_lrelu_cuda_cache[key]
+
+ # Forward op.
+ class FilteredLReluCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
+ if fu is None:
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ if fd is None:
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert 1 <= fu.ndim <= 2
+ assert 1 <= fd.ndim <= 2
+
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
+ fu = fu.square()[None]
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
+ fd = fd.square()[None]
+
+ # Missing sign input tensor.
+ if si is None:
+ si = torch.empty([0])
+
+ # Missing bias tensor.
+ if b is None:
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
+
+ # Construct internal sign tensor only if gradients are needed.
+ write_signs = (si.numel() == 0) and (x.requires_grad
+ or b.requires_grad)
+
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
+ warnings.warn(
+ "low-performance memory layout detected in filtered_lrelu input",
+ RuntimeWarning)
+
+ # Call C++/Cuda plugin if datatype is supported.
+ if x.dtype in [torch.float16, torch.float32]:
+ if torch.cuda.current_stream(
+ x.device) != torch.cuda.default_stream(x.device):
+ warnings.warn(
+ "filtered_lrelu called with non-default cuda stream but concurrent execution is not supported",
+ RuntimeWarning)
+ y, so, return_code = _plugin.filtered_lrelu(
+ x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy,
+ gain, slope, clamp, flip_filter, write_signs)
+ else:
+ return_code = -1
+
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
+ # only the bit-packed sign tensor is retained for gradient computation.
+ if return_code < 0:
+ warnings.warn(
+ "filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback",
+ RuntimeWarning)
+
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
+ y = upfirdn2d.upfirdn2d(x=y,
+ f=fu,
+ up=up,
+ padding=[px0, px1, py0, py1],
+ gain=up**2,
+ flip_filter=flip_filter) # Upsample.
+ so = _plugin.filtered_lrelu_act_(
+ y, si, sx, sy, gain, slope, clamp, write_signs
+ ) # Activation function and sign handling. Modifies y in-place.
+ y = upfirdn2d.upfirdn2d(x=y,
+ f=fd,
+ down=down,
+ flip_filter=flip_filter) # Downsample.
+
+ # Prepare for gradient computation.
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
+ ctx.x_shape = x.shape
+ ctx.y_shape = y.shape
+ ctx.s_ofs = sx, sy
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ fu, fd, si = ctx.saved_tensors
+ _, _, xh, xw = ctx.x_shape
+ _, _, yh, yw = ctx.y_shape
+ sx, sy = ctx.s_ofs
+ dx = None # 0
+ dfu = None
+ assert not ctx.needs_input_grad[1]
+ dfd = None
+ assert not ctx.needs_input_grad[2]
+ db = None # 3
+ dsi = None
+ assert not ctx.needs_input_grad[4]
+ dsx = None
+ assert not ctx.needs_input_grad[5]
+ dsy = None
+ assert not ctx.needs_input_grad[6]
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
+ pp = [
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
+ xw * up - yw * down + px0 - (up - 1),
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
+ xh * up - yh * down + py0 - (up - 1),
+ ]
+ gg = gain * (up**2) / (down**2)
+ ff = (not flip_filter)
+ sx = sx - (fu.shape[-1] - 1) + px0
+ sy = sy - (fu.shape[0] - 1) + py0
+ dx = _filtered_lrelu_cuda(up=down,
+ down=up,
+ padding=pp,
+ gain=gg,
+ slope=slope,
+ clamp=None,
+ flip_filter=ff).apply(
+ dy, fd, fu, None, si, sx, sy)
+
+ if ctx.needs_input_grad[3]:
+ db = dx.sum([0, 2, 3])
+
+ return dx, dfu, dfd, db, dsi, dsx, dsy
+
+ # Add to cache.
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
+ return FilteredLReluCuda
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/filtered_lrelu_ns.cu b/utils/torch_utils/ops/filtered_lrelu_ns.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8a3eae46215c3babea2c54e3ae255b05f4d777af
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu_ns.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for no signs mode (no gradients required).
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/utils/torch_utils/ops/filtered_lrelu_rd.cu b/utils/torch_utils/ops/filtered_lrelu_rd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..3cd43ec0648d3db05e5808299fc0ee318e5ceaa6
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu_rd.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign read mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/utils/torch_utils/ops/filtered_lrelu_wr.cu b/utils/torch_utils/ops/filtered_lrelu_wr.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bc2fa06912eb703dd77ca64533208428bdf373ac
--- /dev/null
+++ b/utils/torch_utils/ops/filtered_lrelu_wr.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign write mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/utils/torch_utils/ops/fma.py b/utils/torch_utils/ops/fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5feae8693cc3cd36d57f77f0f5d16c4dc6b990d
--- /dev/null
+++ b/utils/torch_utils/ops/fma.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+
+#----------------------------------------------------------------------------
+
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+
+#----------------------------------------------------------------------------
+
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [
+ i for i in range(x.ndim)
+ if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)
+ ]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims + 1:])
+ assert x.shape == shape
+ return x
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/grid_sample_gradfix.py b/utils/torch_utils/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..14598b5ddcc76353ac4d2c0792caea5e03eb25d9
--- /dev/null
+++ b/utils/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,96 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input,
+ grid=grid,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+
+
+#----------------------------------------------------------------------------
+
+
+def _should_use_custom_op():
+ return enabled
+
+
+#----------------------------------------------------------------------------
+
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input,
+ grid=grid,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(
+ grad_output, input, grid)
+ return grad_input, grad_grid
+
+
+#----------------------------------------------------------------------------
+
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(
+ grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+
+#----------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/upfirdn2d.cpp b/utils/torch_utils/ops/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c1769c3cbe4dd04f76f9ccef726680720e6f39c8
--- /dev/null
+++ b/utils/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,111 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/utils/torch_utils/ops/upfirdn2d.cu b/utils/torch_utils/ops/upfirdn2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7d182d7b86a9058d0c007b13716d6e7f08207f42
--- /dev/null
+++ b/utils/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,388 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ // No up/downsampling.
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+
+ // 2x upsampling.
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+
+ // 2x downsampling.
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small