Cr4yfish commited on
Commit
c165cd8
1 Parent(s): 656bb68

copy files from SuLvXiangXin

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +15 -0
  2. LICENSE +202 -0
  3. README.md +252 -3
  4. configs/360.gin +15 -0
  5. configs/360_glo.gin +15 -0
  6. configs/blender.gin +15 -0
  7. configs/blender_refnerf.gin +41 -0
  8. configs/llff_256.gin +19 -0
  9. configs/llff_512.gin +19 -0
  10. configs/llff_raw.gin +73 -0
  11. configs/multi360.gin +5 -0
  12. eval.py +307 -0
  13. extract.py +638 -0
  14. gridencoder/__init__.py +1 -0
  15. gridencoder/backend.py +40 -0
  16. gridencoder/grid.py +198 -0
  17. gridencoder/setup.py +50 -0
  18. gridencoder/src/bindings.cpp +9 -0
  19. gridencoder/src/gridencoder.cu +645 -0
  20. gridencoder/src/gridencoder.h +17 -0
  21. internal/camera_utils.py +673 -0
  22. internal/checkpoints.py +38 -0
  23. internal/configs.py +177 -0
  24. internal/coord.py +225 -0
  25. internal/datasets.py +1016 -0
  26. internal/geopoly.py +108 -0
  27. internal/image.py +126 -0
  28. internal/math.py +133 -0
  29. internal/models.py +740 -0
  30. internal/pycolmap/.gitignore +2 -0
  31. internal/pycolmap/LICENSE.txt +21 -0
  32. internal/pycolmap/README.md +4 -0
  33. internal/pycolmap/pycolmap/__init__.py +5 -0
  34. internal/pycolmap/pycolmap/camera.py +259 -0
  35. internal/pycolmap/pycolmap/database.py +340 -0
  36. internal/pycolmap/pycolmap/image.py +35 -0
  37. internal/pycolmap/pycolmap/rotation.py +324 -0
  38. internal/pycolmap/pycolmap/scene_manager.py +670 -0
  39. internal/pycolmap/tools/colmap_to_nvm.py +69 -0
  40. internal/pycolmap/tools/delete_images.py +36 -0
  41. internal/pycolmap/tools/impute_missing_cameras.py +180 -0
  42. internal/pycolmap/tools/save_cameras_as_ply.py +92 -0
  43. internal/pycolmap/tools/transform_model.py +48 -0
  44. internal/pycolmap/tools/write_camera_track_to_bundler.py +60 -0
  45. internal/pycolmap/tools/write_depthmap_to_ply.py +139 -0
  46. internal/raw_utils.py +360 -0
  47. internal/ref_utils.py +174 -0
  48. internal/render.py +242 -0
  49. internal/stepfun.py +403 -0
  50. internal/train_utils.py +263 -0
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ interal/__pycache__/
3
+ tests/__pycache__/
4
+ .DS_Store
5
+ .vscode/
6
+ .idea/
7
+ __MACOSX/
8
+ exp/
9
+ data/
10
+ assets/
11
+ test.py
12
+ test2.py
13
+ *.mp4
14
+ *.ply
15
+ scripts/train_360_debug.sh
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,252 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ZipNeRF
2
+
3
+ An unofficial pytorch implementation of
4
+ "Zip-NeRF: Anti-Aliased Grid-Based Neural Radiance Fields"
5
+ [https://arxiv.org/abs/2304.06706](https://arxiv.org/abs/2304.06706).
6
+ This work is based on [multinerf](https://github.com/google-research/multinerf), so features in refnerf,rawnerf,mipnerf360 are also available.
7
+
8
+ ## News
9
+ - (6.22) Add extracting mesh through tsdf; add [gradient scaling](https://gradient-scaling.github.io/) for near plane floaters.
10
+ - (5.26) Implement the latest version of ZipNeRF [https://arxiv.org/abs/2304.06706](https://arxiv.org/abs/2304.06706).
11
+ - (5.22) Add extracting mesh; add logging,checkpointing system
12
+
13
+ ## Results
14
+ New results(5.27):
15
+
16
+ 360_v2:
17
+
18
+ https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/2b276e48-2dc4-4508-8441-e90ec963f7d9
19
+
20
+
21
+ 360_v2_glo:(fewer floaters, but worse metric)
22
+
23
+
24
+ https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/bddb5610-2a4f-4981-8e17-71326a24d291
25
+
26
+
27
+
28
+
29
+
30
+
31
+ mesh results(5.27):
32
+
33
+ ![mesh](https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/35866fa7-fe6a-44fe-9590-05d594bdb8cd)
34
+
35
+
36
+
37
+ Mipnerf360(PSNR):
38
+
39
+ | | bicycle | garden | stump | room | counter | kitchen | bonsai |
40
+ |:---------:|:-------:|:------:|:-----:|:-----:|:-------:|:-------:|:------:|
41
+ | Paper | 25.80 | 28.20 | 27.55 | 32.65 | 29.38 | 32.50 | 34.46 |
42
+ | This repo | 25.44 | 27.98 | 26.75 | 32.13 | 29.10 | 32.63 | 34.20 |
43
+
44
+
45
+ Blender(PSNR):
46
+
47
+ | | chair | drums | ficus | hotdog | lego | materials | mic | ship |
48
+ |:---------:|:-----:|:-----:|:-----:|:------:|:-----:|:---------:|:-----:|:-----:|
49
+ | Paper | 34.84 | 25.84 | 33.90 | 37.14 | 34.84 | 31.66 | 35.15 | 31.38 |
50
+ | This repo | 35.26 | 25.51 | 32.66 | 36.56 | 35.04 | 29.43 | 34.93 | 31.38 |
51
+
52
+ For Mipnerf360 dataset, the model is trained with a downsample factor of 4 for outdoor scene and 2 for indoor scene(same as in paper).
53
+ Training speed is about 1.5x slower than paper(1.5 hours on 8 A6000).
54
+
55
+ The hash decay loss seems to have little effect(?), as many floaters can be found in the final results in both experiments (especially in Blender).
56
+
57
+ ## Install
58
+
59
+ ```
60
+ # Clone the repo.
61
+ git clone https://github.com/SuLvXiangXin/zipnerf-pytorch.git
62
+ cd zipnerf-pytorch
63
+
64
+ # Make a conda environment.
65
+ conda create --name zipnerf python=3.9
66
+ conda activate zipnerf
67
+
68
+ # Install requirements.
69
+ pip install -r requirements.txt
70
+
71
+ # Install other extensions
72
+ pip install ./gridencoder
73
+
74
+ # Install nvdiffrast (optional, for textured mesh)
75
+ git clone https://github.com/NVlabs/nvdiffrast
76
+ pip install ./nvdiffrast
77
+
78
+ # Install a specific cuda version of torch_scatter
79
+ # see more detail at https://github.com/rusty1s/pytorch_scatter
80
+ CUDA=cu117
81
+ pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
82
+ ```
83
+
84
+ ## Dataset
85
+ [mipnerf360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip)
86
+
87
+ [refnerf](https://storage.googleapis.com/gresearch/refraw360/ref.zip)
88
+
89
+ [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
90
+
91
+ [nerf_llff_data](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
92
+
93
+ ```
94
+ mkdir data
95
+ cd data
96
+
97
+ # e.g. mipnerf360 data
98
+ wget http://storage.googleapis.com/gresearch/refraw360/360_v2.zip
99
+ unzip 360_v2.zip
100
+ ```
101
+
102
+ ## Train
103
+ ```
104
+ # Configure your training (DDP? fp16? ...)
105
+ # see https://huggingface.co/docs/accelerate/index for details
106
+ accelerate config
107
+
108
+ # Where your data is
109
+ DATA_DIR=data/360_v2/bicycle
110
+ EXP_NAME=360_v2/bicycle
111
+
112
+ # Experiment will be conducted under "exp/${EXP_NAME}" folder
113
+ # "--gin_configs=configs/360.gin" can be seen as a default config
114
+ # and you can add specific config useing --gin_bindings="..."
115
+ accelerate launch train.py \
116
+ --gin_configs=configs/360.gin \
117
+ --gin_bindings="Config.data_dir = '${DATA_DIR}'" \
118
+ --gin_bindings="Config.exp_name = '${EXP_NAME}'" \
119
+ --gin_bindings="Config.factor = 4"
120
+
121
+ # or you can also run without accelerate (without DDP)
122
+ CUDA_VISIBLE_DEVICES=0 python train.py \
123
+ --gin_configs=configs/360.gin \
124
+ --gin_bindings="Config.data_dir = '${DATA_DIR}'" \
125
+ --gin_bindings="Config.exp_name = '${EXP_NAME}'" \
126
+ --gin_bindings="Config.factor = 4"
127
+
128
+ # alternatively you can use an example training script
129
+ bash scripts/train_360.sh
130
+
131
+ # blender dataset
132
+ bash scripts/train_blender.sh
133
+
134
+ # metric, render image, etc can be viewed through tensorboard
135
+ tensorboard --logdir "exp/${EXP_NAME}"
136
+
137
+ ```
138
+
139
+ ### Render
140
+ Rendering results can be found in the directory `exp/${EXP_NAME}/render`
141
+ ```
142
+ accelerate launch render.py \
143
+ --gin_configs=configs/360.gin \
144
+ --gin_bindings="Config.data_dir = '${DATA_DIR}'" \
145
+ --gin_bindings="Config.exp_name = '${EXP_NAME}'" \
146
+ --gin_bindings="Config.render_path = True" \
147
+ --gin_bindings="Config.render_path_frames = 480" \
148
+ --gin_bindings="Config.render_video_fps = 60" \
149
+ --gin_bindings="Config.factor = 4"
150
+
151
+ # alternatively you can use an example rendering script
152
+ bash scripts/render_360.sh
153
+ ```
154
+ ## Evaluate
155
+ Evaluating results can be found in the directory `exp/${EXP_NAME}/test_preds`
156
+ ```
157
+ # using the same exp_name as in training
158
+ accelerate launch eval.py \
159
+ --gin_configs=configs/360.gin \
160
+ --gin_bindings="Config.data_dir = '${DATA_DIR}'" \
161
+ --gin_bindings="Config.exp_name = '${EXP_NAME}'" \
162
+ --gin_bindings="Config.factor = 4"
163
+
164
+
165
+ # alternatively you can use an example evaluating script
166
+ bash scripts/eval_360.sh
167
+ ```
168
+
169
+ ## Extract mesh
170
+ Mesh results can be found in the directory `exp/${EXP_NAME}/mesh`
171
+ ```
172
+ # more configuration can be found in internal/configs.py
173
+ accelerate launch extract.py \
174
+ --gin_configs=configs/360.gin \
175
+ --gin_bindings="Config.data_dir = '${DATA_DIR}'" \
176
+ --gin_bindings="Config.exp_name = '${EXP_NAME}'" \
177
+ --gin_bindings="Config.factor = 4"
178
+ # --gin_bindings="Config.mesh_radius = 1" # (optional) smaller for more details e.g. 0.2 in bicycle scene
179
+ # --gin_bindings="Config.isosurface_threshold = 20" # (optional) empirical value
180
+ # --gin_bindings="Config.mesh_voxels=134217728" # (optional) number of voxels used to extract mesh, e.g. 134217728 equals to 512**3 . Smaller values may solve OutoFMemoryError
181
+ # --gin_bindings="Config.vertex_color = True" # (optional) saving mesh with vertex color instead of atlas which is much slower but with more details.
182
+ # --gin_bindings="Config.vertex_projection = True" # (optional) use projection for vertex color
183
+
184
+ # or extracting mesh using tsdf method
185
+ accelerate launch extract.py \
186
+ --gin_configs=configs/360.gin \
187
+ --gin_bindings="Config.data_dir = '${DATA_DIR}'" \
188
+ --gin_bindings="Config.exp_name = '${EXP_NAME}'" \
189
+ --gin_bindings="Config.factor = 4"
190
+
191
+ # alternatively you can use an example script
192
+ bash scripts/extract_360.sh
193
+ ```
194
+
195
+ ## OutOfMemory
196
+ you can decrease the total batch size by
197
+ adding e.g. `--gin_bindings="Config.batch_size = 8192" `,
198
+ or decrease the test chunk size by adding e.g. `--gin_bindings="Config.render_chunk_size = 8192" `,
199
+ or use more GPU by configure `accelerate config` .
200
+
201
+
202
+ ## Preparing custom data
203
+ More details can be found at https://github.com/google-research/multinerf
204
+ ```
205
+ DATA_DIR=my_dataset_dir
206
+ bash scripts/local_colmap_and_resize.sh ${DATA_DIR}
207
+ ```
208
+
209
+ ## TODO
210
+ - [x] Add MultiScale training and testing
211
+
212
+ ## Citation
213
+ ```
214
+ @misc{barron2023zipnerf,
215
+ title={Zip-NeRF: Anti-Aliased Grid-Based Neural Radiance Fields},
216
+ author={Jonathan T. Barron and Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman},
217
+ year={2023},
218
+ eprint={2304.06706},
219
+ archivePrefix={arXiv},
220
+ primaryClass={cs.CV}
221
+ }
222
+
223
+ @misc{multinerf2022,
224
+ title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}},
225
+ author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron},
226
+ year={2022},
227
+ url={https://github.com/google-research/multinerf},
228
+ }
229
+
230
+ @Misc{accelerate,
231
+ title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
232
+ author = {Sylvain Gugger, Lysandre Debut, Thomas Wolf, Philipp Schmid, Zachary Mueller, Sourab Mangrulkar},
233
+ howpublished = {\url{https://github.com/huggingface/accelerate}},
234
+ year = {2022}
235
+ }
236
+
237
+ @misc{torch-ngp,
238
+ Author = {Jiaxiang Tang},
239
+ Year = {2022},
240
+ Note = {https://github.com/ashawkey/torch-ngp},
241
+ Title = {Torch-ngp: a PyTorch implementation of instant-ngp}
242
+ }
243
+ ```
244
+
245
+ ## Acknowledgements
246
+ This work is based on my another repo https://github.com/SuLvXiangXin/multinerf-pytorch,
247
+ which is basically a pytorch translation from [multinerf](https://github.com/google-research/multinerf)
248
+
249
+ - Thanks to [multinerf](https://github.com/google-research/multinerf) for amazing multinerf(MipNeRF360,RefNeRF,RawNeRF) implementation
250
+ - Thanks to [accelerate](https://github.com/huggingface/accelerate) for distributed training
251
+ - Thanks to [torch-ngp](https://github.com/ashawkey/torch-ngp) for super useful hashencoder
252
+ - Thanks to [Yurui Chen](https://github.com/519401113) for discussing the details of the paper.
configs/360.gin ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Config.exp_name = 'test'
2
+ Config.dataset_loader = 'llff'
3
+ Config.near = 0.2
4
+ Config.far = 1e6
5
+ Config.factor = 4
6
+
7
+ Model.raydist_fn = 'power_transformation'
8
+ Model.opaque_background = True
9
+
10
+ PropMLP.disable_density_normals = True
11
+ PropMLP.disable_rgb = True
12
+ PropMLP.grid_level_dim = 1
13
+
14
+ NerfMLP.disable_density_normals = True
15
+
configs/360_glo.gin ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Config.dataset_loader = 'llff'
2
+ Config.near = 0.2
3
+ Config.far = 1e6
4
+ Config.factor = 4
5
+
6
+ Model.raydist_fn = 'power_transformation'
7
+ Model.num_glo_features = 128
8
+ Model.opaque_background = True
9
+
10
+ PropMLP.disable_density_normals = True
11
+ PropMLP.disable_rgb = True
12
+ PropMLP.grid_level_dim = 1
13
+
14
+
15
+ NerfMLP.disable_density_normals = True
configs/blender.gin ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Config.exp_name = 'test'
2
+ Config.dataset_loader = 'blender'
3
+ Config.near = 2
4
+ Config.far = 6
5
+ Config.factor = 0
6
+ Config.hash_decay_mults = 10
7
+
8
+ Model.raydist_fn = None
9
+
10
+ PropMLP.disable_density_normals = True
11
+ PropMLP.disable_rgb = True
12
+ PropMLP.grid_level_dim = 1
13
+
14
+ NerfMLP.disable_density_normals = True
15
+
configs/blender_refnerf.gin ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Config.dataset_loader = 'blender'
2
+ Config.batching = 'single_image'
3
+ Config.near = 2
4
+ Config.far = 6
5
+
6
+ Config.eval_render_interval = 5
7
+ Config.compute_normal_metrics = True
8
+ Config.data_loss_type = 'mse'
9
+ Config.distortion_loss_mult = 0.0
10
+ Config.orientation_loss_mult = 0.1
11
+ Config.orientation_loss_target = 'normals_pred'
12
+ Config.predicted_normal_loss_mult = 3e-4
13
+ Config.orientation_coarse_loss_mult = 0.01
14
+ Config.predicted_normal_coarse_loss_mult = 3e-5
15
+ Config.interlevel_loss_mult = 0.0
16
+ Config.data_coarse_loss_mult = 0.1
17
+ Config.adam_eps = 1e-8
18
+
19
+ Model.num_levels = 2
20
+ Model.single_mlp = True
21
+ Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True.
22
+ Model.num_nerf_samples = 128
23
+ Model.anneal_slope = 0.
24
+ Model.dilation_multiplier = 0.
25
+ Model.dilation_bias = 0.
26
+ Model.single_jitter = False
27
+ Model.resample_padding = 0.01
28
+ Model.distinct_prop = False
29
+
30
+ NerfMLP.disable_density_normals = False
31
+ NerfMLP.enable_pred_normals = True
32
+ NerfMLP.use_directional_enc = True
33
+ NerfMLP.use_reflections = True
34
+ NerfMLP.deg_view = 5
35
+ NerfMLP.enable_pred_roughness = True
36
+ NerfMLP.use_diffuse_color = True
37
+ NerfMLP.use_specular_tint = True
38
+ NerfMLP.use_n_dot_v = True
39
+ NerfMLP.bottleneck_width = 128
40
+ NerfMLP.density_bias = 0.5
41
+ NerfMLP.max_deg_point = 16
configs/llff_256.gin ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Config.dataset_loader = 'llff'
2
+ Config.near = 0.
3
+ Config.far = 1.
4
+ Config.factor = 4
5
+ Config.forward_facing = True
6
+ Config.adam_eps = 1e-8
7
+
8
+ Model.opaque_background = True
9
+ Model.num_levels = 2
10
+ Model.num_prop_samples = 128
11
+ Model.num_nerf_samples = 32
12
+
13
+ PropMLP.disable_density_normals = True
14
+ PropMLP.disable_rgb = True
15
+
16
+ NerfMLP.disable_density_normals = True
17
+
18
+ NerfMLP.max_deg_point = 16
19
+ PropMLP.max_deg_point = 16
configs/llff_512.gin ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Config.dataset_loader = 'llff'
2
+ Config.near = 0.
3
+ Config.far = 1.
4
+ Config.factor = 4
5
+ Config.forward_facing = True
6
+ Config.adam_eps = 1e-8
7
+
8
+ Model.opaque_background = True
9
+ Model.num_levels = 2
10
+ Model.num_prop_samples = 128
11
+ Model.num_nerf_samples = 32
12
+
13
+ PropMLP.disable_density_normals = True
14
+ PropMLP.disable_rgb = True
15
+
16
+ NerfMLP.disable_density_normals = True
17
+
18
+ NerfMLP.max_deg_point = 16
19
+ PropMLP.max_deg_point = 16
configs/llff_raw.gin ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General LLFF settings
2
+
3
+ Config.dataset_loader = 'llff'
4
+ Config.near = 0.
5
+ Config.far = 1.
6
+ Config.factor = 4
7
+ Config.forward_facing = True
8
+
9
+ PropMLP.disable_density_normals = True # Turn this off if using orientation loss.
10
+ PropMLP.disable_rgb = True
11
+
12
+ NerfMLP.disable_density_normals = True # Turn this off if using orientation loss.
13
+
14
+ NerfMLP.max_deg_point = 16
15
+ PropMLP.max_deg_point = 16
16
+
17
+ Config.train_render_every = 5000
18
+
19
+
20
+ ########################## RawNeRF specific settings ##########################
21
+
22
+ Config.rawnerf_mode = True
23
+ Config.data_loss_type = 'rawnerf'
24
+ Config.apply_bayer_mask = True
25
+ Model.learned_exposure_scaling = True
26
+
27
+ Model.num_levels = 2
28
+ Model.num_prop_samples = 128 # Using extra samples for now because of noise instability.
29
+ Model.num_nerf_samples = 128
30
+ Model.opaque_background = True
31
+ Model.distinct_prop = False
32
+
33
+ # RGB activation we use for linear color outputs is exp(x - 5).
34
+ NerfMLP.rgb_padding = 0.
35
+ NerfMLP.rgb_activation = @math.safe_exp
36
+ NerfMLP.rgb_bias = -5.
37
+ PropMLP.rgb_padding = 0.
38
+ PropMLP.rgb_activation = @math.safe_exp
39
+ PropMLP.rgb_bias = -5.
40
+
41
+ ## Experimenting with the various regularizers and losses:
42
+ Config.interlevel_loss_mult = .0 # Turning off interlevel for now (default = 1.).
43
+ Config.distortion_loss_mult = .01 # Distortion loss helps with floaters (default = .01).
44
+ Config.orientation_loss_mult = 0. # Orientation loss also not great (try .01).
45
+ Config.data_coarse_loss_mult = 0.1 # Setting this to match old MipNeRF.
46
+
47
+ ## Density noise used in original NeRF:
48
+ NerfMLP.density_noise = 1.
49
+ PropMLP.density_noise = 1.
50
+
51
+ ## Use a single MLP for all rounds of sampling:
52
+ Model.single_mlp = True
53
+
54
+ ## Some algorithmic settings to match the paper:
55
+ Model.anneal_slope = 0.
56
+ Model.dilation_multiplier = 0.
57
+ Model.dilation_bias = 0.
58
+ Model.single_jitter = False
59
+ NerfMLP.weight_init = 'glorot_uniform'
60
+ PropMLP.weight_init = 'glorot_uniform'
61
+
62
+ ## Training hyperparameters used in the paper:
63
+ Config.batch_size = 16384
64
+ Config.render_chunk_size = 16384
65
+ Config.lr_init = 1e-3
66
+ Config.lr_final = 1e-5
67
+ Config.max_steps = 500000
68
+ Config.checkpoint_every = 25000
69
+ Config.lr_delay_steps = 2500
70
+ Config.lr_delay_mult = 0.01
71
+ Config.grad_max_norm = 0.1
72
+ Config.grad_max_val = 0.1
73
+ Config.adam_eps = 1e-8
configs/multi360.gin ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ include 'configs/360.gin'
2
+ Config.multiscale = True
3
+ Config.multiscale_levels = 4
4
+
5
+
eval.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import time
5
+ import accelerate
6
+ from absl import app
7
+ import gin
8
+ from internal import configs
9
+ from internal import datasets
10
+ from internal import image
11
+ from internal import models
12
+ from internal import raw_utils
13
+ from internal import ref_utils
14
+ from internal import train_utils
15
+ from internal import checkpoints
16
+ from internal import utils
17
+ from internal import vis
18
+ import numpy as np
19
+ import torch
20
+ import tensorboardX
21
+ from torch.utils._pytree import tree_map
22
+
23
+ configs.define_common_flags()
24
+
25
+
26
+ def summarize_results(folder, scene_names, num_buckets):
27
+ metric_names = ['psnrs', 'ssims', 'lpips']
28
+ num_iters = 1000000
29
+ precisions = [3, 4, 4, 4]
30
+
31
+ results = []
32
+ for scene_name in scene_names:
33
+ test_preds_folder = os.path.join(folder, scene_name, 'test_preds')
34
+ values = []
35
+ for metric_name in metric_names:
36
+ filename = os.path.join(folder, scene_name, 'test_preds', f'{metric_name}_{num_iters}.txt')
37
+ with utils.open_file(filename) as f:
38
+ v = np.array([float(s) for s in f.readline().split(' ')])
39
+ values.append(np.mean(np.reshape(v, [-1, num_buckets]), 0))
40
+ results.append(np.concatenate(values))
41
+ avg_results = np.mean(np.array(results), 0)
42
+
43
+ psnr, ssim, lpips = np.mean(np.reshape(avg_results, [-1, num_buckets]), 1)
44
+
45
+ mse = np.exp(-0.1 * np.log(10.) * psnr)
46
+ dssim = np.sqrt(1 - ssim)
47
+ avg_avg = np.exp(np.mean(np.log(np.array([mse, dssim, lpips]))))
48
+
49
+ s = []
50
+ for i, v in enumerate(np.reshape(avg_results, [-1, num_buckets])):
51
+ s.append(' '.join([f'{s:0.{precisions[i]}f}' for s in v]))
52
+ s.append(f'{avg_avg:0.{precisions[-1]}f}')
53
+ return ' | '.join(s)
54
+
55
+
56
+ def main(unused_argv):
57
+ config = configs.load_config()
58
+ config.exp_path = os.path.join('exp', config.exp_name)
59
+ config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints')
60
+ config.render_dir = os.path.join(config.exp_path, 'render')
61
+
62
+ accelerator = accelerate.Accelerator()
63
+
64
+ # setup logger
65
+ logging.basicConfig(
66
+ format="%(asctime)s: %(message)s",
67
+ datefmt="%Y-%m-%d %H:%M:%S",
68
+ force=True,
69
+ handlers=[logging.StreamHandler(sys.stdout),
70
+ logging.FileHandler(os.path.join(config.exp_path, 'log_eval.txt'))],
71
+ level=logging.INFO,
72
+ )
73
+ sys.excepthook = utils.handle_exception
74
+ logger = accelerate.logging.get_logger(__name__)
75
+ logger.info(config)
76
+ logger.info(accelerator.state, main_process_only=False)
77
+
78
+ config.world_size = accelerator.num_processes
79
+ config.global_rank = accelerator.process_index
80
+ accelerate.utils.set_seed(config.seed, device_specific=True)
81
+ model = models.Model(config=config)
82
+ model.eval()
83
+ model.to(accelerator.device)
84
+
85
+ dataset = datasets.load_dataset('test', config.data_dir, config)
86
+ dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)),
87
+ shuffle=False,
88
+ batch_size=1,
89
+ collate_fn=dataset.collate_fn,
90
+ )
91
+ tb_process_fn = lambda x: x.transpose(2, 0, 1) if len(x.shape) == 3 else x[None]
92
+ if config.rawnerf_mode:
93
+ postprocess_fn = dataset.metadata['postprocess_fn']
94
+ else:
95
+ postprocess_fn = lambda z: z
96
+
97
+ if config.eval_raw_affine_cc:
98
+ cc_fun = raw_utils.match_images_affine
99
+ else:
100
+ cc_fun = image.color_correct
101
+
102
+ model = accelerator.prepare(model)
103
+
104
+ metric_harness = image.MetricHarness()
105
+
106
+ last_step = 0
107
+ out_dir = os.path.join(config.exp_path,
108
+ 'path_renders' if config.render_path else 'test_preds')
109
+ path_fn = lambda x: os.path.join(out_dir, x)
110
+
111
+ if not config.eval_only_once:
112
+ summary_writer = tensorboardX.SummaryWriter(
113
+ os.path.join(config.exp_path, 'eval'))
114
+ while True:
115
+ step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger)
116
+ if step <= last_step:
117
+ logger.info(f'Checkpoint step {step} <= last step {last_step}, sleeping.')
118
+ time.sleep(10)
119
+ continue
120
+ logger.info(f'Evaluating checkpoint at step {step}.')
121
+ if config.eval_save_output and (not utils.isdir(out_dir)):
122
+ utils.makedirs(out_dir)
123
+
124
+ num_eval = min(dataset.size, config.eval_dataset_limit)
125
+ perm = np.random.permutation(num_eval)
126
+ showcase_indices = np.sort(perm[:config.num_showcase_images])
127
+ metrics = []
128
+ metrics_cc = []
129
+ showcases = []
130
+ render_times = []
131
+ for idx, batch in enumerate(dataloader):
132
+ batch = accelerate.utils.send_to_device(batch, accelerator.device)
133
+ eval_start_time = time.time()
134
+ if idx >= num_eval:
135
+ logger.info(f'Skipping image {idx + 1}/{dataset.size}')
136
+ continue
137
+ logger.info(f'Evaluating image {idx + 1}/{dataset.size}')
138
+ rendering = models.render_image(model, accelerator,
139
+ batch, False, 1, config)
140
+
141
+ if not accelerator.is_main_process: # Only record via host 0.
142
+ continue
143
+
144
+ render_times.append((time.time() - eval_start_time))
145
+ logger.info(f'Rendered in {render_times[-1]:0.3f}s')
146
+
147
+ cc_start_time = time.time()
148
+ rendering['rgb_cc'] = cc_fun(rendering['rgb'], batch['rgb'])
149
+
150
+ rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering)
151
+ batch = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, batch)
152
+
153
+ gt_rgb = batch['rgb']
154
+ logger.info(f'Color corrected in {(time.time() - cc_start_time):0.3f}s')
155
+
156
+ if not config.eval_only_once and idx in showcase_indices:
157
+ showcase_idx = idx if config.deterministic_showcase else len(showcases)
158
+ showcases.append((showcase_idx, rendering, batch))
159
+ if not config.render_path:
160
+ rgb = postprocess_fn(rendering['rgb'])
161
+ rgb_cc = postprocess_fn(rendering['rgb_cc'])
162
+ rgb_gt = postprocess_fn(gt_rgb)
163
+
164
+ if config.eval_quantize_metrics:
165
+ # Ensures that the images written to disk reproduce the metrics.
166
+ rgb = np.round(rgb * 255) / 255
167
+ rgb_cc = np.round(rgb_cc * 255) / 255
168
+
169
+ if config.eval_crop_borders > 0:
170
+ crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c]
171
+ rgb = crop_fn(rgb)
172
+ rgb_cc = crop_fn(rgb_cc)
173
+ rgb_gt = crop_fn(rgb_gt)
174
+
175
+ metric = metric_harness(rgb, rgb_gt)
176
+ metric_cc = metric_harness(rgb_cc, rgb_gt)
177
+
178
+ if config.compute_disp_metrics:
179
+ for tag in ['mean', 'median']:
180
+ key = f'distance_{tag}'
181
+ if key in rendering:
182
+ disparity = 1 / (1 + rendering[key])
183
+ metric[f'disparity_{tag}_mse'] = float(
184
+ ((disparity - batch['disps']) ** 2).mean())
185
+
186
+ if config.compute_normal_metrics:
187
+ weights = rendering['acc'] * batch['alphas']
188
+ normalized_normals_gt = ref_utils.l2_normalize_np(batch['normals'])
189
+ for key, val in rendering.items():
190
+ if key.startswith('normals') and val is not None:
191
+ normalized_normals = ref_utils.l2_normalize_np(val)
192
+ metric[key + '_mae'] = ref_utils.compute_weighted_mae_np(
193
+ weights, normalized_normals, normalized_normals_gt)
194
+
195
+ for m, v in metric.items():
196
+ logger.info(f'{m:30s} = {v:.4f}')
197
+
198
+ metrics.append(metric)
199
+ metrics_cc.append(metric_cc)
200
+
201
+ if config.eval_save_output and (config.eval_render_interval > 0):
202
+ if (idx % config.eval_render_interval) == 0:
203
+ utils.save_img_u8(postprocess_fn(rendering['rgb']),
204
+ path_fn(f'color_{idx:03d}.png'))
205
+ utils.save_img_u8(postprocess_fn(rendering['rgb_cc']),
206
+ path_fn(f'color_cc_{idx:03d}.png'))
207
+
208
+ for key in ['distance_mean', 'distance_median']:
209
+ if key in rendering:
210
+ utils.save_img_f32(rendering[key],
211
+ path_fn(f'{key}_{idx:03d}.tiff'))
212
+
213
+ for key in ['normals']:
214
+ if key in rendering:
215
+ utils.save_img_u8(rendering[key] / 2. + 0.5,
216
+ path_fn(f'{key}_{idx:03d}.png'))
217
+
218
+ utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff'))
219
+
220
+ if (not config.eval_only_once) and accelerator.is_main_process:
221
+ summary_writer.add_scalar('eval_median_render_time', np.median(render_times),
222
+ step)
223
+ for name in metrics[0]:
224
+ scores = [m[name] for m in metrics]
225
+ summary_writer.add_scalar('eval_metrics/' + name, np.mean(scores), step)
226
+ summary_writer.add_histogram('eval_metrics/' + 'perimage_' + name, scores,
227
+ step)
228
+ for name in metrics_cc[0]:
229
+ scores = [m[name] for m in metrics_cc]
230
+ summary_writer.add_scalar('eval_metrics_cc/' + name, np.mean(scores), step)
231
+ summary_writer.add_histogram('eval_metrics_cc/' + 'perimage_' + name,
232
+ scores, step)
233
+
234
+ for i, r, b in showcases:
235
+ if config.vis_decimate > 1:
236
+ d = config.vis_decimate
237
+ decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]
238
+ else:
239
+ decimate_fn = lambda x: x
240
+ r = tree_map(decimate_fn, r)
241
+ b = tree_map(decimate_fn, b)
242
+ visualizations = vis.visualize_suite(r, b)
243
+ for k, v in visualizations.items():
244
+ if k == 'color':
245
+ v = postprocess_fn(v)
246
+ summary_writer.add_image(f'output_{k}_{i}', tb_process_fn(v), step)
247
+ if not config.render_path:
248
+ target = postprocess_fn(b['rgb'])
249
+ summary_writer.add_image(f'true_color_{i}', tb_process_fn(target), step)
250
+ pred = postprocess_fn(visualizations['color'])
251
+ residual = np.clip(pred - target + 0.5, 0, 1)
252
+ summary_writer.add_image(f'true_residual_{i}', tb_process_fn(residual), step)
253
+ if config.compute_normal_metrics:
254
+ summary_writer.add_image(f'true_normals_{i}', tb_process_fn(b['normals']) / 2. + 0.5,
255
+ step)
256
+
257
+ if (config.eval_save_output and (not config.render_path) and
258
+ accelerator.is_main_process):
259
+ with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f:
260
+ f.write(' '.join([str(r) for r in render_times]))
261
+ logger.info(f'metrics:')
262
+ results = {}
263
+ num_buckets = config.multiscale_levels if config.multiscale else 1
264
+ for name in metrics[0]:
265
+ with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
266
+ ms = [m[name] for m in metrics]
267
+ f.write(' '.join([str(m) for m in ms]))
268
+ results[name] = ' | '.join(
269
+ list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist())))
270
+ with utils.open_file(path_fn(f'metric_avg_{step}.txt'), 'w') as f:
271
+ for name in metrics[0]:
272
+ f.write(f'{name}: {results[name]}\n')
273
+ logger.info(f'{name}: {results[name]}')
274
+ logger.info(f'metrics_cc:')
275
+ results_cc = {}
276
+ for name in metrics_cc[0]:
277
+ with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f:
278
+ ms = [m[name] for m in metrics_cc]
279
+ f.write(' '.join([str(m) for m in ms]))
280
+ results_cc[name] = ' | '.join(
281
+ list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist())))
282
+ with utils.open_file(path_fn(f'metric_cc_avg_{step}.txt'), 'w') as f:
283
+ for name in metrics[0]:
284
+ f.write(f'{name}: {results_cc[name]}\n')
285
+ logger.info(f'{name}: {results_cc[name]}')
286
+ if config.eval_save_ray_data:
287
+ for i, r, b in showcases:
288
+ rays = {k: v for k, v in r.items() if 'ray_' in k}
289
+ np.set_printoptions(threshold=sys.maxsize)
290
+ with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f:
291
+ f.write(repr(rays))
292
+
293
+ if config.eval_only_once:
294
+ break
295
+ if config.early_exit_steps is not None:
296
+ num_steps = config.early_exit_steps
297
+ else:
298
+ num_steps = config.max_steps
299
+ if int(step) >= num_steps:
300
+ break
301
+ last_step = step
302
+ logger.info('Finish evaluation.')
303
+
304
+
305
+ if __name__ == '__main__':
306
+ with gin.config_scope('eval'):
307
+ app.run(main)
extract.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from absl import app
8
+ import gin
9
+ from internal import configs
10
+ from internal import datasets
11
+ from internal import models
12
+ from internal import utils
13
+ from internal import coord
14
+ from internal import checkpoints
15
+ import torch
16
+ import accelerate
17
+ from tqdm import tqdm
18
+ from torch.utils._pytree import tree_map
19
+ import torch.nn.functional as F
20
+ from skimage import measure
21
+ import trimesh
22
+ import pymeshlab as pml
23
+
24
+ configs.define_common_flags()
25
+
26
+
27
+ @torch.no_grad()
28
+ def evaluate_density(model, accelerator: accelerate.Accelerator,
29
+ points, config: configs.Config, std_value=0.0):
30
+ """
31
+ Evaluate a signed distance function (SDF) for a batch of points.
32
+
33
+ Args:
34
+ sdf: A callable function that takes a tensor of size (N, 3) containing
35
+ 3D points and returns a tensor of size (N,) with the SDF values.
36
+ points: A torch tensor containing 3D points.
37
+
38
+ Returns:
39
+ A torch tensor with the SDF values evaluated at the given points.
40
+ """
41
+ z = []
42
+ for _, pnts in enumerate(tqdm(torch.split(points, config.render_chunk_size, dim=0),
43
+ desc="Evaluating density", leave=False,
44
+ disable=not accelerator.is_main_process)):
45
+ rays_remaining = pnts.shape[0] % accelerator.num_processes
46
+ if rays_remaining != 0:
47
+ padding = accelerator.num_processes - rays_remaining
48
+ pnts = torch.cat([pnts, torch.zeros_like(pnts[-padding:])], dim=0)
49
+ else:
50
+ padding = 0
51
+ rays_per_host = pnts.shape[0] // accelerator.num_processes
52
+ start, stop = accelerator.process_index * rays_per_host, \
53
+ (accelerator.process_index + 1) * rays_per_host
54
+ chunk_means = pnts[start:stop]
55
+ chunk_stds = torch.full_like(chunk_means[..., 0], std_value)
56
+ raw_density = model.nerf_mlp.predict_density(chunk_means[:, None], chunk_stds[:, None], no_warp=True)[0]
57
+ density = F.softplus(raw_density + model.nerf_mlp.density_bias)
58
+ density = accelerator.gather(density)
59
+ if padding > 0:
60
+ density = density[: -padding]
61
+ z.append(density)
62
+ z = torch.cat(z, dim=0)
63
+ return z
64
+
65
+
66
+ @torch.no_grad()
67
+ def evaluate_color(model, accelerator: accelerate.Accelerator,
68
+ points, config: configs.Config, std_value=0.0):
69
+ """
70
+ Evaluate a signed distance function (SDF) for a batch of points.
71
+
72
+ Args:
73
+ sdf: A callable function that takes a tensor of size (N, 3) containing
74
+ 3D points and returns a tensor of size (N,) with the SDF values.
75
+ points: A torch tensor containing 3D points.
76
+
77
+ Returns:
78
+ A torch tensor with the SDF values evaluated at the given points.
79
+ """
80
+ z = []
81
+ for _, pnts in enumerate(tqdm(torch.split(points, config.render_chunk_size, dim=0),
82
+ desc="Evaluating color",
83
+ disable=not accelerator.is_main_process)):
84
+ rays_remaining = pnts.shape[0] % accelerator.num_processes
85
+ if rays_remaining != 0:
86
+ padding = accelerator.num_processes - rays_remaining
87
+ pnts = torch.cat([pnts, torch.zeros_like(pnts[-padding:])], dim=0)
88
+ else:
89
+ padding = 0
90
+ rays_per_host = pnts.shape[0] // accelerator.num_processes
91
+ start, stop = accelerator.process_index * rays_per_host, \
92
+ (accelerator.process_index + 1) * rays_per_host
93
+ chunk_means = pnts[start:stop]
94
+ chunk_stds = torch.full_like(chunk_means[..., 0], std_value)
95
+ chunk_viewdirs = torch.zeros_like(chunk_means)
96
+ ray_results = model.nerf_mlp(False, chunk_means[:, None, None], chunk_stds[:, None, None],
97
+ chunk_viewdirs)
98
+ rgb = ray_results['rgb'][:, 0]
99
+ rgb = accelerator.gather(rgb)
100
+ if padding > 0:
101
+ rgb = rgb[: -padding]
102
+ z.append(rgb)
103
+ z = torch.cat(z, dim=0)
104
+ return z
105
+
106
+
107
+ @torch.no_grad()
108
+ def evaluate_color_projection(model, accelerator: accelerate.Accelerator, vertices, faces, config: configs.Config):
109
+ normals = auto_normals(vertices, faces.long())
110
+ viewdirs = -normals
111
+ origins = vertices - 0.005 * viewdirs
112
+ vc = []
113
+ chunk = config.render_chunk_size
114
+ model.num_levels = 1
115
+ model.opaque_background = True
116
+ for i in tqdm(range(0, origins.shape[0], chunk),
117
+ desc="Evaluating color projection",
118
+ disable=not accelerator.is_main_process):
119
+ cur_chunk = min(chunk, origins.shape[0] - i)
120
+ rays_remaining = cur_chunk % accelerator.num_processes
121
+ rays_per_host = cur_chunk // accelerator.num_processes
122
+ if rays_remaining != 0:
123
+ padding = accelerator.num_processes - rays_remaining
124
+ rays_per_host += 1
125
+ else:
126
+ padding = 0
127
+ start = i + accelerator.process_index * rays_per_host
128
+ stop = start + rays_per_host
129
+
130
+ batch = {
131
+ 'origins': origins[start:stop],
132
+ 'directions': viewdirs[start:stop],
133
+ 'viewdirs': viewdirs[start:stop],
134
+ 'cam_dirs': viewdirs[start:stop],
135
+ 'radii': torch.full_like(origins[start:stop, ..., :1], 0.000723),
136
+ 'near': torch.full_like(origins[start:stop, ..., :1], 0),
137
+ 'far': torch.full_like(origins[start:stop, ..., :1], 0.01),
138
+ }
139
+ batch = accelerator.pad_across_processes(batch)
140
+ with accelerator.autocast():
141
+ renderings, ray_history = model(
142
+ False,
143
+ batch,
144
+ compute_extras=False,
145
+ train_frac=1)
146
+ rgb = renderings[-1]['rgb']
147
+ acc = renderings[-1]['acc']
148
+
149
+ rgb /= acc.clamp_min(1e-5)[..., None]
150
+ rgb = rgb.clamp(0, 1)
151
+
152
+ rgb = accelerator.gather(rgb)
153
+ rgb[torch.isnan(rgb) | torch.isinf(rgb)] = 1
154
+ if padding > 0:
155
+ rgb = rgb[: -padding]
156
+ vc.append(rgb)
157
+ vc = torch.cat(vc, dim=0)
158
+ return vc
159
+
160
+
161
+ def auto_normals(verts, faces):
162
+ i0 = faces[:, 0]
163
+ i1 = faces[:, 1]
164
+ i2 = faces[:, 2]
165
+
166
+ v0 = verts[i0, :]
167
+ v1 = verts[i1, :]
168
+ v2 = verts[i2, :]
169
+
170
+ face_normals = torch.cross(v1 - v0, v2 - v0)
171
+
172
+ # Splat face normals to vertices
173
+ v_nrm = torch.zeros_like(verts)
174
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
175
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
176
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
177
+
178
+ # Normalize, replace zero (degenerated) normals with some default value
179
+ v_nrm = torch.where((v_nrm ** 2).sum(dim=-1, keepdims=True) > 1e-20, v_nrm,
180
+ torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=verts.device))
181
+ v_nrm = F.normalize(v_nrm, dim=-1)
182
+ return v_nrm
183
+
184
+
185
+ def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01, logger=None, main_process=True):
186
+ # verts: [N, 3]
187
+ # faces: [N, 3]
188
+ tbar = tqdm(total=9, desc='Clean mesh', leave=False, disable=not main_process)
189
+ _ori_vert_shape = verts.shape
190
+ _ori_face_shape = faces.shape
191
+
192
+ m = pml.Mesh(verts, faces)
193
+ ms = pml.MeshSet()
194
+ ms.add_mesh(m, 'mesh') # will copy!
195
+
196
+ # filters
197
+ tbar.set_description('Remove unreferenced vertices')
198
+ ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
199
+ tbar.update()
200
+
201
+ if v_pct > 0:
202
+ tbar.set_description('Remove unreferenced vertices')
203
+ ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
204
+ tbar.update()
205
+
206
+ tbar.set_description('Remove duplicate faces')
207
+ ms.meshing_remove_duplicate_faces() # faces defined by the same verts
208
+ tbar.update()
209
+
210
+ tbar.set_description('Remove null faces')
211
+ ms.meshing_remove_null_faces() # faces with area == 0
212
+ tbar.update()
213
+
214
+ if min_d > 0:
215
+ tbar.set_description('Remove connected component by diameter')
216
+ ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
217
+ tbar.update()
218
+
219
+ if min_f > 0:
220
+ tbar.set_description('Remove connected component by face number')
221
+ ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
222
+ tbar.update()
223
+
224
+ if repair:
225
+ # tbar.set_description('Remove t vertices')
226
+ # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
227
+ tbar.set_description('Repair non manifold edges')
228
+ ms.meshing_repair_non_manifold_edges(method=0)
229
+ tbar.update()
230
+ tbar.set_description('Repair non manifold vertices')
231
+ ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
232
+ tbar.update()
233
+ else:
234
+ tbar.update(2)
235
+ if remesh:
236
+ # tbar.set_description('Coord taubin smoothing')
237
+ # ms.apply_coord_taubin_smoothing()
238
+ tbar.set_description('Isotropic explicit remeshing')
239
+ ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))
240
+ tbar.update()
241
+
242
+ # extract mesh
243
+ m = ms.current_mesh()
244
+ verts = m.vertex_matrix()
245
+ faces = m.face_matrix()
246
+
247
+ if logger is not None:
248
+ logger.info(f'Mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
249
+
250
+ return verts, faces
251
+
252
+
253
+ def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True, logger=None):
254
+ # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
255
+
256
+ _ori_vert_shape = verts.shape
257
+ _ori_face_shape = faces.shape
258
+
259
+ if backend == 'pyfqmr':
260
+ import pyfqmr
261
+ solver = pyfqmr.Simplify()
262
+ solver.setMesh(verts, faces)
263
+ solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
264
+ verts, faces, normals = solver.getMesh()
265
+ else:
266
+
267
+ m = pml.Mesh(verts, faces)
268
+ ms = pml.MeshSet()
269
+ ms.add_mesh(m, 'mesh') # will copy!
270
+
271
+ # filters
272
+ # ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
273
+ ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
274
+
275
+ if remesh:
276
+ # ms.apply_coord_taubin_smoothing()
277
+ ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
278
+
279
+ # extract mesh
280
+ m = ms.current_mesh()
281
+ verts = m.vertex_matrix()
282
+ faces = m.face_matrix()
283
+
284
+ if logger is not None:
285
+ logger.info(f'Mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
286
+
287
+ return verts, faces
288
+
289
+
290
+ def main(unused_argv):
291
+ config = configs.load_config()
292
+ config.compute_visibility = True
293
+
294
+ config.exp_path = os.path.join("exp", config.exp_name)
295
+ config.mesh_path = os.path.join("exp", config.exp_name, "mesh")
296
+ config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints')
297
+ os.makedirs(config.mesh_path, exist_ok=True)
298
+
299
+ # accelerator for DDP
300
+ accelerator = accelerate.Accelerator()
301
+ device = accelerator.device
302
+
303
+ # setup logger
304
+ logging.basicConfig(
305
+ format="%(asctime)s: %(message)s",
306
+ datefmt="%Y-%m-%d %H:%M:%S",
307
+ force=True,
308
+ handlers=[logging.StreamHandler(sys.stdout),
309
+ logging.FileHandler(os.path.join(config.exp_path, 'log_extract.txt'))],
310
+ level=logging.INFO,
311
+ )
312
+ sys.excepthook = utils.handle_exception
313
+ logger = accelerate.logging.get_logger(__name__)
314
+ logger.info(config)
315
+ logger.info(accelerator.state, main_process_only=False)
316
+
317
+ config.world_size = accelerator.num_processes
318
+ config.global_rank = accelerator.process_index
319
+ accelerate.utils.set_seed(config.seed, device_specific=True)
320
+
321
+ # setup model and optimizer
322
+ model = models.Model(config=config)
323
+ model = accelerator.prepare(model)
324
+ step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger)
325
+ model.eval()
326
+ module = accelerator.unwrap_model(model)
327
+
328
+ visibility_path = os.path.join(config.mesh_path, 'visibility_mask_{:.1f}.pt'.format(config.mesh_radius))
329
+ visibility_resolution = config.visibility_resolution
330
+ if not os.path.exists(visibility_path):
331
+ logger.info('Generate visibility mask...')
332
+ # load dataset
333
+ dataset = datasets.load_dataset('train', config.data_dir, config)
334
+ dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)),
335
+ num_workers=4,
336
+ shuffle=True,
337
+ batch_size=1,
338
+ collate_fn=dataset.collate_fn,
339
+ persistent_workers=True,
340
+ )
341
+
342
+ visibility_mask = torch.ones(
343
+ (1, 1, visibility_resolution, visibility_resolution, visibility_resolution), requires_grad=True
344
+ ).to(device)
345
+ visibility_mask.retain_grad()
346
+ tbar = tqdm(dataloader, desc='Generating visibility grid', disable=not accelerator.is_main_process)
347
+ for index, batch in enumerate(tbar):
348
+ batch = accelerate.utils.send_to_device(batch, accelerator.device)
349
+
350
+ rendering = models.render_image(model, accelerator,
351
+ batch, False, 1, config,
352
+ verbose=False, return_weights=True)
353
+
354
+ coords = rendering['coord'].reshape(-1, 3)
355
+ weights = rendering['weights'].reshape(-1)
356
+
357
+ valid_points = coords[weights > config.valid_weight_thresh]
358
+ valid_points /= config.mesh_radius
359
+ # update mask based on ray samples
360
+ with torch.enable_grad():
361
+ out = torch.nn.functional.grid_sample(visibility_mask,
362
+ valid_points[None, None, None],
363
+ align_corners=True)
364
+ out.sum().backward()
365
+ tbar.set_postfix({"visibility_mask": (visibility_mask.grad > 0.0001).float().mean().item()})
366
+ # if index == 10:
367
+ # break
368
+ visibility_mask = (visibility_mask.grad > 0.0001).float()
369
+ if accelerator.is_main_process:
370
+ torch.save(visibility_mask.detach().cpu(), visibility_path)
371
+ else:
372
+ logger.info('Load visibility mask from {}'.format(visibility_path))
373
+ visibility_mask = torch.load(visibility_path, map_location=device)
374
+
375
+ space = config.mesh_radius * 2 / (config.visibility_resolution - 1)
376
+
377
+ logger.info("Extract mesh from visibility mask...")
378
+ visibility_mask_np = visibility_mask[0, 0].permute(2, 1, 0).detach().cpu().numpy()
379
+ verts, faces, normals, values = measure.marching_cubes(
380
+ volume=-visibility_mask_np,
381
+ level=-0.5,
382
+ spacing=(space, space, space))
383
+ verts -= config.mesh_radius
384
+ if config.extract_visibility:
385
+ meshexport = trimesh.Trimesh(verts, faces)
386
+ meshexport.export(os.path.join(config.mesh_path, "visibility_mask_{}.ply".format(config.mesh_radius)), "ply")
387
+ logger.info("Extract visibility mask done.")
388
+
389
+ # Initialize variables
390
+ crop_n = 512
391
+ grid_min = verts.min(axis=0)
392
+ grid_max = verts.max(axis=0)
393
+ space = ((grid_max - grid_min).prod() / config.mesh_voxels) ** (1 / 3)
394
+ world_size = ((grid_max - grid_min) / space).astype(np.int32)
395
+ Nx, Ny, Nz = np.maximum(1, world_size // crop_n)
396
+ crop_n_x, crop_n_y, crop_n_z = world_size // [Nx, Ny, Nz]
397
+ xs = np.linspace(grid_min[0], grid_max[0], Nx + 1)
398
+ ys = np.linspace(grid_min[1], grid_max[1], Ny + 1)
399
+ zs = np.linspace(grid_min[2], grid_max[2], Nz + 1)
400
+ # Initialize meshes list
401
+ meshes = []
402
+
403
+ # Iterate over the grid
404
+ for i in range(Nx):
405
+ for j in range(Ny):
406
+ for k in range(Nz):
407
+ logger.info(f"Process grid cell ({i + 1}/{Nx}, {j + 1}/{Ny}, {k + 1}/{Nz})...")
408
+ # Calculate grid cell boundaries
409
+ x_min, x_max = xs[i], xs[i + 1]
410
+ y_min, y_max = ys[j], ys[j + 1]
411
+ z_min, z_max = zs[k], zs[k + 1]
412
+
413
+ # Create point grid
414
+ x = np.linspace(x_min, x_max, crop_n_x)
415
+ y = np.linspace(y_min, y_max, crop_n_y)
416
+ z = np.linspace(z_min, z_max, crop_n_z)
417
+ xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
418
+ points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T,
419
+ dtype=torch.float,
420
+ device=device)
421
+ # Construct point pyramids
422
+ points_tmp = points.reshape(crop_n_x, crop_n_y, crop_n_z, 3)[None]
423
+ points_tmp /= config.mesh_radius
424
+ current_mask = torch.nn.functional.grid_sample(visibility_mask, points_tmp, align_corners=True)
425
+ current_mask = (current_mask > 0.0).cpu().numpy()[0, 0]
426
+
427
+ pts_density = evaluate_density(module, accelerator, points,
428
+ config, std_value=config.std_value)
429
+
430
+ # bound the vertices
431
+ points_world = coord.inv_contract(2 * points)
432
+ pts_density[points_world.norm(dim=-1) > config.mesh_max_radius] = 0.0
433
+
434
+ z = pts_density.detach().cpu().numpy()
435
+
436
+ # Skip if no surface found
437
+ valid_z = z.reshape(crop_n_x, crop_n_y, crop_n_z)[current_mask]
438
+ if valid_z.shape[0] <= 0 or (
439
+ np.min(valid_z) > config.isosurface_threshold or np.max(
440
+ valid_z) < config.isosurface_threshold
441
+ ):
442
+ continue
443
+
444
+ if not (np.min(z) > config.isosurface_threshold or np.max(z) < config.isosurface_threshold):
445
+ # Extract mesh
446
+ logger.info('Extract mesh...')
447
+ z = z.astype(np.float32)
448
+ verts, faces, _, _ = measure.marching_cubes(
449
+ volume=-z.reshape(crop_n_x, crop_n_y, crop_n_z),
450
+ level=-config.isosurface_threshold,
451
+ spacing=(
452
+ (x_max - x_min) / (crop_n_x - 1),
453
+ (y_max - y_min) / (crop_n_y - 1),
454
+ (z_max - z_min) / (crop_n_z - 1),
455
+ ),
456
+ mask=current_mask,
457
+ )
458
+ verts = verts + np.array([x_min, y_min, z_min])
459
+
460
+ meshcrop = trimesh.Trimesh(verts, faces)
461
+ logger.info('Extract vertices: {}, faces: {}'.format(meshcrop.vertices.shape[0],
462
+ meshcrop.faces.shape[0]))
463
+ meshes.append(meshcrop)
464
+ # Save mesh
465
+ logger.info('Concatenate mesh...')
466
+ combined_mesh = trimesh.util.concatenate(meshes)
467
+
468
+ # from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/renderer.py
469
+ # clean
470
+ logger.info('Clean mesh...')
471
+ vertices = combined_mesh.vertices.astype(np.float32)
472
+ faces = combined_mesh.faces.astype(np.int32)
473
+
474
+ vertices, faces = clean_mesh(vertices, faces,
475
+ remesh=False, remesh_size=0.01,
476
+ logger=logger, main_process=accelerator.is_main_process)
477
+
478
+ v = torch.from_numpy(vertices).contiguous().float().to(device)
479
+ v = coord.inv_contract(2 * v)
480
+ vertices = v.detach().cpu().numpy()
481
+ f = torch.from_numpy(faces).contiguous().int().to(device)
482
+
483
+ # decimation
484
+ if config.decimate_target > 0 and faces.shape[0] > config.decimate_target:
485
+ logger.info('Decimate mesh...')
486
+ vertices, triangles = decimate_mesh(vertices, faces, config.decimate_target, logger=logger)
487
+ # import ipdb; ipdb.set_trace()
488
+ if config.vertex_color:
489
+ # batched inference to avoid OOM
490
+ logger.info('Evaluate mesh vertex color...')
491
+ if config.vertex_projection:
492
+ rgbs = evaluate_color_projection(module, accelerator, v, f, config)
493
+ else:
494
+ rgbs = evaluate_color(module, accelerator, v,
495
+ config, std_value=config.std_value)
496
+ rgbs = (rgbs * 255).detach().cpu().numpy().astype(np.uint8)
497
+ if accelerator.is_main_process:
498
+ logger.info('Export mesh (vertex color)...')
499
+ mesh = trimesh.Trimesh(vertices, faces,
500
+ vertex_colors=rgbs,
501
+ process=False) # important, process=True leads to seg fault...
502
+ mesh.export(os.path.join(config.mesh_path, 'mesh_{}.ply'.format(config.mesh_radius)))
503
+ logger.info('Finish extracting mesh.')
504
+ return
505
+
506
+ def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
507
+ logger.info('Export mesh (atlas)...')
508
+ # v, f: torch Tensor
509
+ device = v.device
510
+ v_np = v.cpu().numpy() # [N, 3]
511
+ f_np = f.cpu().numpy() # [M, 3]
512
+
513
+ # unwrap uvs
514
+ import xatlas
515
+ import nvdiffrast.torch as dr
516
+ from sklearn.neighbors import NearestNeighbors
517
+ from scipy.ndimage import binary_dilation, binary_erosion
518
+
519
+ logger.info(f'Running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
520
+ atlas = xatlas.Atlas()
521
+ atlas.add_mesh(v_np, f_np)
522
+ chart_options = xatlas.ChartOptions()
523
+ chart_options.max_iterations = 4 # for faster unwrap...
524
+ atlas.generate(chart_options=chart_options)
525
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
526
+
527
+ # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
528
+
529
+ vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
530
+ ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
531
+
532
+ # render uv maps
533
+ uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
534
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
535
+
536
+ if ssaa > 1:
537
+ h = int(h0 * ssaa)
538
+ w = int(w0 * ssaa)
539
+ else:
540
+ h, w = h0, w0
541
+
542
+ if h <= 2048 and w <= 2048:
543
+ glctx = dr.RasterizeCudaContext()
544
+ else:
545
+ glctx = dr.RasterizeGLContext()
546
+
547
+ rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
548
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
549
+ mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
550
+
551
+ # masked query
552
+ xyzs = xyzs.view(-1, 3)
553
+ mask = (mask > 0).view(-1)
554
+
555
+ feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
556
+
557
+ if mask.any():
558
+ xyzs = xyzs[mask] # [M, 3]
559
+
560
+ # batched inference to avoid OOM
561
+ all_feats = evaluate_color(module, accelerator, xyzs,
562
+ config, std_value=config.std_value)
563
+ feats[mask] = all_feats
564
+
565
+ feats = feats.view(h, w, -1)
566
+ mask = mask.view(h, w)
567
+
568
+ # quantize [0.0, 1.0] to [0, 255]
569
+ feats = feats.cpu().numpy()
570
+ feats = (feats * 255).astype(np.uint8)
571
+
572
+ ### NN search as an antialiasing ...
573
+ mask = mask.cpu().numpy()
574
+
575
+ inpaint_region = binary_dilation(mask, iterations=3)
576
+ inpaint_region[mask] = 0
577
+
578
+ search_region = mask.copy()
579
+ not_search_region = binary_erosion(search_region, iterations=2)
580
+ search_region[not_search_region] = 0
581
+
582
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
583
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
584
+
585
+ knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
586
+ _, indices = knn.kneighbors(inpaint_coords)
587
+
588
+ feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
589
+
590
+ feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
591
+
592
+ # do ssaa after the NN search, in numpy
593
+ if ssaa > 1:
594
+ feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
595
+
596
+ cv2.imwrite(os.path.join(config.mesh_path, f'{name}albedo.png'), feats)
597
+
598
+ # save obj (v, vt, f /)
599
+ obj_file = os.path.join(config.mesh_path, f'{name}mesh.obj')
600
+ mtl_file = os.path.join(config.mesh_path, f'{name}mesh.mtl')
601
+
602
+ logger.info(f'writing obj mesh to {obj_file}')
603
+ with open(obj_file, "w") as fp:
604
+ fp.write(f'mtllib {name}mesh.mtl \n')
605
+
606
+ logger.info(f'writing vertices {v_np.shape}')
607
+ for v in v_np:
608
+ fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
609
+
610
+ logger.info(f'writing vertices texture coords {vt_np.shape}')
611
+ for v in vt_np:
612
+ fp.write(f'vt {v[0]} {1 - v[1]} \n')
613
+
614
+ logger.info(f'writing faces {f_np.shape}')
615
+ fp.write(f'usemtl mat0 \n')
616
+ for i in range(len(f_np)):
617
+ fp.write(
618
+ f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
619
+
620
+ with open(mtl_file, "w") as fp:
621
+ fp.write(f'newmtl mat0 \n')
622
+ fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
623
+ fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
624
+ fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
625
+ fp.write(f'Tr 1.000000 \n')
626
+ fp.write(f'illum 1 \n')
627
+ fp.write(f'Ns 0.000000 \n')
628
+ fp.write(f'map_Kd {name}albedo.png \n')
629
+
630
+ # could be extremely slow
631
+ _export(v, f)
632
+
633
+ logger.info('Finish extracting mesh.')
634
+
635
+
636
+ if __name__ == '__main__':
637
+ with gin.config_scope('bake'):
638
+ app.run(main)
gridencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .grid import GridEncoder
gridencoder/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_grid_encoder',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'gridencoder.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
gridencoder/grid.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _gridencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+ _gridtype_to_id = {
15
+ 'hash': 0,
16
+ 'tiled': 1,
17
+ }
18
+
19
+ _interp_to_id = {
20
+ 'linear': 0,
21
+ 'smoothstep': 1,
22
+ }
23
+
24
+ class _grid_encode(Function):
25
+ @staticmethod
26
+ @custom_fwd
27
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0):
28
+ # inputs: [B, D], float in [0, 1]
29
+ # embeddings: [sO, C], float
30
+ # offsets: [L + 1], int
31
+ # RETURN: [B, F], float
32
+
33
+ inputs = inputs.contiguous()
34
+
35
+ B, D = inputs.shape # batch size, coord dim
36
+ L = offsets.shape[0] - 1 # level
37
+ C = embeddings.shape[1] # embedding dim for each level
38
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
39
+ H = base_resolution # base resolution
40
+
41
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
42
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
43
+ if torch.is_autocast_enabled() and C % 2 == 0:
44
+ embeddings = embeddings.to(torch.half)
45
+
46
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
47
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
48
+
49
+ if calc_grad_inputs:
50
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
51
+ else:
52
+ dy_dx = None
53
+
54
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation)
55
+
56
+ # permute back to [B, L * C]
57
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
58
+
59
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
60
+ ctx.dims = [B, D, C, L, S, H, gridtype, interpolation]
61
+ ctx.align_corners = align_corners
62
+
63
+ return outputs
64
+
65
+ @staticmethod
66
+ #@once_differentiable
67
+ @custom_bwd
68
+ def backward(ctx, grad):
69
+
70
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
71
+ B, D, C, L, S, H, gridtype, interpolation = ctx.dims
72
+ align_corners = ctx.align_corners
73
+
74
+ # grad: [B, L * C] --> [L, B, C]
75
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
76
+
77
+ grad_embeddings = torch.zeros_like(embeddings)
78
+
79
+ if dy_dx is not None:
80
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
81
+ else:
82
+ grad_inputs = None
83
+
84
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
85
+
86
+ if dy_dx is not None:
87
+ grad_inputs = grad_inputs.to(inputs.dtype)
88
+
89
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None, None
90
+
91
+
92
+
93
+ grid_encode = _grid_encode.apply
94
+
95
+
96
+ class GridEncoder(nn.Module):
97
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2,
98
+ per_level_scale=2, base_resolution=16,
99
+ log2_hashmap_size=19, desired_resolution=None,
100
+ gridtype='hash', align_corners=False,
101
+ interpolation='linear', init_std=1e-4):
102
+ super().__init__()
103
+
104
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
105
+ if desired_resolution is not None:
106
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
107
+
108
+ self.input_dim = input_dim # coord dims, 2 or 3
109
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
110
+ self.level_dim = level_dim # encode channels per level
111
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
112
+ self.log2_hashmap_size = log2_hashmap_size
113
+ self.base_resolution = base_resolution
114
+ self.output_dim = num_levels * level_dim
115
+ self.gridtype = gridtype
116
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
117
+ self.interpolation = interpolation
118
+ self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
119
+ self.align_corners = align_corners
120
+ self.init_std = init_std
121
+
122
+ # allocate parameters
123
+ resolutions = []
124
+ offsets = []
125
+ offset = 0
126
+ self.max_params = 2 ** log2_hashmap_size
127
+ for i in range(num_levels):
128
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
129
+ resolution = (resolution if align_corners else resolution + 1)
130
+ params_in_level = min(self.max_params, resolution ** input_dim) # limit max number
131
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
132
+ resolutions.append(resolution)
133
+ offsets.append(offset)
134
+ offset += params_in_level
135
+ offsets.append(offset)
136
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
137
+ self.register_buffer('offsets', offsets)
138
+ idx = torch.empty(offset, dtype=torch.long)
139
+ for i in range(self.num_levels):
140
+ idx[offsets[i]:offsets[i+1]] = i
141
+ self.register_buffer('idx', idx)
142
+ self.register_buffer('grid_sizes', torch.from_numpy(np.array(resolutions, dtype=np.int32)))
143
+
144
+ self.n_params = offsets[-1] * level_dim
145
+
146
+ # parameters
147
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
148
+
149
+ self.reset_parameters()
150
+
151
+ def reset_parameters(self):
152
+ std = self.init_std
153
+ self.embeddings.data.uniform_(-std, std)
154
+
155
+ def __repr__(self):
156
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
157
+
158
+ def forward(self, inputs, bound=1):
159
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
160
+ # return: [..., num_levels * level_dim]
161
+
162
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
163
+ # inputs = inputs.clamp(0, 1)
164
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
165
+
166
+ prefix_shape = list(inputs.shape[:-1])
167
+ inputs = inputs.view(-1, self.input_dim)
168
+
169
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id)
170
+ outputs = outputs.view(prefix_shape + [self.output_dim])
171
+
172
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
173
+
174
+ return outputs
175
+
176
+ # always run in float precision!
177
+ @torch.cuda.amp.autocast(enabled=False)
178
+ def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
179
+ # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
180
+
181
+ D = self.input_dim
182
+ C = self.embeddings.shape[1] # embedding dim for each level
183
+ L = self.offsets.shape[0] - 1 # level
184
+ S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
185
+ H = self.base_resolution # base resolution
186
+
187
+ if inputs is None:
188
+ # randomized in [0, 1]
189
+ inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
190
+ else:
191
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
192
+ inputs = inputs.view(-1, self.input_dim)
193
+ B = inputs.shape[0]
194
+
195
+ if self.embeddings.grad is None:
196
+ raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
197
+
198
+ _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
gridencoder/setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ setup(
33
+ name='gridencoder', # package name, import this to use python API
34
+ ext_modules=[
35
+ CUDAExtension(
36
+ name='_gridencoder', # extension name, import this to use CUDA API
37
+ sources=[os.path.join(_src_path, 'src', f) for f in [
38
+ 'gridencoder.cu',
39
+ 'bindings.cpp',
40
+ ]],
41
+ extra_compile_args={
42
+ 'cxx': c_flags,
43
+ 'nvcc': nvcc_flags,
44
+ }
45
+ ),
46
+ ],
47
+ cmdclass={
48
+ 'build_ext': BuildExtension,
49
+ }
50
+ )
gridencoder/src/bindings.cpp ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "gridencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8
+ m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
9
+ }
gridencoder/src/gridencoder.cu ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <torch/torch.h>
7
+
8
+ #include <algorithm>
9
+ #include <stdexcept>
10
+
11
+ #include <stdint.h>
12
+ #include <cstdio>
13
+
14
+
15
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
16
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
17
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
18
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
19
+
20
+
21
+ // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
22
+ __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
23
+ // requires CUDA >= 10 and ARCH >= 70
24
+ // this is very slow compared to float or __half2, never use it.
25
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
26
+ }
27
+
28
+
29
+ template <typename T>
30
+ __host__ __device__ inline T div_round_up(T val, T divisor) {
31
+ return (val + divisor - 1) / divisor;
32
+ }
33
+
34
+ template <typename T, typename T2>
35
+ __host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) {
36
+ return min(max(v, lo), hi);
37
+ }
38
+
39
+ template <typename T>
40
+ __device__ inline T smoothstep(T val) {
41
+ return val*val*(3.0f - 2.0f * val);
42
+ }
43
+
44
+ template <typename T>
45
+ __device__ inline T smoothstep_derivative(T val) {
46
+ return 6*val*(1.0f - val);
47
+ }
48
+
49
+
50
+ template <uint32_t D>
51
+ __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
52
+
53
+ // coherent type of hashing
54
+ constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
55
+
56
+ uint32_t result = 0;
57
+ #pragma unroll
58
+ for (uint32_t i = 0; i < D; ++i) {
59
+ result ^= pos_grid[i] * primes[i];
60
+ }
61
+
62
+ return result;
63
+ }
64
+
65
+
66
+ template <uint32_t D, uint32_t C>
67
+ __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
68
+ uint32_t stride = 1;
69
+ uint32_t index = 0;
70
+
71
+ #pragma unroll
72
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
73
+ index += pos_grid[d] * stride;
74
+ stride *= align_corners ? resolution: (resolution + 1);
75
+ }
76
+
77
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
78
+ // gridtype: 0 == hash, 1 == tiled
79
+ if (gridtype == 0 && stride > hashmap_size) {
80
+ index = fast_hash<D>(pos_grid);
81
+ }
82
+
83
+ return (index % hashmap_size) * C + ch;
84
+ }
85
+
86
+
87
+ template <typename scalar_t, uint32_t D, uint32_t C>
88
+ __global__ void kernel_grid(
89
+ const float * __restrict__ inputs,
90
+ const scalar_t * __restrict__ grid,
91
+ const int * __restrict__ offsets,
92
+ scalar_t * __restrict__ outputs,
93
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
94
+ scalar_t * __restrict__ dy_dx,
95
+ const uint32_t gridtype,
96
+ const bool align_corners,
97
+ const uint32_t interp
98
+ ) {
99
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
100
+
101
+ if (b >= B) return;
102
+
103
+ const uint32_t level = blockIdx.y;
104
+
105
+ // locate
106
+ grid += (uint32_t)offsets[level] * C;
107
+ inputs += b * D;
108
+ outputs += level * B * C + b * C;
109
+
110
+ // check input range (should be in [0, 1])
111
+ bool flag_oob = false;
112
+ #pragma unroll
113
+ for (uint32_t d = 0; d < D; d++) {
114
+ if (inputs[d] < 0 || inputs[d] > 1) {
115
+ flag_oob = true;
116
+ }
117
+ }
118
+ // if input out of bound, just set output to 0
119
+ if (flag_oob) {
120
+ #pragma unroll
121
+ for (uint32_t ch = 0; ch < C; ch++) {
122
+ outputs[ch] = 0;
123
+ }
124
+ if (dy_dx) {
125
+ dy_dx += b * D * L * C + level * D * C; // B L D C
126
+ #pragma unroll
127
+ for (uint32_t d = 0; d < D; d++) {
128
+ #pragma unroll
129
+ for (uint32_t ch = 0; ch < C; ch++) {
130
+ dy_dx[d * C + ch] = 0;
131
+ }
132
+ }
133
+ }
134
+ return;
135
+ }
136
+
137
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
138
+ const float scale = exp2f(level * S) * H - 1.0f;
139
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
140
+
141
+ // calculate coordinate (always use float for precision!)
142
+ float pos[D];
143
+ float pos_deriv[D];
144
+ uint32_t pos_grid[D];
145
+
146
+ #pragma unroll
147
+ for (uint32_t d = 0; d < D; d++) {
148
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
149
+ pos_grid[d] = floorf(pos[d]);
150
+ pos[d] -= (float)pos_grid[d];
151
+ // smoothstep instead of linear
152
+ if (interp == 1) {
153
+ pos_deriv[d] = smoothstep_derivative(pos[d]);
154
+ pos[d] = smoothstep(pos[d]);
155
+ } else {
156
+ pos_deriv[d] = 1.0f; // linear deriv is default to 1
157
+ }
158
+
159
+ }
160
+
161
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
162
+
163
+ // interpolate
164
+ scalar_t results[C] = {0}; // temp results in register
165
+
166
+ #pragma unroll
167
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
168
+ float w = 1;
169
+ uint32_t pos_grid_local[D];
170
+
171
+ #pragma unroll
172
+ for (uint32_t d = 0; d < D; d++) {
173
+ if ((idx & (1 << d)) == 0) {
174
+ w *= 1 - pos[d];
175
+ pos_grid_local[d] = pos_grid[d];
176
+ } else {
177
+ w *= pos[d];
178
+ pos_grid_local[d] = pos_grid[d] + 1;
179
+ }
180
+ }
181
+
182
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
183
+
184
+ // writing to register (fast)
185
+ #pragma unroll
186
+ for (uint32_t ch = 0; ch < C; ch++) {
187
+ results[ch] += w * grid[index + ch];
188
+ }
189
+
190
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
191
+ }
192
+
193
+ // writing to global memory (slow)
194
+ #pragma unroll
195
+ for (uint32_t ch = 0; ch < C; ch++) {
196
+ outputs[ch] = results[ch];
197
+ }
198
+
199
+ // prepare dy_dx
200
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
201
+ if (dy_dx) {
202
+
203
+ dy_dx += b * D * L * C + level * D * C; // B L D C
204
+
205
+ #pragma unroll
206
+ for (uint32_t gd = 0; gd < D; gd++) {
207
+
208
+ scalar_t results_grad[C] = {0};
209
+
210
+ #pragma unroll
211
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
212
+ float w = scale;
213
+ uint32_t pos_grid_local[D];
214
+
215
+ #pragma unroll
216
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
217
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
218
+
219
+ if ((idx & (1 << nd)) == 0) {
220
+ w *= 1 - pos[d];
221
+ pos_grid_local[d] = pos_grid[d];
222
+ } else {
223
+ w *= pos[d];
224
+ pos_grid_local[d] = pos_grid[d] + 1;
225
+ }
226
+ }
227
+
228
+ pos_grid_local[gd] = pos_grid[gd];
229
+ uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
230
+ pos_grid_local[gd] = pos_grid[gd] + 1;
231
+ uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
232
+
233
+ #pragma unroll
234
+ for (uint32_t ch = 0; ch < C; ch++) {
235
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
236
+ }
237
+ }
238
+
239
+ #pragma unroll
240
+ for (uint32_t ch = 0; ch < C; ch++) {
241
+ dy_dx[gd * C + ch] = results_grad[ch];
242
+ }
243
+ }
244
+ }
245
+ }
246
+
247
+
248
+ template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
249
+ __global__ void kernel_grid_backward(
250
+ const scalar_t * __restrict__ grad,
251
+ const float * __restrict__ inputs,
252
+ const scalar_t * __restrict__ grid,
253
+ const int * __restrict__ offsets,
254
+ scalar_t * __restrict__ grad_grid,
255
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
256
+ const uint32_t gridtype,
257
+ const bool align_corners,
258
+ const uint32_t interp
259
+ ) {
260
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
261
+ if (b >= B) return;
262
+
263
+ const uint32_t level = blockIdx.y;
264
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
265
+
266
+ // locate
267
+ grad_grid += offsets[level] * C;
268
+ inputs += b * D;
269
+ grad += level * B * C + b * C + ch; // L, B, C
270
+
271
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
272
+ const float scale = exp2f(level * S) * H - 1.0f;
273
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
274
+
275
+ // check input range (should be in [0, 1])
276
+ #pragma unroll
277
+ for (uint32_t d = 0; d < D; d++) {
278
+ if (inputs[d] < 0 || inputs[d] > 1) {
279
+ return; // grad is init as 0, so we simply return.
280
+ }
281
+ }
282
+
283
+ // calculate coordinate
284
+ float pos[D];
285
+ uint32_t pos_grid[D];
286
+
287
+ #pragma unroll
288
+ for (uint32_t d = 0; d < D; d++) {
289
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
290
+ pos_grid[d] = floorf(pos[d]);
291
+ pos[d] -= (float)pos_grid[d];
292
+ // smoothstep instead of linear
293
+ if (interp == 1) {
294
+ pos[d] = smoothstep(pos[d]);
295
+ }
296
+ }
297
+
298
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
299
+ #pragma unroll
300
+ for (uint32_t c = 0; c < N_C; c++) {
301
+ grad_cur[c] = grad[c];
302
+ }
303
+
304
+ // interpolate
305
+ #pragma unroll
306
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
307
+ float w = 1;
308
+ uint32_t pos_grid_local[D];
309
+
310
+ #pragma unroll
311
+ for (uint32_t d = 0; d < D; d++) {
312
+ if ((idx & (1 << d)) == 0) {
313
+ w *= 1 - pos[d];
314
+ pos_grid_local[d] = pos_grid[d];
315
+ } else {
316
+ w *= pos[d];
317
+ pos_grid_local[d] = pos_grid[d] + 1;
318
+ }
319
+ }
320
+
321
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
322
+
323
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
324
+ // TODO: use float which is better than __half, if N_C % 2 != 0
325
+ if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
326
+ #pragma unroll
327
+ for (uint32_t c = 0; c < N_C; c += 2) {
328
+ // process two __half at once (by interpreting as a __half2)
329
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
330
+ atomicAdd((__half2*)&grad_grid[index + c], v);
331
+ }
332
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
333
+ } else {
334
+ #pragma unroll
335
+ for (uint32_t c = 0; c < N_C; c++) {
336
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
337
+ }
338
+ }
339
+ }
340
+ }
341
+
342
+
343
+ template <typename scalar_t, uint32_t D, uint32_t C>
344
+ __global__ void kernel_input_backward(
345
+ const scalar_t * __restrict__ grad,
346
+ const scalar_t * __restrict__ dy_dx,
347
+ scalar_t * __restrict__ grad_inputs,
348
+ uint32_t B, uint32_t L
349
+ ) {
350
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
351
+ if (t >= B * D) return;
352
+
353
+ const uint32_t b = t / D;
354
+ const uint32_t d = t - b * D;
355
+
356
+ dy_dx += b * L * D * C;
357
+
358
+ scalar_t result = 0;
359
+
360
+ # pragma unroll
361
+ for (int l = 0; l < L; l++) {
362
+ # pragma unroll
363
+ for (int ch = 0; ch < C; ch++) {
364
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
365
+ }
366
+ }
367
+
368
+ grad_inputs[t] = result;
369
+ }
370
+
371
+
372
+ template <typename scalar_t, uint32_t D>
373
+ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
374
+ static constexpr uint32_t N_THREAD = 512;
375
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
376
+ switch (C) {
377
+ case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
378
+ case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
379
+ case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
380
+ case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
381
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
382
+ }
383
+ }
384
+
385
+ // inputs: [B, D], float, in [0, 1]
386
+ // embeddings: [sO, C], float
387
+ // offsets: [L + 1], uint32_t
388
+ // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
389
+ // H: base resolution
390
+ // dy_dx: [B, L * D * C]
391
+ template <typename scalar_t>
392
+ void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
393
+ switch (D) {
394
+ case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
395
+ case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
396
+ case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
397
+ case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
398
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
399
+ }
400
+ }
401
+
402
+ template <typename scalar_t, uint32_t D>
403
+ void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
404
+ static constexpr uint32_t N_THREAD = 256;
405
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
406
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
407
+ switch (C) {
408
+ case 1:
409
+ kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
410
+ if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
411
+ break;
412
+ case 2:
413
+ kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
414
+ if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
415
+ break;
416
+ case 4:
417
+ kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
418
+ if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
419
+ break;
420
+ case 8:
421
+ kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
422
+ if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
423
+ break;
424
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
425
+ }
426
+ }
427
+
428
+
429
+ // grad: [L, B, C], float
430
+ // inputs: [B, D], float, in [0, 1]
431
+ // embeddings: [sO, C], float
432
+ // offsets: [L + 1], uint32_t
433
+ // grad_embeddings: [sO, C]
434
+ // H: base resolution
435
+ template <typename scalar_t>
436
+ void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
437
+ switch (D) {
438
+ case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
439
+ case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
440
+ case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
441
+ case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
442
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
443
+ }
444
+ }
445
+
446
+
447
+
448
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
449
+ CHECK_CUDA(inputs);
450
+ CHECK_CUDA(embeddings);
451
+ CHECK_CUDA(offsets);
452
+ CHECK_CUDA(outputs);
453
+ // CHECK_CUDA(dy_dx);
454
+
455
+ CHECK_CONTIGUOUS(inputs);
456
+ CHECK_CONTIGUOUS(embeddings);
457
+ CHECK_CONTIGUOUS(offsets);
458
+ CHECK_CONTIGUOUS(outputs);
459
+ // CHECK_CONTIGUOUS(dy_dx);
460
+
461
+ CHECK_IS_FLOATING(inputs);
462
+ CHECK_IS_FLOATING(embeddings);
463
+ CHECK_IS_INT(offsets);
464
+ CHECK_IS_FLOATING(outputs);
465
+ // CHECK_IS_FLOATING(dy_dx);
466
+
467
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
468
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
469
+ grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
470
+ }));
471
+ }
472
+
473
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
474
+ CHECK_CUDA(grad);
475
+ CHECK_CUDA(inputs);
476
+ CHECK_CUDA(embeddings);
477
+ CHECK_CUDA(offsets);
478
+ CHECK_CUDA(grad_embeddings);
479
+ // CHECK_CUDA(dy_dx);
480
+ // CHECK_CUDA(grad_inputs);
481
+
482
+ CHECK_CONTIGUOUS(grad);
483
+ CHECK_CONTIGUOUS(inputs);
484
+ CHECK_CONTIGUOUS(embeddings);
485
+ CHECK_CONTIGUOUS(offsets);
486
+ CHECK_CONTIGUOUS(grad_embeddings);
487
+ // CHECK_CONTIGUOUS(dy_dx);
488
+ // CHECK_CONTIGUOUS(grad_inputs);
489
+
490
+ CHECK_IS_FLOATING(grad);
491
+ CHECK_IS_FLOATING(inputs);
492
+ CHECK_IS_FLOATING(embeddings);
493
+ CHECK_IS_INT(offsets);
494
+ CHECK_IS_FLOATING(grad_embeddings);
495
+ // CHECK_IS_FLOATING(dy_dx);
496
+ // CHECK_IS_FLOATING(grad_inputs);
497
+
498
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
499
+ grad.scalar_type(), "grid_encode_backward", ([&] {
500
+ grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
501
+ }));
502
+
503
+ }
504
+
505
+
506
+ template <typename scalar_t, uint32_t D, uint32_t C>
507
+ __global__ void kernel_grad_tv(
508
+ const scalar_t * __restrict__ inputs,
509
+ const scalar_t * __restrict__ grid,
510
+ scalar_t * __restrict__ grad,
511
+ const int * __restrict__ offsets,
512
+ const float weight,
513
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
514
+ const uint32_t gridtype,
515
+ const bool align_corners
516
+ ) {
517
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
518
+
519
+ if (b >= B) return;
520
+
521
+ const uint32_t level = blockIdx.y;
522
+
523
+ // locate
524
+ inputs += b * D;
525
+ grid += (uint32_t)offsets[level] * C;
526
+ grad += (uint32_t)offsets[level] * C;
527
+
528
+ // check input range (should be in [0, 1])
529
+ bool flag_oob = false;
530
+ #pragma unroll
531
+ for (uint32_t d = 0; d < D; d++) {
532
+ if (inputs[d] < 0 || inputs[d] > 1) {
533
+ flag_oob = true;
534
+ }
535
+ }
536
+
537
+ // if input out of bound, do nothing
538
+ if (flag_oob) return;
539
+
540
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
541
+ const float scale = exp2f(level * S) * H - 1.0f;
542
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
543
+
544
+ // calculate coordinate
545
+ float pos[D];
546
+ uint32_t pos_grid[D]; // [0, resolution]
547
+
548
+ #pragma unroll
549
+ for (uint32_t d = 0; d < D; d++) {
550
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
551
+ pos_grid[d] = floorf(pos[d]);
552
+ // pos[d] -= (float)pos_grid[d]; // not used
553
+ }
554
+
555
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
556
+
557
+ // total variation on pos_grid
558
+ scalar_t results[C] = {0}; // temp results in register
559
+ scalar_t idelta[C] = {0};
560
+
561
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
562
+
563
+ scalar_t w = weight / (2 * D);
564
+
565
+ #pragma unroll
566
+ for (uint32_t d = 0; d < D; d++) {
567
+
568
+ uint32_t cur_d = pos_grid[d];
569
+ scalar_t grad_val;
570
+
571
+ // right side
572
+ if (cur_d < resolution) {
573
+ pos_grid[d] = cur_d + 1;
574
+ uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
575
+
576
+ #pragma unroll
577
+ for (uint32_t ch = 0; ch < C; ch++) {
578
+ // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f);
579
+ grad_val = (grid[index + ch] - grid[index_right + ch]);
580
+ results[ch] += grad_val;
581
+ idelta[ch] += grad_val * grad_val;
582
+ }
583
+ }
584
+
585
+ // left side
586
+ if (cur_d > 0) {
587
+ pos_grid[d] = cur_d - 1;
588
+ uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
589
+
590
+ #pragma unroll
591
+ for (uint32_t ch = 0; ch < C; ch++) {
592
+ // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f);
593
+ grad_val = (grid[index + ch] - grid[index_left + ch]);
594
+ results[ch] += grad_val;
595
+ idelta[ch] += grad_val * grad_val;
596
+ }
597
+ }
598
+
599
+ // reset
600
+ pos_grid[d] = cur_d;
601
+ }
602
+
603
+ // writing to global memory (slow)
604
+ #pragma unroll
605
+ for (uint32_t ch = 0; ch < C; ch++) {
606
+ // index may collide, so use atomic!
607
+ atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
608
+ }
609
+
610
+ }
611
+
612
+
613
+ template <typename scalar_t, uint32_t D>
614
+ void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
615
+ static constexpr uint32_t N_THREAD = 512;
616
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
617
+ switch (C) {
618
+ case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
619
+ case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
620
+ case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
621
+ case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
622
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
623
+ }
624
+ }
625
+
626
+
627
+ template <typename scalar_t>
628
+ void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
629
+ switch (D) {
630
+ case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
631
+ case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
632
+ case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
633
+ case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
634
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
635
+ }
636
+ }
637
+
638
+
639
+ void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
640
+
641
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
642
+ embeddings.scalar_type(), "grad_total_variation", ([&] {
643
+ grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners);
644
+ }));
645
+ }
gridencoder/src/gridencoder.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _HASH_ENCODE_H
2
+ #define _HASH_ENCODE_H
3
+
4
+ #include <stdint.h>
5
+ #include <torch/torch.h>
6
+
7
+ // inputs: [B, D], float, in [0, 1]
8
+ // embeddings: [sO, C], float
9
+ // offsets: [L + 1], uint32_t
10
+ // outputs: [B, L * C], float
11
+ // H: base resolution
12
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
13
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
14
+
15
+ void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
16
+
17
+ #endif
internal/camera_utils.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from internal import configs
3
+ from internal import stepfun
4
+ from internal import utils
5
+ import numpy as np
6
+ import scipy
7
+
8
+
9
+ def convert_to_ndc(origins,
10
+ directions,
11
+ pixtocam,
12
+ near: float = 1.):
13
+ """Converts a set of rays to normalized device coordinates (NDC).
14
+
15
+ Args:
16
+ origins: ndarray(float32), [..., 3], world space ray origins.
17
+ directions: ndarray(float32), [..., 3], world space ray directions.
18
+ pixtocam: ndarray(float32), [3, 3], inverse intrinsic matrix.
19
+ near: float, near plane along the negative z axis.
20
+
21
+ Returns:
22
+ origins_ndc: ndarray(float32), [..., 3].
23
+ directions_ndc: ndarray(float32), [..., 3].
24
+
25
+ This function assumes input rays should be mapped into the NDC space for a
26
+ perspective projection pinhole camera, with identity extrinsic matrix (pose)
27
+ and intrinsic parameters defined by inputs focal, width, and height.
28
+
29
+ The near value specifies the near plane of the frustum, and the far plane is
30
+ assumed to be infinity.
31
+
32
+ The ray bundle for the identity pose camera will be remapped to parallel rays
33
+ within the (-1, -1, -1) to (1, 1, 1) cube. Any other ray in the original
34
+ world space can be remapped as long as it has dz < 0 (ray direction has a
35
+ negative z-coord); this allows us to share a common NDC space for "forward
36
+ facing" scenes.
37
+
38
+ Note that
39
+ projection(origins + t * directions)
40
+ will NOT be equal to
41
+ origins_ndc + t * directions_ndc
42
+ and that the directions_ndc are not unit length. Rather, directions_ndc is
43
+ defined such that the valid near and far planes in NDC will be 0 and 1.
44
+
45
+ See Appendix C in https://arxiv.org/abs/2003.08934 for additional details.
46
+ """
47
+
48
+ # Shift ray origins to near plane, such that oz = -near.
49
+ # This makes the new near bound equal to 0.
50
+ t = -(near + origins[..., 2]) / directions[..., 2]
51
+ origins = origins + t[..., None] * directions
52
+
53
+ dx, dy, dz = np.moveaxis(directions, -1, 0)
54
+ ox, oy, oz = np.moveaxis(origins, -1, 0)
55
+
56
+ xmult = 1. / pixtocam[0, 2] # Equal to -2. * focal / cx
57
+ ymult = 1. / pixtocam[1, 2] # Equal to -2. * focal / cy
58
+
59
+ # Perspective projection into NDC for the t = 0 near points
60
+ # origins + 0 * directions
61
+ origins_ndc = np.stack([xmult * ox / oz, ymult * oy / oz,
62
+ -np.ones_like(oz)], axis=-1)
63
+
64
+ # Perspective projection into NDC for the t = infinity far points
65
+ # origins + infinity * directions
66
+ infinity_ndc = np.stack([xmult * dx / dz, ymult * dy / dz,
67
+ np.ones_like(oz)],
68
+ axis=-1)
69
+
70
+ # directions_ndc points from origins_ndc to infinity_ndc
71
+ directions_ndc = infinity_ndc - origins_ndc
72
+
73
+ return origins_ndc, directions_ndc
74
+
75
+
76
+ def pad_poses(p):
77
+ """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
78
+ bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
79
+ return np.concatenate([p[..., :3, :4], bottom], axis=-2)
80
+
81
+
82
+ def unpad_poses(p):
83
+ """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
84
+ return p[..., :3, :4]
85
+
86
+
87
+ def recenter_poses(poses):
88
+ """Recenter poses around the origin."""
89
+ cam2world = average_pose(poses)
90
+ transform = np.linalg.inv(pad_poses(cam2world))
91
+ poses = transform @ pad_poses(poses)
92
+ return unpad_poses(poses), transform
93
+
94
+
95
+ def average_pose(poses):
96
+ """New pose using average position, z-axis, and up vector of input poses."""
97
+ position = poses[:, :3, 3].mean(0)
98
+ z_axis = poses[:, :3, 2].mean(0)
99
+ up = poses[:, :3, 1].mean(0)
100
+ cam2world = viewmatrix(z_axis, up, position)
101
+ return cam2world
102
+
103
+
104
+ def viewmatrix(lookdir, up, position):
105
+ """Construct lookat view matrix."""
106
+ vec2 = normalize(lookdir)
107
+ vec0 = normalize(np.cross(up, vec2))
108
+ vec1 = normalize(np.cross(vec2, vec0))
109
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
110
+ return m
111
+
112
+
113
+ def normalize(x):
114
+ """Normalization helper function."""
115
+ return x / np.linalg.norm(x)
116
+
117
+
118
+ def focus_point_fn(poses):
119
+ """Calculate nearest point to all focal axes in poses."""
120
+ directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
121
+ m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
122
+ mt_m = np.transpose(m, [0, 2, 1]) @ m
123
+ focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
124
+ return focus_pt
125
+
126
+
127
+ # Constants for generate_spiral_path():
128
+ NEAR_STRETCH = .9 # Push forward near bound for forward facing render path.
129
+ FAR_STRETCH = 5. # Push back far bound for forward facing render path.
130
+ FOCUS_DISTANCE = .75 # Relative weighting of near, far bounds for render path.
131
+
132
+
133
+ def generate_spiral_path(poses, bounds, n_frames=120, n_rots=2, zrate=.5):
134
+ """Calculates a forward facing spiral path for rendering."""
135
+ # Find a reasonable 'focus depth' for this dataset as a weighted average
136
+ # of conservative near and far bounds in disparity space.
137
+ near_bound = bounds.min() * NEAR_STRETCH
138
+ far_bound = bounds.max() * FAR_STRETCH
139
+ # All cameras will point towards the world space point (0, 0, -focal).
140
+ focal = 1 / (((1 - FOCUS_DISTANCE) / near_bound + FOCUS_DISTANCE / far_bound))
141
+
142
+ # Get radii for spiral path using 90th percentile of camera positions.
143
+ positions = poses[:, :3, 3]
144
+ radii = np.percentile(np.abs(positions), 90, 0)
145
+ radii = np.concatenate([radii, [1.]])
146
+
147
+ # Generate poses for spiral path.
148
+ render_poses = []
149
+ cam2world = average_pose(poses)
150
+ up = poses[:, :3, 1].mean(0)
151
+ for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
152
+ t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
153
+ position = cam2world @ t
154
+ lookat = cam2world @ [0, 0, -focal, 1.]
155
+ z_axis = position - lookat
156
+ render_poses.append(viewmatrix(z_axis, up, position))
157
+ render_poses = np.stack(render_poses, axis=0)
158
+ return render_poses
159
+
160
+
161
+ def transform_poses_pca(poses):
162
+ """Transforms poses so principal components lie on XYZ axes.
163
+
164
+ Args:
165
+ poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
166
+
167
+ Returns:
168
+ A tuple (poses, transform), with the transformed poses and the applied
169
+ camera_to_world transforms.
170
+ """
171
+ t = poses[:, :3, 3]
172
+ t_mean = t.mean(axis=0)
173
+ t = t - t_mean
174
+
175
+ eigval, eigvec = np.linalg.eig(t.T @ t)
176
+ # Sort eigenvectors in order of largest to smallest eigenvalue.
177
+ inds = np.argsort(eigval)[::-1]
178
+ eigvec = eigvec[:, inds]
179
+ rot = eigvec.T
180
+ if np.linalg.det(rot) < 0:
181
+ rot = np.diag(np.array([1, 1, -1])) @ rot
182
+
183
+ transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
184
+ poses_recentered = unpad_poses(transform @ pad_poses(poses))
185
+ transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
186
+
187
+ # Flip coordinate system if z component of y-axis is negative
188
+ if poses_recentered.mean(axis=0)[2, 1] < 0:
189
+ poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
190
+ transform = np.diag(np.array([1, -1, -1, 1])) @ transform
191
+
192
+ # Just make sure it's it in the [-1, 1]^3 cube
193
+ scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
194
+ poses_recentered[:, :3, 3] *= scale_factor
195
+ transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
196
+
197
+ return poses_recentered, transform
198
+
199
+
200
+ def generate_ellipse_path(poses, n_frames=120, const_speed=True, z_variation=0., z_phase=0.):
201
+ """Generate an elliptical render path based on the given poses."""
202
+ # Calculate the focal point for the path (cameras point toward this).
203
+ center = focus_point_fn(poses)
204
+ # Path height sits at z=0 (in middle of zero-mean capture pattern).
205
+ offset = np.array([center[0], center[1], 0])
206
+
207
+ # Calculate scaling for ellipse axes based on input camera positions.
208
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
209
+ # Use ellipse that is symmetric about the focal point in xy.
210
+ low = -sc + offset
211
+ high = sc + offset
212
+ # Optional height variation need not be symmetric
213
+ z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
214
+ z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
215
+
216
+ def get_positions(theta):
217
+ # Interpolate between bounds with trig functions to get ellipse in x-y.
218
+ # Optionally also interpolate in z to change camera height along path.
219
+ return np.stack([
220
+ low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
221
+ low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
222
+ z_variation * (z_low[2] + (z_high - z_low)[2] *
223
+ (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
224
+ ], -1)
225
+
226
+ theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
227
+ positions = get_positions(theta)
228
+
229
+ if const_speed:
230
+ # Resample theta angles so that the velocity is closer to constant.
231
+ lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
232
+ theta = stepfun.sample_np(None, theta, np.log(lengths), n_frames + 1)
233
+ positions = get_positions(theta)
234
+
235
+ # Throw away duplicated last position.
236
+ positions = positions[:-1]
237
+
238
+ # Set path's up vector to axis closest to average of input pose up vectors.
239
+ avg_up = poses[:, :3, 1].mean(0)
240
+ avg_up = avg_up / np.linalg.norm(avg_up)
241
+ ind_up = np.argmax(np.abs(avg_up))
242
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
243
+
244
+ return np.stack([viewmatrix(p - center, up, p) for p in positions])
245
+
246
+
247
+ def generate_interpolated_path(poses, n_interp, spline_degree=5,
248
+ smoothness=.03, rot_weight=.1):
249
+ """Creates a smooth spline path between input keyframe camera poses.
250
+
251
+ Spline is calculated with poses in format (position, lookat-point, up-point).
252
+
253
+ Args:
254
+ poses: (n, 3, 4) array of input pose keyframes.
255
+ n_interp: returned path will have n_interp * (n - 1) total poses.
256
+ spline_degree: polynomial degree of B-spline.
257
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
258
+ rot_weight: relative weighting of rotation/translation in spline solve.
259
+
260
+ Returns:
261
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
262
+ """
263
+
264
+ def poses_to_points(poses, dist):
265
+ """Converts from pose matrices to (position, lookat, up) format."""
266
+ pos = poses[:, :3, -1]
267
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
268
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
269
+ return np.stack([pos, lookat, up], 1)
270
+
271
+ def points_to_poses(points):
272
+ """Converts from (position, lookat, up) format to pose matrices."""
273
+ return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
274
+
275
+ def interp(points, n, k, s):
276
+ """Runs multidimensional B-spline interpolation on the input points."""
277
+ sh = points.shape
278
+ pts = np.reshape(points, (sh[0], -1))
279
+ k = min(k, sh[0] - 1)
280
+ tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
281
+ u = np.linspace(0, 1, n, endpoint=False)
282
+ new_points = np.array(scipy.interpolate.splev(u, tck))
283
+ new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
284
+ return new_points
285
+
286
+ points = poses_to_points(poses, dist=rot_weight)
287
+ new_points = interp(points,
288
+ n_interp * (points.shape[0] - 1),
289
+ k=spline_degree,
290
+ s=smoothness)
291
+ return points_to_poses(new_points)
292
+
293
+
294
+ def interpolate_1d(x, n_interp, spline_degree, smoothness):
295
+ """Interpolate 1d signal x (by a factor of n_interp times)."""
296
+ t = np.linspace(0, 1, len(x), endpoint=True)
297
+ tck = scipy.interpolate.splrep(t, x, s=smoothness, k=spline_degree)
298
+ n = n_interp * (len(x) - 1)
299
+ u = np.linspace(0, 1, n, endpoint=False)
300
+ return scipy.interpolate.splev(u, tck)
301
+
302
+
303
+ def create_render_spline_path(config, image_names, poses, exposures):
304
+ """Creates spline interpolation render path from subset of dataset poses.
305
+
306
+ Args:
307
+ config: configs.Config object.
308
+ image_names: either a directory of images or a text file of image names.
309
+ poses: [N, 3, 4] array of extrinsic camera pose matrices.
310
+ exposures: optional list of floating point exposure values.
311
+
312
+ Returns:
313
+ spline_indices: list of indices used to select spline keyframe poses.
314
+ render_poses: array of interpolated extrinsic camera poses for the path.
315
+ render_exposures: optional list of interpolated exposures for the path.
316
+ """
317
+ if utils.isdir(config.render_spline_keyframes):
318
+ # If directory, use image filenames.
319
+ keyframe_names = sorted(utils.listdir(config.render_spline_keyframes))
320
+ else:
321
+ # If text file, treat each line as an image filename.
322
+ with utils.open_file(config.render_spline_keyframes, 'r') as fp:
323
+ # Decode bytes into string and split into lines.
324
+ keyframe_names = fp.read().decode('utf-8').splitlines()
325
+ # Grab poses corresponding to the image filenames.
326
+ spline_indices = np.array(
327
+ [i for i, n in enumerate(image_names) if n in keyframe_names])
328
+ keyframes = poses[spline_indices]
329
+ render_poses = generate_interpolated_path(
330
+ keyframes,
331
+ n_interp=config.render_spline_n_interp,
332
+ spline_degree=config.render_spline_degree,
333
+ smoothness=config.render_spline_smoothness,
334
+ rot_weight=.1)
335
+ if config.render_spline_interpolate_exposure:
336
+ if exposures is None:
337
+ raise ValueError('config.render_spline_interpolate_exposure is True but '
338
+ 'create_render_spline_path() was passed exposures=None.')
339
+ # Interpolate per-frame exposure value.
340
+ log_exposure = np.log(exposures[spline_indices])
341
+ # Use aggressive smoothing for exposure interpolation to avoid flickering.
342
+ log_exposure_interp = interpolate_1d(
343
+ log_exposure,
344
+ config.render_spline_n_interp,
345
+ spline_degree=5,
346
+ smoothness=20)
347
+ render_exposures = np.exp(log_exposure_interp)
348
+ else:
349
+ render_exposures = None
350
+ return spline_indices, render_poses, render_exposures
351
+
352
+
353
+ def intrinsic_matrix(fx, fy, cx, cy):
354
+ """Intrinsic matrix for a pinhole camera in OpenCV coordinate system."""
355
+ return np.array([
356
+ [fx, 0, cx],
357
+ [0, fy, cy],
358
+ [0, 0, 1.],
359
+ ])
360
+
361
+
362
+ def get_pixtocam(focal, width, height):
363
+ """Inverse intrinsic matrix for a perfect pinhole camera."""
364
+ camtopix = intrinsic_matrix(focal, focal, width * .5, height * .5)
365
+ return np.linalg.inv(camtopix)
366
+
367
+
368
+ def pixel_coordinates(width, height):
369
+ """Tuple of the x and y integer coordinates for a grid of pixels."""
370
+ return np.meshgrid(np.arange(width), np.arange(height), indexing='xy')
371
+
372
+
373
+ def _compute_residual_and_jacobian(x, y, xd, yd,
374
+ k1=0.0, k2=0.0, k3=0.0,
375
+ k4=0.0, p1=0.0, p2=0.0, ):
376
+ """Auxiliary function of radial_and_tangential_undistort()."""
377
+ # Adapted from https://github.com/google/nerfies/blob/main/nerfies/camera.py
378
+ # let r(x, y) = x^2 + y^2;
379
+ # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +
380
+ # k4 * r(x, y)^4;
381
+ r = x * x + y * y
382
+ d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4)))
383
+
384
+ # The perfect projection is:
385
+ # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
386
+ # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
387
+ #
388
+ # Let's define
389
+ #
390
+ # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
391
+ # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
392
+ #
393
+ # We are looking for a solution that satisfies
394
+ # fx(x, y) = fy(x, y) = 0;
395
+ fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
396
+ fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
397
+
398
+ # Compute derivative of d over [x, y]
399
+ d_r = (k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4)))
400
+ d_x = 2.0 * x * d_r
401
+ d_y = 2.0 * y * d_r
402
+
403
+ # Compute derivative of fx over x and y.
404
+ fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
405
+ fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
406
+
407
+ # Compute derivative of fy over x and y.
408
+ fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
409
+ fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
410
+
411
+ return fx, fy, fx_x, fx_y, fy_x, fy_y
412
+
413
+
414
+ def _radial_and_tangential_undistort(xd, yd, k1=0, k2=0,
415
+ k3=0, k4=0, p1=0,
416
+ p2=0, eps=1e-9, max_iterations=10):
417
+ """Computes undistorted (x, y) from (xd, yd)."""
418
+ # From https://github.com/google/nerfies/blob/main/nerfies/camera.py
419
+ # Initialize from the distorted point.
420
+ x = np.copy(xd)
421
+ y = np.copy(yd)
422
+
423
+ for _ in range(max_iterations):
424
+ fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
425
+ x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, k4=k4, p1=p1, p2=p2)
426
+ denominator = fy_x * fx_y - fx_x * fy_y
427
+ x_numerator = fx * fy_y - fy * fx_y
428
+ y_numerator = fy * fx_x - fx * fy_x
429
+ step_x = np.where(
430
+ np.abs(denominator) > eps, x_numerator / denominator,
431
+ np.zeros_like(denominator))
432
+ step_y = np.where(
433
+ np.abs(denominator) > eps, y_numerator / denominator,
434
+ np.zeros_like(denominator))
435
+
436
+ x = x + step_x
437
+ y = y + step_y
438
+
439
+ return x, y
440
+
441
+
442
+ class ProjectionType(enum.Enum):
443
+ """Camera projection type (standard perspective pinhole or fisheye model)."""
444
+ PERSPECTIVE = 'perspective'
445
+ FISHEYE = 'fisheye'
446
+
447
+
448
+ def pixels_to_rays(pix_x_int, pix_y_int, pixtocams,
449
+ camtoworlds,
450
+ distortion_params=None,
451
+ pixtocam_ndc=None,
452
+ camtype=ProjectionType.PERSPECTIVE):
453
+ """Calculates rays given pixel coordinates, intrinisics, and extrinsics.
454
+
455
+ Given 2D pixel coordinates pix_x_int, pix_y_int for cameras with
456
+ inverse intrinsics pixtocams and extrinsics camtoworlds (and optional
457
+ distortion coefficients distortion_params and NDC space projection matrix
458
+ pixtocam_ndc), computes the corresponding 3D camera rays.
459
+
460
+ Vectorized over the leading dimensions of the first four arguments.
461
+
462
+ Args:
463
+ pix_x_int: int array, shape SH, x coordinates of image pixels.
464
+ pix_y_int: int array, shape SH, y coordinates of image pixels.
465
+ pixtocams: float array, broadcastable to SH + [3, 3], inverse intrinsics.
466
+ camtoworlds: float array, broadcastable to SH + [3, 4], camera extrinsics.
467
+ distortion_params: dict of floats, optional camera distortion parameters.
468
+ pixtocam_ndc: float array, [3, 3], optional inverse intrinsics for NDC.
469
+ camtype: camera_utils.ProjectionType, fisheye or perspective camera.
470
+
471
+ Returns:
472
+ origins: float array, shape SH + [3], ray origin points.
473
+ directions: float array, shape SH + [3], ray direction vectors.
474
+ viewdirs: float array, shape SH + [3], normalized ray direction vectors.
475
+ radii: float array, shape SH + [1], ray differential radii.
476
+ imageplane: float array, shape SH + [2], xy coordinates on the image plane.
477
+ If the image plane is at world space distance 1 from the pinhole, then
478
+ imageplane will be the xy coordinates of a pixel in that space (so the
479
+ camera ray direction at the origin would be (x, y, -1) in OpenGL coords).
480
+ """
481
+
482
+ # Must add half pixel offset to shoot rays through pixel centers.
483
+ def pix_to_dir(x, y):
484
+ return np.stack([x + .5, y + .5, np.ones_like(x)], axis=-1)
485
+
486
+ # We need the dx and dy rays to calculate ray radii for mip-NeRF cones.
487
+ pixel_dirs_stacked = np.stack([
488
+ pix_to_dir(pix_x_int, pix_y_int),
489
+ pix_to_dir(pix_x_int + 1, pix_y_int),
490
+ pix_to_dir(pix_x_int, pix_y_int + 1)
491
+ ], axis=0)
492
+
493
+ matmul = np.matmul
494
+ mat_vec_mul = lambda A, b: matmul(A, b[..., None])[..., 0]
495
+
496
+ # Apply inverse intrinsic matrices.
497
+ camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked)
498
+
499
+ if distortion_params is not None:
500
+ # Correct for distortion.
501
+ x, y = _radial_and_tangential_undistort(
502
+ camera_dirs_stacked[..., 0],
503
+ camera_dirs_stacked[..., 1],
504
+ **distortion_params)
505
+ camera_dirs_stacked = np.stack([x, y, np.ones_like(x)], -1)
506
+
507
+ if camtype == ProjectionType.FISHEYE:
508
+ theta = np.sqrt(np.sum(np.square(camera_dirs_stacked[..., :2]), axis=-1))
509
+ theta = np.minimum(np.pi, theta)
510
+
511
+ sin_theta_over_theta = np.sin(theta) / theta
512
+ camera_dirs_stacked = np.stack([
513
+ camera_dirs_stacked[..., 0] * sin_theta_over_theta,
514
+ camera_dirs_stacked[..., 1] * sin_theta_over_theta,
515
+ np.cos(theta),
516
+ ], axis=-1)
517
+
518
+ # Flip from OpenCV to OpenGL coordinate system.
519
+ camera_dirs_stacked = matmul(camera_dirs_stacked,
520
+ np.diag(np.array([1., -1., -1.])))
521
+
522
+ # Extract 2D image plane (x, y) coordinates.
523
+ imageplane = camera_dirs_stacked[0, ..., :2]
524
+
525
+ # Apply camera rotation matrices.
526
+ directions_stacked = mat_vec_mul(camtoworlds[..., :3, :3],
527
+ camera_dirs_stacked)
528
+ # Extract the offset rays.
529
+ directions, dx, dy = directions_stacked
530
+
531
+ origins = np.broadcast_to(camtoworlds[..., :3, -1], directions.shape)
532
+ viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
533
+
534
+ if pixtocam_ndc is None:
535
+ # Distance from each unit-norm direction vector to its neighbors.
536
+ dx_norm = np.linalg.norm(dx - directions, axis=-1)
537
+ dy_norm = np.linalg.norm(dy - directions, axis=-1)
538
+
539
+ else:
540
+ # Convert ray origins and directions into projective NDC space.
541
+ origins_dx, _ = convert_to_ndc(origins, dx, pixtocam_ndc)
542
+ origins_dy, _ = convert_to_ndc(origins, dy, pixtocam_ndc)
543
+ origins, directions = convert_to_ndc(origins, directions, pixtocam_ndc)
544
+
545
+ # In NDC space, we use the offset between origins instead of directions.
546
+ dx_norm = np.linalg.norm(origins_dx - origins, axis=-1)
547
+ dy_norm = np.linalg.norm(origins_dy - origins, axis=-1)
548
+
549
+ # Cut the distance in half, multiply it to match the variance of a uniform
550
+ # distribution the size of a pixel (1/12, see the original mipnerf paper).
551
+ radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / np.sqrt(12)
552
+ return origins, directions, viewdirs, radii, imageplane
553
+
554
+
555
+ def cast_ray_batch(cameras, pixels, camtype):
556
+ """Maps from input cameras and Pixel batch to output Ray batch.
557
+
558
+ `cameras` is a Tuple of four sets of camera parameters.
559
+ pixtocams: 1 or N stacked [3, 3] inverse intrinsic matrices.
560
+ camtoworlds: 1 or N stacked [3, 4] extrinsic pose matrices.
561
+ distortion_params: optional, dict[str, float] containing pinhole model
562
+ distortion parameters.
563
+ pixtocam_ndc: optional, [3, 3] inverse intrinsic matrix for mapping to NDC.
564
+
565
+ Args:
566
+ cameras: described above.
567
+ pixels: integer pixel coordinates and camera indices, plus ray metadata.
568
+ These fields can be an arbitrary batch shape.
569
+ camtype: camera_utils.ProjectionType, fisheye or perspective camera.
570
+
571
+ Returns:
572
+ rays: Rays dataclass with computed 3D world space ray data.
573
+ """
574
+ pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras
575
+
576
+ # pixels.cam_idx has shape [..., 1], remove this hanging dimension.
577
+ cam_idx = pixels['cam_idx'][..., 0]
578
+ batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx]
579
+
580
+ # Compute rays from pixel coordinates.
581
+ origins, directions, viewdirs, radii, imageplane = pixels_to_rays(
582
+ pixels['pix_x_int'],
583
+ pixels['pix_y_int'],
584
+ batch_index(pixtocams),
585
+ batch_index(camtoworlds),
586
+ distortion_params=distortion_params,
587
+ pixtocam_ndc=pixtocam_ndc,
588
+ camtype=camtype)
589
+
590
+ # Create Rays data structure.
591
+ return dict(
592
+ origins=origins,
593
+ directions=directions,
594
+ viewdirs=viewdirs,
595
+ radii=radii,
596
+ imageplane=imageplane,
597
+ lossmult=pixels.get('lossmult'),
598
+ near=pixels.get('near'),
599
+ far=pixels.get('far'),
600
+ cam_idx=pixels.get('cam_idx'),
601
+ exposure_idx=pixels.get('exposure_idx'),
602
+ exposure_values=pixels.get('exposure_values'),
603
+ )
604
+
605
+
606
+ def cast_pinhole_rays(camtoworld, height, width, focal, near, far):
607
+ """Wrapper for generating a pinhole camera ray batch (w/o distortion)."""
608
+
609
+ pix_x_int, pix_y_int = pixel_coordinates(width, height)
610
+ pixtocam = get_pixtocam(focal, width, height)
611
+
612
+ origins, directions, viewdirs, radii, imageplane = pixels_to_rays(pix_x_int, pix_y_int, pixtocam, camtoworld)
613
+
614
+ broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
615
+ ray_kwargs = {
616
+ 'lossmult': broadcast_scalar(1.),
617
+ 'near': broadcast_scalar(near),
618
+ 'far': broadcast_scalar(far),
619
+ 'cam_idx': broadcast_scalar(0),
620
+ }
621
+
622
+ return dict(origins=origins,
623
+ directions=directions,
624
+ viewdirs=viewdirs,
625
+ radii=radii,
626
+ imageplane=imageplane,
627
+ **ray_kwargs)
628
+
629
+
630
+ def cast_spherical_rays(camtoworld, height, width, near, far):
631
+ """Generates a spherical camera ray batch."""
632
+
633
+ theta_vals = np.linspace(0, 2 * np.pi, width + 1)
634
+ phi_vals = np.linspace(0, np.pi, height + 1)
635
+ theta, phi = np.meshgrid(theta_vals, phi_vals, indexing='xy')
636
+
637
+ # Spherical coordinates in camera reference frame (y is up).
638
+ directions = np.stack([
639
+ -np.sin(phi) * np.sin(theta),
640
+ np.cos(phi),
641
+ np.sin(phi) * np.cos(theta),
642
+ ], axis=-1)
643
+
644
+ matmul = np.matmul
645
+ directions = matmul(camtoworld[:3, :3], directions[..., None])[..., 0]
646
+
647
+ dy = np.diff(directions[:, :-1], axis=0)
648
+ dx = np.diff(directions[:-1, :], axis=1)
649
+ directions = directions[:-1, :-1]
650
+ viewdirs = directions
651
+
652
+ origins = np.broadcast_to(camtoworld[:3, -1], directions.shape)
653
+
654
+ dx_norm = np.linalg.norm(dx, axis=-1)
655
+ dy_norm = np.linalg.norm(dy, axis=-1)
656
+ radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / np.sqrt(12)
657
+
658
+ imageplane = np.zeros_like(directions[..., :2])
659
+
660
+ broadcast_scalar = lambda x: np.broadcast_to(x, radii.shape[:-1])[..., None]
661
+ ray_kwargs = {
662
+ 'lossmult': broadcast_scalar(1.),
663
+ 'near': broadcast_scalar(near),
664
+ 'far': broadcast_scalar(far),
665
+ 'cam_idx': broadcast_scalar(0),
666
+ }
667
+
668
+ return dict(origins=origins,
669
+ directions=directions,
670
+ viewdirs=viewdirs,
671
+ radii=radii,
672
+ imageplane=imageplane,
673
+ **ray_kwargs)
internal/checkpoints.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import accelerate
5
+ import torch
6
+ import glob
7
+
8
+
9
+ def restore_checkpoint(
10
+ checkpoint_dir,
11
+ accelerator: accelerate.Accelerator,
12
+ logger=None
13
+ ):
14
+ dirs = glob.glob(os.path.join(checkpoint_dir, "*"))
15
+ dirs.sort()
16
+ path = dirs[-1] if len(dirs) > 0 else None
17
+ if path is None:
18
+ if logger is not None:
19
+ logger.info("Checkpoint does not exist. Starting a new training run.")
20
+ init_step = 0
21
+ else:
22
+ if logger is not None:
23
+ logger.info(f"Resuming from checkpoint {path}")
24
+ accelerator.load_state(path)
25
+ init_step = int(os.path.basename(path))
26
+ return init_step
27
+
28
+
29
+ def save_checkpoint(save_dir,
30
+ accelerator: accelerate.Accelerator,
31
+ step=0,
32
+ total_limit=3):
33
+ if total_limit > 0:
34
+ folders = glob.glob(os.path.join(save_dir, "*"))
35
+ folders.sort()
36
+ for folder in folders[: len(folders) + 1 - total_limit]:
37
+ shutil.rmtree(folder)
38
+ accelerator.save_state(os.path.join(save_dir, f"{step:06d}"))
internal/configs.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import os
3
+ from typing import Any, Callable, Optional, Tuple, List
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from absl import flags
8
+ import gin
9
+ from internal import utils
10
+
11
+ gin.add_config_file_search_path('configs/')
12
+
13
+ configurables = {
14
+ 'torch': [torch.reciprocal, torch.log, torch.log1p, torch.exp, torch.sqrt, torch.square],
15
+ }
16
+
17
+ for module, configurables in configurables.items():
18
+ for configurable in configurables:
19
+ gin.config.external_configurable(configurable, module=module)
20
+
21
+
22
+ @gin.configurable()
23
+ @dataclasses.dataclass
24
+ class Config:
25
+ """Configuration flags for everything."""
26
+ seed = 0
27
+ dataset_loader: str = 'llff' # The type of dataset loader to use.
28
+ batching: str = 'all_images' # Batch composition, [single_image, all_images].
29
+ batch_size: int = 2 ** 16 # The number of rays/pixels in each batch.
30
+ patch_size: int = 1 # Resolution of patches sampled for training batches.
31
+ factor: int = 4 # The downsample factor of images, 0 for no downsampling.
32
+ multiscale: bool = False # use multiscale data for training.
33
+ multiscale_levels: int = 4 # number of multiscale levels.
34
+ # ordering (affects heldout test set).
35
+ forward_facing: bool = False # Set to True for forward-facing LLFF captures.
36
+ render_path: bool = False # If True, render a path. Used only by LLFF.
37
+ llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF.
38
+ # If true, use all input images for training.
39
+ llff_use_all_images_for_training: bool = False
40
+ llff_use_all_images_for_testing: bool = False
41
+ use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender.
42
+ compute_disp_metrics: bool = False # If True, load and compute disparity MSE.
43
+ compute_normal_metrics: bool = False # If True, load and compute normal MAE.
44
+ disable_multiscale_loss: bool = False # If True, disable multiscale loss.
45
+ randomized: bool = True # Use randomized stratified sampling.
46
+ near: float = 2. # Near plane distance.
47
+ far: float = 6. # Far plane distance.
48
+ exp_name: str = "test" # experiment name
49
+ data_dir: Optional[str] = "/SSD_DISK/datasets/360_v2/bicycle" # Input data directory.
50
+ vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP.
51
+ render_chunk_size: int = 65536 # Chunk size for whole-image renderings.
52
+ num_showcase_images: int = 5 # The number of test-set images to showcase.
53
+ deterministic_showcase: bool = True # If True, showcase the same images.
54
+ vis_num_rays: int = 16 # The number of rays to visualize.
55
+ # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage.
56
+ vis_decimate: int = 0
57
+
58
+ # Only used by train.py:
59
+ max_steps: int = 25000 # The number of optimization steps.
60
+ early_exit_steps: Optional[int] = None # Early stopping, for debugging.
61
+ checkpoint_every: int = 5000 # The number of steps to save a checkpoint.
62
+ resume_from_checkpoint: bool = True # whether to resume from checkpoint.
63
+ checkpoints_total_limit: int = 1
64
+ gradient_scaling: bool = False # If True, scale gradients as in https://gradient-scaling.github.io/.
65
+ print_every: int = 100 # The number of steps between reports to tensorboard.
66
+ train_render_every: int = 500 # Steps between test set renders when training
67
+ data_loss_type: str = 'charb' # What kind of loss to use ('mse' or 'charb').
68
+ charb_padding: float = 0.001 # The padding used for Charbonnier loss.
69
+ data_loss_mult: float = 1.0 # Mult for the finest data term in the loss.
70
+ data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms.
71
+ interlevel_loss_mult: float = 0.0 # Mult. for the loss on the proposal MLP.
72
+ anti_interlevel_loss_mult: float = 0.01 # Mult. for the loss on the proposal MLP.
73
+ pulse_width = [0.03, 0.003] # Mult. for the loss on the proposal MLP.
74
+ orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss.
75
+ orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights.
76
+ # What that loss is imposed on, options are 'normals' or 'normals_pred'.
77
+ orientation_loss_target: str = 'normals_pred'
78
+ predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss.
79
+ # Mult. on the coarser predicted normal loss.
80
+ predicted_normal_coarse_loss_mult: float = 0.0
81
+ hash_decay_mults: float = 0.1
82
+
83
+ lr_init: float = 0.01 # The initial learning rate.
84
+ lr_final: float = 0.001 # The final learning rate.
85
+ lr_delay_steps: int = 5000 # The number of "warmup" learning steps.
86
+ lr_delay_mult: float = 1e-8 # How much sever the "warmup" should be.
87
+ adam_beta1: float = 0.9 # Adam's beta2 hyperparameter.
88
+ adam_beta2: float = 0.99 # Adam's beta2 hyperparameter.
89
+ adam_eps: float = 1e-15 # Adam's epsilon hyperparameter.
90
+ grad_max_norm: float = 0. # Gradient clipping magnitude, disabled if == 0.
91
+ grad_max_val: float = 0. # Gradient clipping value, disabled if == 0.
92
+ distortion_loss_mult: float = 0.005 # Multiplier on the distortion loss.
93
+ opacity_loss_mult: float = 0. # Multiplier on the distortion loss.
94
+
95
+ # Only used by eval.py:
96
+ eval_only_once: bool = True # If True evaluate the model only once, ow loop.
97
+ eval_save_output: bool = True # If True save predicted images to disk.
98
+ eval_save_ray_data: bool = False # If True save individual ray traces.
99
+ eval_render_interval: int = 1 # The interval between images saved to disk.
100
+ eval_dataset_limit: int = np.iinfo(np.int32).max # Num test images to eval.
101
+ eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images.
102
+ eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]).
103
+
104
+ # Only used by render.py
105
+ render_video_fps: int = 60 # Framerate in frames-per-second.
106
+ render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality.
107
+ render_path_frames: int = 120 # Number of frames in render path.
108
+ z_variation: float = 0. # How much height variation in render path.
109
+ z_phase: float = 0. # Phase offset for height variation in render path.
110
+ render_dist_percentile: float = 0.5 # How much to trim from near/far planes.
111
+ render_dist_curve_fn: Callable[..., Any] = np.log # How depth is curved.
112
+ render_path_file: Optional[str] = None # Numpy render pose file to load.
113
+ render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as
114
+ # (width, height).
115
+ render_focal: Optional[float] = None # Render focal length.
116
+ render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'.
117
+ render_spherical: bool = False # Render spherical 360 panoramas.
118
+ render_save_async: bool = True # Save to CNS using a separate thread.
119
+
120
+ render_spline_keyframes: Optional[str] = None # Text file containing names of
121
+ # images to be used as spline
122
+ # keyframes, OR directory
123
+ # containing those images.
124
+ render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe.
125
+ render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation.
126
+ render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for
127
+ # exact interpolation of keyframes.
128
+ # Interpolate per-frame exposure value from spline keyframes.
129
+ render_spline_interpolate_exposure: bool = False
130
+
131
+ # Flags for raw datasets.
132
+ rawnerf_mode: bool = False # Load raw images and train in raw color space.
133
+ exposure_percentile: float = 97. # Image percentile to expose as white.
134
+ num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border
135
+ # around each input image.
136
+ apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask.
137
+ autoexpose_renders: bool = False # During rendering, autoexpose each image.
138
+ # For raw test scenes, use affine raw-space color correction.
139
+ eval_raw_affine_cc: bool = False
140
+
141
+ zero_glo: bool = False
142
+
143
+ # marching cubes
144
+ valid_weight_thresh: float = 0.05
145
+ isosurface_threshold: float = 20
146
+ mesh_voxels: int = 512 ** 3
147
+ visibility_resolution: int = 512
148
+ mesh_radius: float = 1.0 # mesh radius * 2 = in contract space
149
+ mesh_max_radius: float = 10.0 # in world space
150
+ std_value: float = 0.0 # std of the sampled points
151
+ compute_visibility: bool = False
152
+ extract_visibility: bool = True
153
+ decimate_target: int = -1
154
+ vertex_color: bool = True
155
+ vertex_projection: bool = True
156
+
157
+ # tsdf
158
+ tsdf_radius: float = 2.0
159
+ tsdf_resolution: int = 512
160
+ truncation_margin: float = 5.0
161
+ tsdf_max_radius: float = 10.0 # in world space
162
+
163
+
164
+ def define_common_flags():
165
+ # Define the flags used by both train.py and eval.py
166
+ flags.DEFINE_string('mode', None, 'Required by GINXM, not used.')
167
+ flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.')
168
+ flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.')
169
+ flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.')
170
+
171
+
172
+ def load_config():
173
+ """Load the config, and optionally checkpoint it."""
174
+ gin.parse_config_files_and_bindings(
175
+ flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True)
176
+ config = Config()
177
+ return config
internal/coord.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from internal import math
2
+ from internal import utils
3
+ import numpy as np
4
+ import torch
5
+ # from torch.func import vmap, jacrev
6
+
7
+
8
+ def contract(x):
9
+ """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077)."""
10
+ eps = torch.finfo(x.dtype).eps
11
+ # eps = 1e-3
12
+ # Clamping to eps prevents non-finite gradients when x == 0.
13
+ x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
14
+ z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
15
+ return z
16
+
17
+
18
+ def inv_contract(z):
19
+ """The inverse of contract()."""
20
+ eps = torch.finfo(z.dtype).eps
21
+
22
+ # Clamping to eps prevents non-finite gradients when z == 0.
23
+ z_mag_sq = torch.sum(z ** 2, dim=-1, keepdim=True).clamp_min(eps)
24
+ x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq).clamp_min(eps))
25
+ return x
26
+
27
+
28
+ def inv_contract_np(z):
29
+ """The inverse of contract()."""
30
+ eps = np.finfo(z.dtype).eps
31
+
32
+ # Clamping to eps prevents non-finite gradients when z == 0.
33
+ z_mag_sq = np.maximum(np.sum(z ** 2, axis=-1, keepdims=True), eps)
34
+ x = np.where(z_mag_sq <= 1, z, z / np.maximum(2 * np.sqrt(z_mag_sq) - z_mag_sq, eps))
35
+ return x
36
+
37
+
38
+ def contract_tuple(x):
39
+ res = contract(x)
40
+ return res, res
41
+
42
+
43
+ def contract_mean_jacobi(x):
44
+ eps = torch.finfo(x.dtype).eps
45
+ # eps = 1e-3
46
+
47
+ # Clamping to eps prevents non-finite gradients when x == 0.
48
+ x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
49
+ x_mag_sqrt = torch.sqrt(x_mag_sq)
50
+ x_xT = math.matmul(x[..., None], x[..., None, :])
51
+ mask = x_mag_sq <= 1
52
+ z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
53
+
54
+ eye = torch.broadcast_to(torch.eye(3, device=x.device), z.shape[:-1] + z.shape[-1:] * 2)
55
+ jacobi = (2 * x_xT * (1 - x_mag_sqrt[..., None]) + (2 * x_mag_sqrt[..., None] ** 3 - x_mag_sqrt[..., None] ** 2) * eye) / x_mag_sqrt[..., None] ** 4
56
+ jacobi = torch.where(mask[..., None], eye, jacobi)
57
+ return z, jacobi
58
+
59
+
60
+ def contract_mean_std(x, std):
61
+ eps = torch.finfo(x.dtype).eps
62
+ # eps = 1e-3
63
+ # Clamping to eps prevents non-finite gradients when x == 0.
64
+ x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
65
+ x_mag_sqrt = torch.sqrt(x_mag_sq)
66
+ mask = x_mag_sq <= 1
67
+ z = torch.where(mask, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
68
+ # det_13 = ((1 / x_mag_sq) * ((2 / x_mag_sqrt - 1 / x_mag_sq) ** 2)) ** (1 / 3)
69
+ det_13 = (torch.pow(2 * x_mag_sqrt - 1, 1/3) / x_mag_sqrt) ** 2
70
+
71
+ std = torch.where(mask[..., 0], std, det_13[..., 0] * std)
72
+ return z, std
73
+
74
+
75
+ @torch.no_grad()
76
+ def track_linearize(fn, mean, std):
77
+ """Apply function `fn` to a set of means and covariances, ala a Kalman filter.
78
+
79
+ We can analytically transform a Gaussian parameterized by `mean` and `cov`
80
+ with a function `fn` by linearizing `fn` around `mean`, and taking advantage
81
+ of the fact that Covar[Ax + y] = A(Covar[x])A^T (see
82
+ https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details).
83
+
84
+ Args:
85
+ fn: the function applied to the Gaussians parameterized by (mean, cov).
86
+ mean: a tensor of means, where the last axis is the dimension.
87
+ std: a tensor of covariances, where the last two axes are the dimensions.
88
+
89
+ Returns:
90
+ fn_mean: the transformed means.
91
+ fn_cov: the transformed covariances.
92
+ """
93
+ if fn == 'contract':
94
+ fn = contract_mean_jacobi
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ pre_shape = mean.shape[:-1]
99
+ mean = mean.reshape(-1, 3)
100
+ std = std.reshape(-1)
101
+
102
+ # jvp_1, mean_1 = vmap(jacrev(contract_tuple, has_aux=True))(mean)
103
+ # std_1 = std * torch.linalg.det(jvp_1) ** (1 / mean.shape[-1])
104
+ #
105
+ # mean_2, jvp_2 = fn(mean)
106
+ # std_2 = std * torch.linalg.det(jvp_2) ** (1 / mean.shape[-1])
107
+ #
108
+ # mean_3, std_3 = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
109
+ # torch.allclose(std_1, std_3, atol=1e-7) # True
110
+ # torch.allclose(mean_1, mean_3) # True
111
+ # import ipdb; ipdb.set_trace()
112
+ mean, std = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
113
+
114
+ mean = mean.reshape(*pre_shape, 3)
115
+ std = std.reshape(*pre_shape)
116
+ return mean, std
117
+
118
+
119
+ def power_transformation(x, lam):
120
+ """
121
+ power transformation for Eq(4) in zip-nerf
122
+ """
123
+ lam_1 = np.abs(lam - 1)
124
+ return lam_1 / lam * ((x / lam_1 + 1) ** lam - 1)
125
+
126
+
127
+ def inv_power_transformation(x, lam):
128
+ """
129
+ inverse power transformation
130
+ """
131
+ lam_1 = np.abs(lam - 1)
132
+ eps = torch.finfo(x.dtype).eps # may cause inf
133
+ # eps = 1e-3
134
+ return ((x * lam / lam_1 + 1 + eps) ** (1 / lam) - 1) * lam_1
135
+
136
+
137
+ def construct_ray_warps(fn, t_near, t_far, lam=None):
138
+ """Construct a bijection between metric distances and normalized distances.
139
+
140
+ See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a
141
+ detailed explanation.
142
+
143
+ Args:
144
+ fn: the function to ray distances.
145
+ t_near: a tensor of near-plane distances.
146
+ t_far: a tensor of far-plane distances.
147
+ lam: for lam in Eq(4) in zip-nerf
148
+
149
+ Returns:
150
+ t_to_s: a function that maps distances to normalized distances in [0, 1].
151
+ s_to_t: the inverse of t_to_s.
152
+ """
153
+ if fn is None:
154
+ fn_fwd = lambda x: x
155
+ fn_inv = lambda x: x
156
+ elif fn == 'piecewise':
157
+ # Piecewise spacing combining identity and 1/x functions to allow t_near=0.
158
+ fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x)
159
+ fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x))
160
+ elif fn == 'power_transformation':
161
+ fn_fwd = lambda x: power_transformation(x * 2, lam=lam)
162
+ fn_inv = lambda y: inv_power_transformation(y, lam=lam) / 2
163
+ else:
164
+ inv_mapping = {
165
+ 'reciprocal': torch.reciprocal,
166
+ 'log': torch.exp,
167
+ 'exp': torch.log,
168
+ 'sqrt': torch.square,
169
+ 'square': torch.sqrt,
170
+ }
171
+ fn_fwd = fn
172
+ fn_inv = inv_mapping[fn.__name__]
173
+
174
+ s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)]
175
+ t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near)
176
+ s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near)
177
+ return t_to_s, s_to_t
178
+
179
+
180
+ def expected_sin(mean, var):
181
+ """Compute the mean of sin(x), x ~ N(mean, var)."""
182
+ return torch.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value.
183
+
184
+
185
+ def integrated_pos_enc(mean, var, min_deg, max_deg):
186
+ """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg).
187
+
188
+ Args:
189
+ mean: tensor, the mean coordinates to be encoded
190
+ var: tensor, the variance of the coordinates to be encoded.
191
+ min_deg: int, the min degree of the encoding.
192
+ max_deg: int, the max degree of the encoding.
193
+
194
+ Returns:
195
+ encoded: tensor, encoded variables.
196
+ """
197
+ scales = 2 ** torch.arange(min_deg, max_deg, device=mean.device)
198
+ shape = mean.shape[:-1] + (-1,)
199
+ scaled_mean = (mean[..., None, :] * scales[:, None]).reshape(*shape)
200
+ scaled_var = (var[..., None, :] * scales[:, None] ** 2).reshape(*shape)
201
+
202
+ return expected_sin(
203
+ torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1),
204
+ torch.cat([scaled_var] * 2, dim=-1))
205
+
206
+
207
+ def lift_and_diagonalize(mean, cov, basis):
208
+ """Project `mean` and `cov` onto basis and diagonalize the projected cov."""
209
+ fn_mean = math.matmul(mean, basis)
210
+ fn_cov_diag = torch.sum(basis * math.matmul(cov, basis), dim=-2)
211
+ return fn_mean, fn_cov_diag
212
+
213
+
214
+ def pos_enc(x, min_deg, max_deg, append_identity=True):
215
+ """The positional encoding used by the original NeRF paper."""
216
+ scales = 2 ** torch.arange(min_deg, max_deg, device=x.device)
217
+ shape = x.shape[:-1] + (-1,)
218
+ scaled_x = (x[..., None, :] * scales[:, None]).reshape(*shape)
219
+ # Note that we're not using safe_sin, unlike IPE.
220
+ four_feat = torch.sin(
221
+ torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1))
222
+ if append_identity:
223
+ return torch.cat([x] + [four_feat], dim=-1)
224
+ else:
225
+ return four_feat
internal/datasets.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import copy
3
+ import json
4
+ import os
5
+ import cv2
6
+ from internal import camera_utils
7
+ from internal import configs
8
+ from internal import image as lib_image
9
+ from internal import raw_utils
10
+ from internal import utils
11
+ from collections import defaultdict
12
+ import numpy as np
13
+ import cv2
14
+ from PIL import Image
15
+ import torch
16
+ from tqdm import tqdm
17
+ # This is ugly, but it works.
18
+ import sys
19
+
20
+ sys.path.insert(0, 'internal/pycolmap')
21
+ sys.path.insert(0, 'internal/pycolmap/pycolmap')
22
+ import pycolmap
23
+
24
+
25
+ def load_dataset(split, train_dir, config: configs.Config):
26
+ """Loads a split of a dataset using the data_loader specified by `config`."""
27
+ if config.multiscale:
28
+ dataset_dict = {
29
+ 'llff': MultiLLFF,
30
+ }
31
+ else:
32
+ dataset_dict = {
33
+ 'blender': Blender,
34
+ 'llff': LLFF,
35
+ 'tat_nerfpp': TanksAndTemplesNerfPP,
36
+ 'tat_fvs': TanksAndTemplesFVS,
37
+ 'dtu': DTU,
38
+ }
39
+ return dataset_dict[config.dataset_loader](split, train_dir, config)
40
+
41
+
42
+ class NeRFSceneManager(pycolmap.SceneManager):
43
+ """COLMAP pose loader.
44
+
45
+ Minor NeRF-specific extension to the third_party Python COLMAP loader:
46
+ google3/third_party/py/pycolmap/scene_manager.py
47
+ """
48
+
49
+ def process(self):
50
+ """Applies NeRF-specific postprocessing to the loaded pose data.
51
+
52
+ Returns:
53
+ a tuple [image_names, poses, pixtocam, distortion_params].
54
+ image_names: contains the only the basename of the images.
55
+ poses: [N, 4, 4] array containing the camera to world matrices.
56
+ pixtocam: [N, 3, 3] array containing the camera to pixel space matrices.
57
+ distortion_params: mapping of distortion param name to distortion
58
+ parameters. Cameras share intrinsics. Valid keys are k1, k2, p1 and p2.
59
+ """
60
+
61
+ self.load_cameras()
62
+ self.load_images()
63
+ # self.load_points3D() # For now, we do not need the point cloud data.
64
+
65
+ # Assume shared intrinsics between all cameras.
66
+ cam = self.cameras[1]
67
+
68
+ # Extract focal lengths and principal point parameters.
69
+ fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
70
+ pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))
71
+
72
+ # Extract extrinsic matrices in world-to-camera format.
73
+ imdata = self.images
74
+ w2c_mats = []
75
+ bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
76
+ for k in imdata:
77
+ im = imdata[k]
78
+ rot = im.R()
79
+ trans = im.tvec.reshape(3, 1)
80
+ w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
81
+ w2c_mats.append(w2c)
82
+ w2c_mats = np.stack(w2c_mats, axis=0)
83
+
84
+ # Convert extrinsics to camera-to-world.
85
+ c2w_mats = np.linalg.inv(w2c_mats)
86
+ poses = c2w_mats[:, :3, :4]
87
+
88
+ # Image names from COLMAP. No need for permuting the poses according to
89
+ # image names anymore.
90
+ names = [imdata[k].name for k in imdata]
91
+
92
+ # Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.
93
+ poses = poses @ np.diag([1, -1, -1, 1])
94
+
95
+ # Get distortion parameters.
96
+ type_ = cam.camera_type
97
+
98
+ if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
99
+ params = None
100
+ camtype = camera_utils.ProjectionType.PERSPECTIVE
101
+
102
+ elif type_ == 1 or type_ == 'PINHOLE':
103
+ params = None
104
+ camtype = camera_utils.ProjectionType.PERSPECTIVE
105
+
106
+ if type_ == 2 or type_ == 'SIMPLE_RADIAL':
107
+ params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
108
+ params['k1'] = cam.k1
109
+ camtype = camera_utils.ProjectionType.PERSPECTIVE
110
+
111
+ elif type_ == 3 or type_ == 'RADIAL':
112
+ params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
113
+ params['k1'] = cam.k1
114
+ params['k2'] = cam.k2
115
+ camtype = camera_utils.ProjectionType.PERSPECTIVE
116
+
117
+ elif type_ == 4 or type_ == 'OPENCV':
118
+ params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
119
+ params['k1'] = cam.k1
120
+ params['k2'] = cam.k2
121
+ params['p1'] = cam.p1
122
+ params['p2'] = cam.p2
123
+ camtype = camera_utils.ProjectionType.PERSPECTIVE
124
+
125
+ elif type_ == 5 or type_ == 'OPENCV_FISHEYE':
126
+ params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}
127
+ params['k1'] = cam.k1
128
+ params['k2'] = cam.k2
129
+ params['k3'] = cam.k3
130
+ params['k4'] = cam.k4
131
+ camtype = camera_utils.ProjectionType.FISHEYE
132
+
133
+ return names, poses, pixtocam, params, camtype
134
+
135
+
136
+ def load_blender_posedata(data_dir, split=None):
137
+ """Load poses from `transforms.json` file, as used in Blender/NGP datasets."""
138
+ suffix = '' if split is None else f'_{split}'
139
+ pose_file = os.path.join(data_dir, f'transforms{suffix}.json')
140
+ with utils.open_file(pose_file, 'r') as fp:
141
+ meta = json.load(fp)
142
+ names = []
143
+ poses = []
144
+ for _, frame in enumerate(meta['frames']):
145
+ filepath = os.path.join(data_dir, frame['file_path'])
146
+ if utils.file_exists(filepath):
147
+ names.append(frame['file_path'].split('/')[-1])
148
+ poses.append(np.array(frame['transform_matrix'], dtype=np.float32))
149
+ poses = np.stack(poses, axis=0)
150
+
151
+ w = meta['w']
152
+ h = meta['h']
153
+ cx = meta['cx'] if 'cx' in meta else w / 2.
154
+ cy = meta['cy'] if 'cy' in meta else h / 2.
155
+ if 'fl_x' in meta:
156
+ fx = meta['fl_x']
157
+ else:
158
+ fx = 0.5 * w / np.tan(0.5 * float(meta['camera_angle_x']))
159
+ if 'fl_y' in meta:
160
+ fy = meta['fl_y']
161
+ else:
162
+ fy = 0.5 * h / np.tan(0.5 * float(meta['camera_angle_y']))
163
+ pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))
164
+ coeffs = ['k1', 'k2', 'p1', 'p2']
165
+ if not any([c in meta for c in coeffs]):
166
+ params = None
167
+ else:
168
+ params = {c: (meta[c] if c in meta else 0.) for c in coeffs}
169
+ camtype = camera_utils.ProjectionType.PERSPECTIVE
170
+ return names, poses, pixtocam, params, camtype
171
+
172
+
173
+ class Dataset(torch.utils.data.Dataset):
174
+ """Dataset Base Class.
175
+
176
+ Base class for a NeRF dataset. Creates batches of ray and color data used for
177
+ training or rendering a NeRF model.
178
+
179
+ Each subclass is responsible for loading images and camera poses from disk by
180
+ implementing the _load_renderings() method. This data is used to generate
181
+ train and test batches of ray + color data for feeding through the NeRF model.
182
+ The ray parameters are calculated in _generate_rays().
183
+
184
+ The public interface mimics the behavior of a standard machine learning
185
+ pipeline dataset provider that can provide infinite batches of data to the
186
+ training/testing pipelines without exposing any details of how the batches are
187
+ loaded/created or how this is parallelized. Therefore, the initializer runs
188
+ all setup, including data loading from disk using _load_renderings(), and
189
+ begins the thread using its parent start() method. After the initializer
190
+ returns, the caller can request batches of data straight away.
191
+
192
+ The internal self._queue is initialized as queue.Queue(3), so the infinite
193
+ loop in run() will block on the call self._queue.put(self._next_fn()) once
194
+ there are 3 elements. The main thread training job runs in a loop that pops 1
195
+ element at a time off the front of the queue. The Dataset thread's run() loop
196
+ will populate the queue with 3 elements, then wait until a batch has been
197
+ removed and push one more onto the end.
198
+
199
+ This repeats indefinitely until the main thread's training loop completes
200
+ (typically hundreds of thousands of iterations), then the main thread will
201
+ exit and the Dataset thread will automatically be killed since it is a daemon.
202
+
203
+ Attributes:
204
+ alphas: np.ndarray, optional array of alpha channel data.
205
+ cameras: tuple summarizing all camera extrinsic/intrinsic/distortion params.
206
+ camtoworlds: np.ndarray, a list of extrinsic camera pose matrices.
207
+ camtype: camera_utils.ProjectionType, fisheye or perspective camera.
208
+ data_dir: str, location of the dataset on disk.
209
+ disp_images: np.ndarray, optional array of disparity (inverse depth) data.
210
+ distortion_params: dict, the camera distortion model parameters.
211
+ exposures: optional per-image exposure value (shutter * ISO / 1000).
212
+ far: float, far plane value for rays.
213
+ focal: float, focal length from camera intrinsics.
214
+ height: int, height of images.
215
+ images: np.ndarray, array of RGB image data.
216
+ metadata: dict, optional metadata for raw datasets.
217
+ near: float, near plane value for rays.
218
+ normal_images: np.ndarray, optional array of surface normal vector data.
219
+ pixtocams: np.ndarray, one or a list of inverse intrinsic camera matrices.
220
+ pixtocam_ndc: np.ndarray, the inverse intrinsic matrix used for NDC space.
221
+ poses: np.ndarray, optional array of auxiliary camera pose data.
222
+ rays: utils.Rays, ray data for every pixel in the dataset.
223
+ render_exposures: optional list of exposure values for the render path.
224
+ render_path: bool, indicates if a smooth camera path should be generated.
225
+ size: int, number of images in the dataset.
226
+ split: str, indicates if this is a "train" or "test" dataset.
227
+ width: int, width of images.
228
+ """
229
+
230
+ def __init__(self,
231
+ split: str,
232
+ data_dir: str,
233
+ config: configs.Config):
234
+ super().__init__()
235
+
236
+ # Initialize attributes
237
+ self._patch_size = max(config.patch_size, 1)
238
+ self._batch_size = config.batch_size // config.world_size
239
+ if self._patch_size ** 2 > self._batch_size:
240
+ raise ValueError(f'Patch size {self._patch_size}^2 too large for ' +
241
+ f'per-process batch size {self._batch_size}')
242
+ self._batching = utils.BatchingMethod(config.batching)
243
+ self._use_tiffs = config.use_tiffs
244
+ self._load_disps = config.compute_disp_metrics
245
+ self._load_normals = config.compute_normal_metrics
246
+ self._num_border_pixels_to_mask = config.num_border_pixels_to_mask
247
+ self._apply_bayer_mask = config.apply_bayer_mask
248
+ self._render_spherical = False
249
+
250
+ self.config = config
251
+ self.global_rank = config.global_rank
252
+ self.world_size = config.world_size
253
+ self.split = utils.DataSplit(split)
254
+ self.data_dir = data_dir
255
+ self.near = config.near
256
+ self.far = config.far
257
+ self.render_path = config.render_path
258
+ self.distortion_params = None
259
+ self.disp_images = None
260
+ self.normal_images = None
261
+ self.alphas = None
262
+ self.poses = None
263
+ self.pixtocam_ndc = None
264
+ self.metadata = None
265
+ self.camtype = camera_utils.ProjectionType.PERSPECTIVE
266
+ self.exposures = None
267
+ self.render_exposures = None
268
+
269
+ # Providing type comments for these attributes, they must be correctly
270
+ # initialized by _load_renderings() (see docstring) in any subclass.
271
+ self.images: np.ndarray = None
272
+ self.camtoworlds: np.ndarray = None
273
+ self.pixtocams: np.ndarray = None
274
+ self.height: int = None
275
+ self.width: int = None
276
+
277
+ # Load data from disk using provided config parameters.
278
+ self._load_renderings(config)
279
+
280
+ if self.render_path:
281
+ if config.render_path_file is not None:
282
+ with utils.open_file(config.render_path_file, 'rb') as fp:
283
+ render_poses = np.load(fp)
284
+ self.camtoworlds = render_poses
285
+ if config.render_resolution is not None:
286
+ self.width, self.height = config.render_resolution
287
+ if config.render_focal is not None:
288
+ self.focal = config.render_focal
289
+ if config.render_camtype is not None:
290
+ if config.render_camtype == 'pano':
291
+ self._render_spherical = True
292
+ else:
293
+ self.camtype = camera_utils.ProjectionType(config.render_camtype)
294
+
295
+ self.distortion_params = None
296
+ self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
297
+ self.height)
298
+
299
+ self._n_examples = self.camtoworlds.shape[0]
300
+
301
+ self.cameras = (self.pixtocams,
302
+ self.camtoworlds,
303
+ self.distortion_params,
304
+ self.pixtocam_ndc)
305
+
306
+ # Seed the queue with one batch to avoid race condition.
307
+ if self.split == utils.DataSplit.TRAIN and not config.compute_visibility:
308
+ self._next_fn = self._next_train
309
+ else:
310
+ self._next_fn = self._next_test
311
+
312
+ @property
313
+ def size(self):
314
+ return self._n_examples
315
+
316
+ def __len__(self):
317
+ if self.split == utils.DataSplit.TRAIN and not self.config.compute_visibility:
318
+ return 1000
319
+ else:
320
+ return self._n_examples
321
+
322
+ @abc.abstractmethod
323
+ def _load_renderings(self, config):
324
+ """Load images and poses from disk.
325
+
326
+ Args:
327
+ config: utils.Config, user-specified config parameters.
328
+ In inherited classes, this method must set the following public attributes:
329
+ images: [N, height, width, 3] array for RGB images.
330
+ disp_images: [N, height, width] array for depth data (optional).
331
+ normal_images: [N, height, width, 3] array for normals (optional).
332
+ camtoworlds: [N, 3, 4] array of extrinsic pose matrices.
333
+ poses: [..., 3, 4] array of auxiliary pose data (optional).
334
+ pixtocams: [N, 3, 4] array of inverse intrinsic matrices.
335
+ distortion_params: dict, camera lens distortion model parameters.
336
+ height: int, height of images.
337
+ width: int, width of images.
338
+ focal: float, focal length to use for ideal pinhole rendering.
339
+ """
340
+
341
+ def _make_ray_batch(self,
342
+ pix_x_int,
343
+ pix_y_int,
344
+ cam_idx,
345
+ lossmult=None
346
+ ):
347
+ """Creates ray data batch from pixel coordinates and camera indices.
348
+
349
+ All arguments must have broadcastable shapes. If the arguments together
350
+ broadcast to a shape [a, b, c, ..., z] then the returned utils.Rays object
351
+ will have array attributes with shape [a, b, c, ..., z, N], where N=3 for
352
+ 3D vectors and N=1 for per-ray scalar attributes.
353
+
354
+ Args:
355
+ pix_x_int: int array, x coordinates of image pixels.
356
+ pix_y_int: int array, y coordinates of image pixels.
357
+ cam_idx: int or int array, camera indices.
358
+ lossmult: float array, weight to apply to each ray when computing loss fn.
359
+
360
+ Returns:
361
+ A dict mapping from strings utils.Rays or arrays of image data.
362
+ This is the batch provided for one NeRF train or test iteration.
363
+ """
364
+
365
+ broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
366
+ ray_kwargs = {
367
+ 'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult,
368
+ 'near': broadcast_scalar(self.near),
369
+ 'far': broadcast_scalar(self.far),
370
+ 'cam_idx': broadcast_scalar(cam_idx),
371
+ }
372
+ # Collect per-camera information needed for each ray.
373
+ if self.metadata is not None:
374
+ # Exposure index and relative shutter speed, needed for RawNeRF.
375
+ for key in ['exposure_idx', 'exposure_values']:
376
+ idx = 0 if self.render_path else cam_idx
377
+ ray_kwargs[key] = broadcast_scalar(self.metadata[key][idx])
378
+ if self.exposures is not None:
379
+ idx = 0 if self.render_path else cam_idx
380
+ ray_kwargs['exposure_values'] = broadcast_scalar(self.exposures[idx])
381
+ if self.render_path and self.render_exposures is not None:
382
+ ray_kwargs['exposure_values'] = broadcast_scalar(
383
+ self.render_exposures[cam_idx])
384
+
385
+ pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs)
386
+
387
+ # Slow path, do ray computation using numpy (on CPU).
388
+ batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype)
389
+ batch['cam_dirs'] = -self.camtoworlds[ray_kwargs['cam_idx'][..., 0]][..., :3, 2]
390
+
391
+ # import trimesh
392
+ # pts = batch['origins'][..., None, :] + batch['directions'][..., None, :] * np.linspace(0, 1, 5)[:, None]
393
+ # trimesh.Trimesh(vertices=pts.reshape(-1, 3)).export("test.ply", "ply")
394
+ #
395
+ # pts = batch['origins'][0, 0, None, :] - self.camtoworlds[cam_idx][:, 2] * np.linspace(0, 1, 100)[:, None]
396
+ # trimesh.Trimesh(vertices=pts.reshape(-1, 3)).export("test2.ply", "ply")
397
+
398
+ if not self.render_path:
399
+ batch['rgb'] = self.images[cam_idx, pix_y_int, pix_x_int]
400
+ if self._load_disps:
401
+ batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int]
402
+ if self._load_normals:
403
+ batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int]
404
+ batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int]
405
+ return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()}
406
+
407
+ def _next_train(self, item):
408
+ """Sample next training batch (random rays)."""
409
+ # We assume all images in the dataset are the same resolution, so we can use
410
+ # the same width/height for sampling all pixels coordinates in the batch.
411
+ # Batch/patch sampling parameters.
412
+ num_patches = self._batch_size // self._patch_size ** 2
413
+ lower_border = self._num_border_pixels_to_mask
414
+ upper_border = self._num_border_pixels_to_mask + self._patch_size - 1
415
+ # Random pixel patch x-coordinates.
416
+ pix_x_int = np.random.randint(lower_border, self.width - upper_border,
417
+ (num_patches, 1, 1))
418
+ # Random pixel patch y-coordinates.
419
+ pix_y_int = np.random.randint(lower_border, self.height - upper_border,
420
+ (num_patches, 1, 1))
421
+ # Add patch coordinate offsets.
422
+ # Shape will broadcast to (num_patches, _patch_size, _patch_size).
423
+ patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates(
424
+ self._patch_size, self._patch_size)
425
+ pix_x_int = pix_x_int + patch_dx_int
426
+ pix_y_int = pix_y_int + patch_dy_int
427
+ # Random camera indices.
428
+ if self._batching == utils.BatchingMethod.ALL_IMAGES:
429
+ cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))
430
+ else:
431
+ cam_idx = np.random.randint(0, self._n_examples, (1,))
432
+
433
+ if self._apply_bayer_mask:
434
+ # Compute the Bayer mosaic mask for each pixel in the batch.
435
+ lossmult = raw_utils.pixels_to_bayer_mask(pix_x_int, pix_y_int)
436
+ else:
437
+ lossmult = None
438
+
439
+ return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx,
440
+ lossmult=lossmult)
441
+
442
+ def generate_ray_batch(self, cam_idx: int):
443
+ """Generate ray batch for a specified camera in the dataset."""
444
+ if self._render_spherical:
445
+ camtoworld = self.camtoworlds[cam_idx]
446
+ rays = camera_utils.cast_spherical_rays(
447
+ camtoworld, self.height, self.width, self.near, self.far)
448
+ return rays
449
+ else:
450
+ # Generate rays for all pixels in the image.
451
+ pix_x_int, pix_y_int = camera_utils.pixel_coordinates(
452
+ self.width, self.height)
453
+ return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx)
454
+
455
+ def _next_test(self, item):
456
+ """Sample next test batch (one full image)."""
457
+ return self.generate_ray_batch(item)
458
+
459
+ def collate_fn(self, item):
460
+ return self._next_fn(item[0])
461
+
462
+ def __getitem__(self, item):
463
+ return self._next_fn(item)
464
+
465
+
466
+ class Blender(Dataset):
467
+ """Blender Dataset."""
468
+
469
+ def _load_renderings(self, config):
470
+ """Load images from disk."""
471
+ if config.render_path:
472
+ raise ValueError('render_path cannot be used for the blender dataset.')
473
+ pose_file = os.path.join(self.data_dir, f'transforms_{self.split.value}.json')
474
+ with utils.open_file(pose_file, 'r') as fp:
475
+ meta = json.load(fp)
476
+ images = []
477
+ disp_images = []
478
+ normal_images = []
479
+ cams = []
480
+ for idx, frame in enumerate(tqdm(meta['frames'], desc='Loading Blender dataset', disable=self.global_rank != 0, leave=False)):
481
+ fprefix = os.path.join(self.data_dir, frame['file_path'])
482
+
483
+ def get_img(f, fprefix=fprefix):
484
+ image = utils.load_img(fprefix + f)
485
+ if config.factor > 1:
486
+ image = lib_image.downsample(image, config.factor)
487
+ return image
488
+
489
+ if self._use_tiffs:
490
+ channels = [get_img(f'_{ch}.tiff') for ch in ['R', 'G', 'B', 'A']]
491
+ # Convert image to sRGB color space.
492
+ image = lib_image.linear_to_srgb_np(np.stack(channels, axis=-1))
493
+ else:
494
+ image = get_img('.png') / 255.
495
+ images.append(image)
496
+
497
+ if self._load_disps:
498
+ disp_image = get_img('_disp.tiff')
499
+ disp_images.append(disp_image)
500
+ if self._load_normals:
501
+ normal_image = get_img('_normal.png')[..., :3] * 2. / 255. - 1.
502
+ normal_images.append(normal_image)
503
+
504
+ cams.append(np.array(frame['transform_matrix'], dtype=np.float32))
505
+
506
+ self.images = np.stack(images, axis=0)
507
+ if self._load_disps:
508
+ self.disp_images = np.stack(disp_images, axis=0)
509
+ if self._load_normals:
510
+ self.normal_images = np.stack(normal_images, axis=0)
511
+ self.alphas = self.images[..., -1]
512
+
513
+ rgb, alpha = self.images[..., :3], self.images[..., -1:]
514
+ self.images = rgb * alpha + (1. - alpha) # Use a white background.
515
+ self.height, self.width = self.images.shape[1:3]
516
+ self.camtoworlds = np.stack(cams, axis=0)
517
+ self.focal = .5 * self.width / np.tan(.5 * float(meta['camera_angle_x']))
518
+ self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
519
+ self.height)
520
+
521
+
522
+ class LLFF(Dataset):
523
+ """LLFF Dataset."""
524
+
525
+ def _load_renderings(self, config):
526
+ """Load images from disk."""
527
+ # Set up scaling factor.
528
+ image_dir_suffix = ''
529
+ # Use downsampling factor (unless loading training split for raw dataset,
530
+ # we train raw at full resolution because of the Bayer mosaic pattern).
531
+ if config.factor > 0 and not (config.rawnerf_mode and
532
+ self.split == utils.DataSplit.TRAIN):
533
+ image_dir_suffix = f'_{config.factor}'
534
+ factor = config.factor
535
+ else:
536
+ factor = 1
537
+
538
+ # Copy COLMAP data to local disk for faster loading.
539
+ colmap_dir = os.path.join(self.data_dir, 'sparse/0/')
540
+
541
+ # Load poses.
542
+ if utils.file_exists(colmap_dir):
543
+ pose_data = NeRFSceneManager(colmap_dir).process()
544
+ else:
545
+ # # Attempt to load Blender/NGP format if COLMAP data not present.
546
+ # pose_data = load_blender_posedata(self.data_dir)
547
+ raise ValueError('COLMAP data not found.')
548
+ image_names, poses, pixtocam, distortion_params, camtype = pose_data
549
+
550
+ # Previous NeRF results were generated with images sorted by filename,
551
+ # use this flag to ensure metrics are reported on the same test set.
552
+ inds = np.argsort(image_names)
553
+ image_names = [image_names[i] for i in inds]
554
+ poses = poses[inds]
555
+
556
+ # Load bounds if possible (only used in forward facing scenes).
557
+ posefile = os.path.join(self.data_dir, 'poses_bounds.npy')
558
+ if utils.file_exists(posefile):
559
+ with utils.open_file(posefile, 'rb') as fp:
560
+ poses_arr = np.load(fp)
561
+ bounds = poses_arr[:, -2:]
562
+ else:
563
+ bounds = np.array([0.01, 1.])
564
+ self.colmap_to_world_transform = np.eye(4)
565
+
566
+ # Scale the inverse intrinsics matrix by the image downsampling factor.
567
+ pixtocam = pixtocam @ np.diag([factor, factor, 1.])
568
+ self.pixtocams = pixtocam.astype(np.float32)
569
+ self.focal = 1. / self.pixtocams[0, 0]
570
+ self.distortion_params = distortion_params
571
+ self.camtype = camtype
572
+
573
+ # Separate out 360 versus forward facing scenes.
574
+ if config.forward_facing:
575
+ # Set the projective matrix defining the NDC transformation.
576
+ self.pixtocam_ndc = self.pixtocams.reshape(-1, 3, 3)[0]
577
+ # Rescale according to a default bd factor.
578
+ scale = 1. / (bounds.min() * .75)
579
+ poses[:, :3, 3] *= scale
580
+ self.colmap_to_world_transform = np.diag([scale] * 3 + [1])
581
+ bounds *= scale
582
+ # Recenter poses.
583
+ poses, transform = camera_utils.recenter_poses(poses)
584
+ self.colmap_to_world_transform = (
585
+ transform @ self.colmap_to_world_transform)
586
+ # Forward-facing spiral render path.
587
+ self.render_poses = camera_utils.generate_spiral_path(
588
+ poses, bounds, n_frames=config.render_path_frames)
589
+ else:
590
+ # Rotate/scale poses to align ground with xy plane and fit to unit cube.
591
+ poses, transform = camera_utils.transform_poses_pca(poses)
592
+ self.colmap_to_world_transform = transform
593
+ if config.render_spline_keyframes is not None:
594
+ rets = camera_utils.create_render_spline_path(config, image_names,
595
+ poses, self.exposures)
596
+ self.spline_indices, self.render_poses, self.render_exposures = rets
597
+ else:
598
+ # Automatically generated inward-facing elliptical render path.
599
+ self.render_poses = camera_utils.generate_ellipse_path(
600
+ poses,
601
+ n_frames=config.render_path_frames,
602
+ z_variation=config.z_variation,
603
+ z_phase=config.z_phase)
604
+
605
+ # Select the split.
606
+ all_indices = np.arange(len(image_names))
607
+ if config.llff_use_all_images_for_training:
608
+ train_indices = all_indices
609
+ else:
610
+ train_indices = all_indices % config.llffhold != 0
611
+ if config.llff_use_all_images_for_testing:
612
+ test_indices = all_indices
613
+ else:
614
+ test_indices = all_indices % config.llffhold == 0
615
+ split_indices = {
616
+ utils.DataSplit.TEST: all_indices[test_indices],
617
+ utils.DataSplit.TRAIN: all_indices[train_indices],
618
+ }
619
+ indices = split_indices[self.split]
620
+ image_names = [image_names[i] for i in indices]
621
+ poses = poses[indices]
622
+ # if self.split == utils.DataSplit.TRAIN:
623
+ # # load different training data on different rank
624
+ # local_indices = [i for i in range(len(image_names)) if (i + self.global_rank) % self.world_size == 0]
625
+ # image_names = [image_names[i] for i in local_indices]
626
+ # poses = poses[local_indices]
627
+ # indices = local_indices
628
+
629
+ raw_testscene = False
630
+ if config.rawnerf_mode:
631
+ # Load raw images and metadata.
632
+ images, metadata, raw_testscene = raw_utils.load_raw_dataset(
633
+ self.split,
634
+ self.data_dir,
635
+ image_names,
636
+ config.exposure_percentile,
637
+ factor)
638
+ self.metadata = metadata
639
+
640
+ else:
641
+ # Load images.
642
+ colmap_image_dir = os.path.join(self.data_dir, 'images')
643
+ image_dir = os.path.join(self.data_dir, 'images' + image_dir_suffix)
644
+ for d in [image_dir, colmap_image_dir]:
645
+ if not utils.file_exists(d):
646
+ raise ValueError(f'Image folder {d} does not exist.')
647
+ # Downsampled images may have different names vs images used for COLMAP,
648
+ # so we need to map between the two sorted lists of files.
649
+ colmap_files = sorted(utils.listdir(colmap_image_dir))
650
+ image_files = sorted(utils.listdir(image_dir))
651
+ colmap_to_image = dict(zip(colmap_files, image_files))
652
+ image_paths = [os.path.join(image_dir, colmap_to_image[f])
653
+ for f in image_names]
654
+ images = [utils.load_img(x) for x in tqdm(image_paths, desc='Loading LLFF dataset', disable=self.global_rank != 0, leave=False)]
655
+ images = np.stack(images, axis=0) / 255.
656
+
657
+ # EXIF data is usually only present in the original JPEG images.
658
+ jpeg_paths = [os.path.join(colmap_image_dir, f) for f in image_names]
659
+ exifs = [utils.load_exif(x) for x in jpeg_paths]
660
+ self.exifs = exifs
661
+ if 'ExposureTime' in exifs[0] and 'ISOSpeedRatings' in exifs[0]:
662
+ gather_exif_value = lambda k: np.array([float(x[k]) for x in exifs])
663
+ shutters = gather_exif_value('ExposureTime')
664
+ isos = gather_exif_value('ISOSpeedRatings')
665
+ self.exposures = shutters * isos / 1000.
666
+
667
+ if raw_testscene:
668
+ # For raw testscene, the first image sent to COLMAP has the same pose as
669
+ # the ground truth test image. The remaining images form the training set.
670
+ raw_testscene_poses = {
671
+ utils.DataSplit.TEST: poses[:1],
672
+ utils.DataSplit.TRAIN: poses[1:],
673
+ }
674
+ poses = raw_testscene_poses[self.split]
675
+
676
+ self.poses = poses
677
+ self.images = images
678
+ self.camtoworlds = self.render_poses if config.render_path else poses
679
+ self.height, self.width = images.shape[1:3]
680
+
681
+
682
+ class TanksAndTemplesNerfPP(Dataset):
683
+ """Subset of Tanks and Temples Dataset as processed by NeRF++."""
684
+
685
+ def _load_renderings(self, config):
686
+ """Load images from disk."""
687
+ if config.render_path:
688
+ split_str = 'camera_path'
689
+ else:
690
+ split_str = self.split.value
691
+
692
+ basedir = os.path.join(self.data_dir, split_str)
693
+
694
+ # TODO: need to rewrite this to put different data on different rank
695
+ def load_files(dirname, load_fn, shape=None):
696
+ files = [
697
+ os.path.join(basedir, dirname, f)
698
+ for f in sorted(utils.listdir(os.path.join(basedir, dirname)))
699
+ ]
700
+ mats = np.array([load_fn(utils.open_file(f, 'rb')) for f in files])
701
+ if shape is not None:
702
+ mats = mats.reshape(mats.shape[:1] + shape)
703
+ return mats
704
+
705
+ poses = load_files('pose', np.loadtxt, (4, 4))
706
+ # Flip Y and Z axes to get correct coordinate frame.
707
+ poses = np.matmul(poses, np.diag(np.array([1, -1, -1, 1])))
708
+
709
+ # For now, ignore all but the first focal length in intrinsics
710
+ intrinsics = load_files('intrinsics', np.loadtxt, (4, 4))
711
+
712
+ if not config.render_path:
713
+ images = load_files('rgb', lambda f: np.array(Image.open(f))) / 255.
714
+ self.images = images
715
+ self.height, self.width = self.images.shape[1:3]
716
+
717
+ else:
718
+ # Hack to grab the image resolution from a test image
719
+ d = os.path.join(self.data_dir, 'test', 'rgb')
720
+ f = os.path.join(d, sorted(utils.listdir(d))[0])
721
+ shape = utils.load_img(f).shape
722
+ self.height, self.width = shape[:2]
723
+ self.images = None
724
+
725
+ self.camtoworlds = poses
726
+ self.focal = intrinsics[0, 0, 0]
727
+ self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
728
+ self.height)
729
+
730
+
731
+ class TanksAndTemplesFVS(Dataset):
732
+ """Subset of Tanks and Temples Dataset as processed by Free View Synthesis."""
733
+
734
+ def _load_renderings(self, config):
735
+ """Load images from disk."""
736
+ render_only = config.render_path and self.split == utils.DataSplit.TEST
737
+
738
+ basedir = os.path.join(self.data_dir, 'dense')
739
+ sizes = [f for f in sorted(utils.listdir(basedir)) if f.startswith('ibr3d')]
740
+ sizes = sizes[::-1]
741
+
742
+ if config.factor >= len(sizes):
743
+ raise ValueError(f'Factor {config.factor} larger than {len(sizes)}')
744
+
745
+ basedir = os.path.join(basedir, sizes[config.factor])
746
+ open_fn = lambda f: utils.open_file(os.path.join(basedir, f), 'rb')
747
+
748
+ files = [f for f in sorted(utils.listdir(basedir)) if f.startswith('im_')]
749
+ if render_only:
750
+ files = files[:1]
751
+ images = np.array([np.array(Image.open(open_fn(f))) for f in files]) / 255.
752
+
753
+ names = ['Ks', 'Rs', 'ts']
754
+ intrinsics, rot, trans = (np.load(open_fn(f'{n}.npy')) for n in names)
755
+
756
+ # Convert poses from colmap world-to-cam into our cam-to-world.
757
+ w2c = np.concatenate([rot, trans[..., None]], axis=-1)
758
+ c2w_colmap = np.linalg.inv(camera_utils.pad_poses(w2c))[:, :3, :4]
759
+ c2w = c2w_colmap @ np.diag(np.array([1, -1, -1, 1]))
760
+
761
+ # Reorient poses so z-axis is up
762
+ poses, _ = camera_utils.transform_poses_pca(c2w)
763
+ self.poses = poses
764
+
765
+ self.images = images
766
+ self.height, self.width = self.images.shape[1:3]
767
+ self.camtoworlds = poses
768
+ # For now, ignore all but the first focal length in intrinsics
769
+ self.focal = intrinsics[0, 0, 0]
770
+ self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
771
+ self.height)
772
+
773
+ if render_only:
774
+ render_path = camera_utils.generate_ellipse_path(
775
+ poses,
776
+ config.render_path_frames,
777
+ z_variation=config.z_variation,
778
+ z_phase=config.z_phase)
779
+ self.images = None
780
+ self.camtoworlds = render_path
781
+ self.render_poses = render_path
782
+ else:
783
+ # Select the split.
784
+ all_indices = np.arange(images.shape[0])
785
+ indices = {
786
+ utils.DataSplit.TEST:
787
+ all_indices[all_indices % config.llffhold == 0],
788
+ utils.DataSplit.TRAIN:
789
+ all_indices[all_indices % config.llffhold != 0],
790
+ }[self.split]
791
+
792
+ self.images = self.images[indices]
793
+ self.camtoworlds = self.camtoworlds[indices]
794
+
795
+
796
+ class DTU(Dataset):
797
+ """DTU Dataset."""
798
+
799
+ def _load_renderings(self, config):
800
+ """Load images from disk."""
801
+ if config.render_path:
802
+ raise ValueError('render_path cannot be used for the DTU dataset.')
803
+
804
+ images = []
805
+ pixtocams = []
806
+ camtoworlds = []
807
+
808
+ # Find out whether the particular scan has 49 or 65 images.
809
+ n_images = len(utils.listdir(self.data_dir)) // 8
810
+
811
+ # Loop over all images.
812
+ for i in range(1, n_images + 1):
813
+ # Set light condition string accordingly.
814
+ if config.dtu_light_cond < 7:
815
+ light_str = f'{config.dtu_light_cond}_r' + ('5000'
816
+ if i < 50 else '7000')
817
+ else:
818
+ light_str = 'max'
819
+
820
+ # Load image.
821
+ fname = os.path.join(self.data_dir, f'rect_{i:03d}_{light_str}.png')
822
+ image = utils.load_img(fname) / 255.
823
+ if config.factor > 1:
824
+ image = lib_image.downsample(image, config.factor)
825
+ images.append(image)
826
+
827
+ # Load projection matrix from file.
828
+ fname = os.path.join(self.data_dir, f'../../cal18/pos_{i:03d}.txt')
829
+ with utils.open_file(fname, 'rb') as f:
830
+ projection = np.loadtxt(f, dtype=np.float32)
831
+
832
+ # Decompose projection matrix into pose and camera matrix.
833
+ camera_mat, rot_mat, t = cv2.decomposeProjectionMatrix(projection)[:3]
834
+ camera_mat = camera_mat / camera_mat[2, 2]
835
+ pose = np.eye(4, dtype=np.float32)
836
+ pose[:3, :3] = rot_mat.transpose()
837
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
838
+ pose = pose[:3]
839
+ camtoworlds.append(pose)
840
+
841
+ if config.factor > 0:
842
+ # Scale camera matrix according to downsampling factor.
843
+ camera_mat = np.diag([1. / config.factor, 1. / config.factor, 1.
844
+ ]).astype(np.float32) @ camera_mat
845
+ pixtocams.append(np.linalg.inv(camera_mat))
846
+
847
+ pixtocams = np.stack(pixtocams)
848
+ camtoworlds = np.stack(camtoworlds)
849
+ images = np.stack(images)
850
+
851
+ def rescale_poses(poses):
852
+ """Rescales camera poses according to maximum x/y/z value."""
853
+ s = np.max(np.abs(poses[:, :3, -1]))
854
+ out = np.copy(poses)
855
+ out[:, :3, -1] /= s
856
+ return out
857
+
858
+ # Center and scale poses.
859
+ camtoworlds, _ = camera_utils.recenter_poses(camtoworlds)
860
+ camtoworlds = rescale_poses(camtoworlds)
861
+ # Flip y and z axes to get poses in OpenGL coordinate system.
862
+ camtoworlds = camtoworlds @ np.diag([1., -1., -1., 1.]).astype(np.float32)
863
+
864
+ all_indices = np.arange(images.shape[0])
865
+ split_indices = {
866
+ utils.DataSplit.TEST: all_indices[all_indices % config.dtuhold == 0],
867
+ utils.DataSplit.TRAIN: all_indices[all_indices % config.dtuhold != 0],
868
+ }
869
+ indices = split_indices[self.split]
870
+
871
+ self.images = images[indices]
872
+ self.height, self.width = images.shape[1:3]
873
+ self.camtoworlds = camtoworlds[indices]
874
+ self.pixtocams = pixtocams[indices]
875
+
876
+
877
+ class Multicam(Dataset):
878
+ def __init__(self,
879
+ split: str,
880
+ data_dir: str,
881
+ config: configs.Config):
882
+ super().__init__(split, data_dir, config)
883
+
884
+ self.multiscale_levels = config.multiscale_levels
885
+
886
+ images, camtoworlds, pixtocams, pixtocam_ndc = \
887
+ self.images, self.camtoworlds, self.pixtocams, self.pixtocam_ndc
888
+ self.heights, self.widths, self.focals, self.images, self.camtoworlds, self.pixtocams, self.lossmults = [], [], [], [], [], [], []
889
+ if pixtocam_ndc is not None:
890
+ self.pixtocam_ndc = []
891
+ else:
892
+ self.pixtocam_ndc = None
893
+
894
+ for i in range(self._n_examples):
895
+ for j in range(self.multiscale_levels):
896
+ self.heights.append(self.height // 2 ** j)
897
+ self.widths.append(self.width // 2 ** j)
898
+
899
+ self.pixtocams.append(pixtocams @ np.diag([self.height / self.heights[-1],
900
+ self.width / self.widths[-1],
901
+ 1.]))
902
+ self.focals.append(1. / self.pixtocams[-1][0, 0])
903
+ if config.forward_facing:
904
+ # Set the projective matrix defining the NDC transformation.
905
+ self.pixtocam_ndc.append(pixtocams.reshape(3, 3))
906
+
907
+ self.camtoworlds.append(camtoworlds[i])
908
+ self.lossmults.append(2. ** j)
909
+ self.images.append(self.down2(images[i], (self.heights[-1], self.widths[-1])))
910
+ self.pixtocams = np.stack(self.pixtocams)
911
+ self.camtoworlds = np.stack(self.camtoworlds)
912
+ self.cameras = (self.pixtocams,
913
+ self.camtoworlds,
914
+ self.distortion_params,
915
+ np.stack(self.pixtocam_ndc) if self.pixtocam_ndc is not None else None)
916
+ self._generate_rays()
917
+
918
+ if self.split == utils.DataSplit.TRAIN:
919
+ # Always flatten out the height x width dimensions
920
+ def flatten(x):
921
+ if x[0] is not None:
922
+ x = [y.reshape([-1, y.shape[-1]]) for y in x]
923
+ if self._batching == utils.BatchingMethod.ALL_IMAGES:
924
+ # If global batching, also concatenate all data into one list
925
+ x = np.concatenate(x, axis=0)
926
+ return x
927
+ else:
928
+ return None
929
+
930
+ self.batches = {k: flatten(v) for k, v in self.batches.items()}
931
+ self._n_examples = len(self.camtoworlds)
932
+
933
+ # Seed the queue with one batch to avoid race condition.
934
+ if self.split == utils.DataSplit.TRAIN:
935
+ self._next_fn = self._next_train
936
+ else:
937
+ self._next_fn = self._next_test
938
+
939
+ def _generate_rays(self):
940
+ if self.global_rank == 0:
941
+ tbar = tqdm(range(len(self.camtoworlds)), desc='Generating rays', leave=False)
942
+ else:
943
+ tbar = range(len(self.camtoworlds))
944
+
945
+ self.batches = defaultdict(list)
946
+ for cam_idx in tbar:
947
+ pix_x_int, pix_y_int = camera_utils.pixel_coordinates(
948
+ self.widths[cam_idx], self.heights[cam_idx])
949
+ broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
950
+ ray_kwargs = {
951
+ 'lossmult': broadcast_scalar(self.lossmults[cam_idx]),
952
+ 'near': broadcast_scalar(self.near),
953
+ 'far': broadcast_scalar(self.far),
954
+ 'cam_idx': broadcast_scalar(cam_idx),
955
+ }
956
+
957
+ pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs)
958
+
959
+ batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype)
960
+ if not self.render_path:
961
+ batch['rgb'] = self.images[cam_idx]
962
+ if self._load_disps:
963
+ batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int]
964
+ if self._load_normals:
965
+ batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int]
966
+ batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int]
967
+ for k, v in batch.items():
968
+ self.batches[k].append(v)
969
+
970
+ def _next_train(self, item):
971
+ """Sample next training batch (random rays)."""
972
+ # We assume all images in the dataset are the same resolution, so we can use
973
+ # the same width/height for sampling all pixels coordinates in the batch.
974
+ # Batch/patch sampling parameters.
975
+ num_patches = self._batch_size // self._patch_size ** 2
976
+ # Random camera indices.
977
+ if self._batching == utils.BatchingMethod.ALL_IMAGES:
978
+ ray_indices = np.random.randint(0, self.batches['origins'].shape[0], (num_patches, 1, 1))
979
+ batch = {k: v[ray_indices] if v is not None else None for k, v in self.batches.items()}
980
+ else:
981
+ image_index = np.random.randint(0, self._n_examples, ())
982
+ ray_indices = np.random.randint(0, self.batches['origins'][image_index].shape[0], (num_patches,))
983
+ batch = {k: v[image_index][ray_indices] if v is not None else None for k, v in self.batches.items()}
984
+ batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2]
985
+ return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()}
986
+
987
+ def _next_test(self, item):
988
+ """Sample next test batch (one full image)."""
989
+ batch = {k: v[item] for k, v in self.batches.items()}
990
+ batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2]
991
+ return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()}
992
+
993
+ @staticmethod
994
+ def down2(img, sh):
995
+ return cv2.resize(img, sh[::-1], interpolation=cv2.INTER_CUBIC)
996
+
997
+
998
+ class MultiLLFF(Multicam, LLFF):
999
+ pass
1000
+
1001
+
1002
+ if __name__ == '__main__':
1003
+ from internal import configs
1004
+ import accelerate
1005
+
1006
+ config = configs.Config()
1007
+ accelerator = accelerate.Accelerator()
1008
+ config.world_size = accelerator.num_processes
1009
+ config.global_rank = accelerator.process_index
1010
+ config.factor = 8
1011
+ dataset = LLFF('test', '/SSD_DISK/datasets/360_v2/bicycle', config)
1012
+ print(len(dataset))
1013
+ for _ in tqdm(dataset):
1014
+ pass
1015
+ print('done')
1016
+ # print(accelerator.process_index)
internal/geopoly.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+
4
+
5
+ def compute_sq_dist(mat0, mat1=None):
6
+ """Compute the squared Euclidean distance between all pairs of columns."""
7
+ if mat1 is None:
8
+ mat1 = mat0
9
+ # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y.
10
+ sq_norm0 = np.sum(mat0 ** 2, 0)
11
+ sq_norm1 = np.sum(mat1 ** 2, 0)
12
+ sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1
13
+ sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors.
14
+ return sq_dist
15
+
16
+
17
+ def compute_tesselation_weights(v):
18
+ """Tesselate the vertices of a triangle by a factor of `v`."""
19
+ if v < 1:
20
+ raise ValueError(f'v {v} must be >= 1')
21
+ int_weights = []
22
+ for i in range(v + 1):
23
+ for j in range(v + 1 - i):
24
+ int_weights.append((i, j, v - (i + j)))
25
+ int_weights = np.array(int_weights)
26
+ weights = int_weights / v # Barycentric weights.
27
+ return weights
28
+
29
+
30
+ def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4):
31
+ """Tesselate the vertices of a geodesic polyhedron.
32
+
33
+ Args:
34
+ base_verts: tensor of floats, the vertex coordinates of the geodesic.
35
+ base_faces: tensor of ints, the indices of the vertices of base_verts that
36
+ constitute eachface of the polyhedra.
37
+ v: int, the factor of the tesselation (v==1 is a no-op).
38
+ eps: float, a small value used to determine if two vertices are the same.
39
+
40
+ Returns:
41
+ verts: a tensor of floats, the coordinates of the tesselated vertices.
42
+ """
43
+ if not isinstance(v, int):
44
+ raise ValueError(f'v {v} must an integer')
45
+ tri_weights = compute_tesselation_weights(v)
46
+
47
+ verts = []
48
+ for base_face in base_faces:
49
+ new_verts = np.matmul(tri_weights, base_verts[base_face, :])
50
+ new_verts /= np.sqrt(np.sum(new_verts ** 2, 1, keepdims=True))
51
+ verts.append(new_verts)
52
+ verts = np.concatenate(verts, 0)
53
+
54
+ sq_dist = compute_sq_dist(verts.T)
55
+ assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist])
56
+ unique = np.unique(assignment)
57
+ verts = verts[unique, :]
58
+
59
+ return verts
60
+
61
+
62
+ def generate_basis(base_shape,
63
+ angular_tesselation,
64
+ remove_symmetries=True,
65
+ eps=1e-4):
66
+ """Generates a 3D basis by tesselating a geometric polyhedron.
67
+
68
+ Args:
69
+ base_shape: string, the name of the starting polyhedron, must be either
70
+ 'icosahedron' or 'octahedron'.
71
+ angular_tesselation: int, the number of times to tesselate the polyhedron,
72
+ must be >= 1 (a value of 1 is a no-op to the polyhedron).
73
+ remove_symmetries: bool, if True then remove the symmetric basis columns,
74
+ which is usually a good idea because otherwise projections onto the basis
75
+ will have redundant negative copies of each other.
76
+ eps: float, a small number used to determine symmetries.
77
+
78
+ Returns:
79
+ basis: a matrix with shape [3, n].
80
+ """
81
+ if base_shape == 'icosahedron':
82
+ a = (np.sqrt(5) + 1) / 2
83
+ verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1),
84
+ (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0),
85
+ (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2)
86
+ faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1),
87
+ (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3),
88
+ (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6),
89
+ (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5),
90
+ (7, 2, 11)])
91
+ verts = tesselate_geodesic(verts, faces, angular_tesselation)
92
+ elif base_shape == 'octahedron':
93
+ verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0),
94
+ (1, 0, 0)])
95
+ corners = np.array(list(itertools.product([-1, 1], repeat=3)))
96
+ pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)
97
+ faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)
98
+ verts = tesselate_geodesic(verts, faces, angular_tesselation)
99
+ else:
100
+ raise ValueError(f'base_shape {base_shape} not supported')
101
+
102
+ if remove_symmetries:
103
+ # Remove elements of `verts` that are reflections of each other.
104
+ match = compute_sq_dist(verts.T, -verts.T) < eps
105
+ verts = verts[np.any(np.triu(match), 1), :]
106
+
107
+ basis = verts[:, ::-1]
108
+ return basis
internal/image.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from internal import math
4
+ from skimage.metrics import structural_similarity, peak_signal_noise_ratio
5
+ import cv2
6
+
7
+
8
+ def mse_to_psnr(mse):
9
+ """Compute PSNR given an MSE (we assume the maximum pixel value is 1)."""
10
+ return -10. / np.log(10.) * np.log(mse)
11
+
12
+
13
+ def psnr_to_mse(psnr):
14
+ """Compute MSE given a PSNR (we assume the maximum pixel value is 1)."""
15
+ return np.exp(-0.1 * np.log(10.) * psnr)
16
+
17
+
18
+ def ssim_to_dssim(ssim):
19
+ """Compute DSSIM given an SSIM."""
20
+ return (1 - ssim) / 2
21
+
22
+
23
+ def dssim_to_ssim(dssim):
24
+ """Compute DSSIM given an SSIM."""
25
+ return 1 - 2 * dssim
26
+
27
+
28
+ def linear_to_srgb(linear, eps=None):
29
+ """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
30
+ if eps is None:
31
+ eps = torch.finfo(linear.dtype).eps
32
+ # eps = 1e-3
33
+
34
+ srgb0 = 323 / 25 * linear
35
+ srgb1 = (211 * linear.clamp_min(eps) ** (5 / 12) - 11) / 200
36
+ return torch.where(linear <= 0.0031308, srgb0, srgb1)
37
+
38
+
39
+ def linear_to_srgb_np(linear, eps=None):
40
+ """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
41
+ if eps is None:
42
+ eps = np.finfo(linear.dtype).eps
43
+ srgb0 = 323 / 25 * linear
44
+ srgb1 = (211 * np.maximum(eps, linear) ** (5 / 12) - 11) / 200
45
+ return np.where(linear <= 0.0031308, srgb0, srgb1)
46
+
47
+
48
+ def srgb_to_linear(srgb, eps=None):
49
+ """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
50
+ if eps is None:
51
+ eps = np.finfo(srgb.dtype).eps
52
+ linear0 = 25 / 323 * srgb
53
+ linear1 = np.maximum(eps, ((200 * srgb + 11) / (211))) ** (12 / 5)
54
+ return np.where(srgb <= 0.04045, linear0, linear1)
55
+
56
+
57
+ def downsample(img, factor):
58
+ """Area downsample img (factor must evenly divide img height and width)."""
59
+ sh = img.shape
60
+ if not (sh[0] % factor == 0 and sh[1] % factor == 0):
61
+ raise ValueError(f'Downsampling factor {factor} does not '
62
+ f'evenly divide image shape {sh[:2]}')
63
+ img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:])
64
+ img = img.mean((1, 3))
65
+ return img
66
+
67
+
68
+ def color_correct(img, ref, num_iters=5, eps=0.5 / 255):
69
+ """Warp `img` to match the colors in `ref_img`."""
70
+ if img.shape[-1] != ref.shape[-1]:
71
+ raise ValueError(
72
+ f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match'
73
+ )
74
+ num_channels = img.shape[-1]
75
+ img_mat = img.reshape([-1, num_channels])
76
+ ref_mat = ref.reshape([-1, num_channels])
77
+ is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps].
78
+ mask0 = is_unclipped(img_mat)
79
+ # Because the set of saturated pixels may change after solving for a
80
+ # transformation, we repeatedly solve a system `num_iters` times and update
81
+ # our estimate of which pixels are saturated.
82
+ for _ in range(num_iters):
83
+ # Construct the left hand side of a linear system that contains a quadratic
84
+ # expansion of each pixel of `img`.
85
+ a_mat = []
86
+ for c in range(num_channels):
87
+ a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term.
88
+ a_mat.append(img_mat) # Linear term.
89
+ a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term.
90
+ a_mat = torch.cat(a_mat, dim=-1)
91
+ warp = []
92
+ for c in range(num_channels):
93
+ # Construct the right hand side of a linear system containing each color
94
+ # of `ref`.
95
+ b = ref_mat[:, c]
96
+ # Ignore rows of the linear system that were saturated in the input or are
97
+ # saturated in the current corrected color estimate.
98
+ mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
99
+ ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
100
+ mb = torch.where(mask, b, torch.zeros_like(b))
101
+ w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
102
+ assert torch.all(torch.isfinite(w))
103
+ warp.append(w)
104
+ warp = torch.stack(warp, dim=-1)
105
+ # Apply the warp to update img_mat.
106
+ img_mat = torch.clip(math.matmul(a_mat, warp), 0, 1)
107
+ corrected_img = torch.reshape(img_mat, img.shape)
108
+ return corrected_img
109
+
110
+
111
+ class MetricHarness:
112
+ """A helper class for evaluating several error metrics."""
113
+
114
+ def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s):
115
+ """Evaluate the error between a predicted rgb image and the true image."""
116
+ rgb_pred = (rgb_pred * 255).astype(np.uint8)
117
+ rgb_gt = (rgb_gt * 255).astype(np.uint8)
118
+ rgb_pred_gray = cv2.cvtColor(rgb_pred, cv2.COLOR_RGB2GRAY)
119
+ rgb_gt_gray = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY)
120
+ psnr = float(peak_signal_noise_ratio(rgb_pred, rgb_gt, data_range=255))
121
+ ssim = float(structural_similarity(rgb_pred_gray, rgb_gt_gray, data_range=255))
122
+
123
+ return {
124
+ name_fn('psnr'): psnr,
125
+ name_fn('ssim'): ssim,
126
+ }
internal/math.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ @torch.jit.script
6
+ def erf(x):
7
+ return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2))
8
+
9
+
10
+ def matmul(a, b):
11
+ return (a[..., None] * b[..., None, :, :]).sum(dim=-2)
12
+ # B,3,4,1 B,1,4,3
13
+
14
+ # cause nan when fp16
15
+ # return torch.matmul(a, b)
16
+
17
+
18
+ def safe_trig_helper(x, fn, t=100 * torch.pi):
19
+ """Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
20
+ return fn(torch.where(torch.abs(x) < t, x, x % t))
21
+
22
+
23
+ def safe_cos(x):
24
+ return safe_trig_helper(x, torch.cos)
25
+
26
+
27
+ def safe_sin(x):
28
+ return safe_trig_helper(x, torch.sin)
29
+
30
+
31
+ def safe_exp(x):
32
+ return torch.exp(x.clamp_max(88.))
33
+
34
+
35
+ def safe_exp_jvp(primals, tangents):
36
+ """Override safe_exp()'s gradient so that it's large when inputs are large."""
37
+ x, = primals
38
+ x_dot, = tangents
39
+ exp_x = safe_exp(x)
40
+ exp_x_dot = exp_x * x_dot
41
+ return exp_x, exp_x_dot
42
+
43
+
44
+ def log_lerp(t, v0, v1):
45
+ """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
46
+ if v0 <= 0 or v1 <= 0:
47
+ raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
48
+ lv0 = np.log(v0)
49
+ lv1 = np.log(v1)
50
+ return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0)
51
+
52
+
53
+ def learning_rate_decay(step,
54
+ lr_init,
55
+ lr_final,
56
+ max_steps,
57
+ lr_delay_steps=0,
58
+ lr_delay_mult=1):
59
+ """Continuous learning rate decay function.
60
+
61
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
62
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
63
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
64
+ function of lr_delay_mult, such that the initial learning rate is
65
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
66
+ to the normal learning rate when steps>lr_delay_steps.
67
+
68
+ Args:
69
+ step: int, the current optimization step.
70
+ lr_init: float, the initial learning rate.
71
+ lr_final: float, the final learning rate.
72
+ max_steps: int, the number of steps during optimization.
73
+ lr_delay_steps: int, the number of steps to delay the full learning rate.
74
+ lr_delay_mult: float, the multiplier on the rate when delaying it.
75
+
76
+ Returns:
77
+ lr: the learning for current step 'step'.
78
+ """
79
+ if lr_delay_steps > 0:
80
+ # A kind of reverse cosine decay.
81
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
82
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
83
+ else:
84
+ delay_rate = 1.
85
+ return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)
86
+
87
+
88
+ def sorted_interp(x, xp, fp):
89
+ """A TPU-friendly version of interp(), where xp and fp must be sorted."""
90
+
91
+ # Identify the location in `xp` that corresponds to each `x`.
92
+ # The final `True` index in `mask` is the start of the matching interval.
93
+ mask = x[..., None, :] >= xp[..., :, None]
94
+
95
+ def find_interval(x):
96
+ # Grab the value where `mask` switches from True to False, and vice versa.
97
+ # This approach takes advantage of the fact that `x` is sorted.
98
+ x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values
99
+ x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values
100
+ return x0, x1
101
+
102
+ fp0, fp1 = find_interval(fp)
103
+ xp0, xp1 = find_interval(xp)
104
+
105
+ offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
106
+ ret = fp0 + offset * (fp1 - fp0)
107
+ return ret
108
+
109
+
110
+ def sorted_interp_quad(x, xp, fpdf, fcdf):
111
+ """interp in quadratic"""
112
+
113
+ # Identify the location in `xp` that corresponds to each `x`.
114
+ # The final `True` index in `mask` is the start of the matching interval.
115
+ mask = x[..., None, :] >= xp[..., :, None]
116
+
117
+ def find_interval(x, return_idx=False):
118
+ # Grab the value where `mask` switches from True to False, and vice versa.
119
+ # This approach takes advantage of the fact that `x` is sorted.
120
+ x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2)
121
+ x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2)
122
+ if return_idx:
123
+ return x0, x1, x0_idx, x1_idx
124
+ return x0, x1
125
+
126
+ fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True)
127
+ fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1)
128
+ fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1)
129
+ xp0, xp1 = find_interval(xp)
130
+
131
+ offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
132
+ ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2
133
+ return ret
internal/models.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import accelerate
2
+ import gin
3
+ from internal import coord
4
+ from internal import geopoly
5
+ from internal import image
6
+ from internal import math
7
+ from internal import ref_utils
8
+ from internal import train_utils
9
+ from internal import render
10
+ from internal import stepfun
11
+ from internal import utils
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.utils._pytree import tree_map
17
+ from tqdm import tqdm
18
+ from gridencoder import GridEncoder
19
+ from torch_scatter import segment_coo
20
+
21
+ gin.config.external_configurable(math.safe_exp, module='math')
22
+
23
+
24
+ def set_kwargs(self, kwargs):
25
+ for k, v in kwargs.items():
26
+ setattr(self, k, v)
27
+
28
+
29
+ @gin.configurable
30
+ class Model(nn.Module):
31
+ """A mip-Nerf360 model containing all MLPs."""
32
+ num_prop_samples: int = 64 # The number of samples for each proposal level.
33
+ num_nerf_samples: int = 32 # The number of samples the final nerf level.
34
+ num_levels: int = 3 # The number of sampling levels (3==2 proposals, 1 nerf).
35
+ bg_intensity_range = (1., 1.) # The range of background colors.
36
+ anneal_slope: float = 10 # Higher = more rapid annealing.
37
+ stop_level_grad: bool = True # If True, don't backprop across levels.
38
+ use_viewdirs: bool = True # If True, use view directions as input.
39
+ raydist_fn = None # The curve used for ray dists.
40
+ single_jitter: bool = True # If True, jitter whole rays instead of samples.
41
+ dilation_multiplier: float = 0.5 # How much to dilate intervals relatively.
42
+ dilation_bias: float = 0.0025 # How much to dilate intervals absolutely.
43
+ num_glo_features: int = 0 # GLO vector length, disabled if 0.
44
+ num_glo_embeddings: int = 1000 # Upper bound on max number of train images.
45
+ learned_exposure_scaling: bool = False # Learned exposure scaling (RawNeRF).
46
+ near_anneal_rate = None # How fast to anneal in near bound.
47
+ near_anneal_init: float = 0.95 # Where to initialize near bound (in [0, 1]).
48
+ single_mlp: bool = False # Use the NerfMLP for all rounds of sampling.
49
+ distinct_prop: bool = True # Use the NerfMLP for all rounds of sampling.
50
+ resample_padding: float = 0.0 # Dirichlet/alpha "padding" on the histogram.
51
+ opaque_background: bool = False # If true, make the background opaque.
52
+ power_lambda: float = -1.5
53
+ std_scale: float = 0.5
54
+ prop_desired_grid_size = [512, 2048]
55
+
56
+ def __init__(self, config=None, **kwargs):
57
+ super().__init__()
58
+ set_kwargs(self, kwargs)
59
+ self.config = config
60
+
61
+ # Construct MLPs. WARNING: Construction order may matter, if MLP weights are
62
+ # being regularized.
63
+ self.nerf_mlp = NerfMLP(num_glo_features=self.num_glo_features,
64
+ num_glo_embeddings=self.num_glo_embeddings)
65
+ if self.single_mlp:
66
+ self.prop_mlp = self.nerf_mlp
67
+ elif not self.distinct_prop:
68
+ self.prop_mlp = PropMLP()
69
+ else:
70
+ for i in range(self.num_levels - 1):
71
+ self.register_module(f'prop_mlp_{i}', PropMLP(grid_disired_resolution=self.prop_desired_grid_size[i]))
72
+ if self.num_glo_features > 0 and not config.zero_glo:
73
+ # Construct/grab GLO vectors for the cameras of each input ray.
74
+ self.glo_vecs = nn.Embedding(self.num_glo_embeddings, self.num_glo_features)
75
+
76
+ if self.learned_exposure_scaling:
77
+ # Setup learned scaling factors for output colors.
78
+ max_num_exposures = self.num_glo_embeddings
79
+ # Initialize the learned scaling offsets at 0.
80
+ self.exposure_scaling_offsets = nn.Embedding(max_num_exposures, 3)
81
+ torch.nn.init.zeros_(self.exposure_scaling_offsets.weight)
82
+
83
+ def forward(
84
+ self,
85
+ rand,
86
+ batch,
87
+ train_frac,
88
+ compute_extras,
89
+ zero_glo=True,
90
+ ):
91
+ """The mip-NeRF Model.
92
+
93
+ Args:
94
+ rand: random number generator (or None for deterministic output).
95
+ batch: util.Rays, a pytree of ray origins, directions, and viewdirs.
96
+ train_frac: float in [0, 1], what fraction of training is complete.
97
+ compute_extras: bool, if True, compute extra quantities besides color.
98
+ zero_glo: bool, if True, when using GLO pass in vector of zeros.
99
+
100
+ Returns:
101
+ ret: list, [*(rgb, distance, acc)]
102
+ """
103
+ device = batch['origins'].device
104
+ if self.num_glo_features > 0:
105
+ if not zero_glo:
106
+ # Construct/grab GLO vectors for the cameras of each input ray.
107
+ cam_idx = batch['cam_idx'][..., 0]
108
+ glo_vec = self.glo_vecs(cam_idx.long())
109
+ else:
110
+ glo_vec = torch.zeros(batch['origins'].shape[:-1] + (self.num_glo_features,), device=device)
111
+ else:
112
+ glo_vec = None
113
+
114
+ # Define the mapping from normalized to metric ray distance.
115
+ _, s_to_t = coord.construct_ray_warps(self.raydist_fn, batch['near'], batch['far'], self.power_lambda)
116
+
117
+ # Initialize the range of (normalized) distances for each ray to [0, 1],
118
+ # and assign that single interval a weight of 1. These distances and weights
119
+ # will be repeatedly updated as we proceed through sampling levels.
120
+ # `near_anneal_rate` can be used to anneal in the near bound at the start
121
+ # of training, eg. 0.1 anneals in the bound over the first 10% of training.
122
+ if self.near_anneal_rate is None:
123
+ init_s_near = 0.
124
+ else:
125
+ init_s_near = np.clip(1 - train_frac / self.near_anneal_rate, 0,
126
+ self.near_anneal_init)
127
+ init_s_far = 1.
128
+ sdist = torch.cat([
129
+ torch.full_like(batch['near'], init_s_near),
130
+ torch.full_like(batch['far'], init_s_far)
131
+ ], dim=-1)
132
+ weights = torch.ones_like(batch['near'])
133
+ prod_num_samples = 1
134
+
135
+ ray_history = []
136
+ renderings = []
137
+ for i_level in range(self.num_levels):
138
+ is_prop = i_level < (self.num_levels - 1)
139
+ num_samples = self.num_prop_samples if is_prop else self.num_nerf_samples
140
+
141
+ # Dilate by some multiple of the expected span of each current interval,
142
+ # with some bias added in.
143
+ dilation = self.dilation_bias + self.dilation_multiplier * (
144
+ init_s_far - init_s_near) / prod_num_samples
145
+
146
+ # Record the product of the number of samples seen so far.
147
+ prod_num_samples *= num_samples
148
+
149
+ # After the first level (where dilation would be a no-op) optionally
150
+ # dilate the interval weights along each ray slightly so that they're
151
+ # overestimates, which can reduce aliasing.
152
+ use_dilation = self.dilation_bias > 0 or self.dilation_multiplier > 0
153
+ if i_level > 0 and use_dilation:
154
+ sdist, weights = stepfun.max_dilate_weights(
155
+ sdist,
156
+ weights,
157
+ dilation,
158
+ domain=(init_s_near, init_s_far),
159
+ renormalize=True)
160
+ sdist = sdist[..., 1:-1]
161
+ weights = weights[..., 1:-1]
162
+
163
+ # Optionally anneal the weights as a function of training iteration.
164
+ if self.anneal_slope > 0:
165
+ # Schlick's bias function, see https://arxiv.org/abs/2010.09714
166
+ bias = lambda x, s: (s * x) / ((s - 1) * x + 1)
167
+ anneal = bias(train_frac, self.anneal_slope)
168
+ else:
169
+ anneal = 1.
170
+
171
+ # A slightly more stable way to compute weights**anneal. If the distance
172
+ # between adjacent intervals is zero then its weight is fixed to 0.
173
+ logits_resample = torch.where(
174
+ sdist[..., 1:] > sdist[..., :-1],
175
+ anneal * torch.log(weights + self.resample_padding),
176
+ torch.full_like(sdist[..., :-1], -torch.inf))
177
+
178
+ # Draw sampled intervals from each ray's current weights.
179
+ sdist = stepfun.sample_intervals(
180
+ rand,
181
+ sdist,
182
+ logits_resample,
183
+ num_samples,
184
+ single_jitter=self.single_jitter,
185
+ domain=(init_s_near, init_s_far))
186
+
187
+ # Optimization will usually go nonlinear if you propagate gradients
188
+ # through sampling.
189
+ if self.stop_level_grad:
190
+ sdist = sdist.detach()
191
+
192
+ # Convert normalized distances to metric distances.
193
+ tdist = s_to_t(sdist)
194
+
195
+ # Cast our rays, by turning our distance intervals into Gaussians.
196
+ means, stds, ts = render.cast_rays(
197
+ tdist,
198
+ batch['origins'],
199
+ batch['directions'],
200
+ batch['cam_dirs'],
201
+ batch['radii'],
202
+ rand,
203
+ std_scale=self.std_scale)
204
+
205
+ # Push our Gaussians through one of our two MLPs.
206
+ mlp = (self.get_submodule(
207
+ f'prop_mlp_{i_level}') if self.distinct_prop else self.prop_mlp) if is_prop else self.nerf_mlp
208
+ ray_results = mlp(
209
+ rand,
210
+ means, stds,
211
+ viewdirs=batch['viewdirs'] if self.use_viewdirs else None,
212
+ imageplane=batch.get('imageplane'),
213
+ glo_vec=None if is_prop else glo_vec,
214
+ exposure=batch.get('exposure_values'),
215
+ )
216
+ if self.config.gradient_scaling:
217
+ ray_results['rgb'], ray_results['density'] = train_utils.GradientScaler.apply(
218
+ ray_results['rgb'], ray_results['density'], ts.mean(dim=-1))
219
+
220
+ # Get the weights used by volumetric rendering (and our other losses).
221
+ weights = render.compute_alpha_weights(
222
+ ray_results['density'],
223
+ tdist,
224
+ batch['directions'],
225
+ opaque_background=self.opaque_background,
226
+ )[0]
227
+
228
+ # Define or sample the background color for each ray.
229
+ if self.bg_intensity_range[0] == self.bg_intensity_range[1]:
230
+ # If the min and max of the range are equal, just take it.
231
+ bg_rgbs = self.bg_intensity_range[0]
232
+ elif rand is None:
233
+ # If rendering is deterministic, use the midpoint of the range.
234
+ bg_rgbs = (self.bg_intensity_range[0] + self.bg_intensity_range[1]) / 2
235
+ else:
236
+ # Sample RGB values from the range for each ray.
237
+ minval = self.bg_intensity_range[0]
238
+ maxval = self.bg_intensity_range[1]
239
+ bg_rgbs = torch.rand(weights.shape[:-1] + (3,), device=device) * (maxval - minval) + minval
240
+
241
+ # RawNeRF exposure logic.
242
+ if batch.get('exposure_idx') is not None:
243
+ # Scale output colors by the exposure.
244
+ ray_results['rgb'] *= batch['exposure_values'][..., None, :]
245
+ if self.learned_exposure_scaling:
246
+ exposure_idx = batch['exposure_idx'][..., 0]
247
+ # Force scaling offset to always be zero when exposure_idx is 0.
248
+ # This constraint fixes a reference point for the scene's brightness.
249
+ mask = exposure_idx > 0
250
+ # Scaling is parameterized as an offset from 1.
251
+ scaling = 1 + mask[..., None] * self.exposure_scaling_offsets(exposure_idx.long())
252
+ ray_results['rgb'] *= scaling[..., None, :]
253
+
254
+ # Render each ray.
255
+ rendering = render.volumetric_rendering(
256
+ ray_results['rgb'],
257
+ weights,
258
+ tdist,
259
+ bg_rgbs,
260
+ batch['far'],
261
+ compute_extras,
262
+ extras={
263
+ k: v
264
+ for k, v in ray_results.items()
265
+ if k.startswith('normals') or k in ['roughness']
266
+ })
267
+
268
+ if compute_extras:
269
+ # Collect some rays to visualize directly. By naming these quantities
270
+ # with `ray_` they get treated differently downstream --- they're
271
+ # treated as bags of rays, rather than image chunks.
272
+ n = self.config.vis_num_rays
273
+ rendering['ray_sdist'] = sdist.reshape([-1, sdist.shape[-1]])[:n, :]
274
+ rendering['ray_weights'] = (
275
+ weights.reshape([-1, weights.shape[-1]])[:n, :])
276
+ rgb = ray_results['rgb']
277
+ rendering['ray_rgbs'] = (rgb.reshape((-1,) + rgb.shape[-2:]))[:n, :, :]
278
+
279
+ if self.training:
280
+ # Compute the hash decay loss for this level.
281
+ idx = mlp.encoder.idx
282
+ param = mlp.encoder.embeddings
283
+ loss_hash_decay = segment_coo(param ** 2,
284
+ idx,
285
+ torch.zeros(idx.max() + 1, param.shape[-1], device=param.device),
286
+ reduce='mean'
287
+ ).mean()
288
+ ray_results['loss_hash_decay'] = loss_hash_decay
289
+
290
+ renderings.append(rendering)
291
+ ray_results['sdist'] = sdist.clone()
292
+ ray_results['weights'] = weights.clone()
293
+ ray_history.append(ray_results)
294
+
295
+ if compute_extras:
296
+ # Because the proposal network doesn't produce meaningful colors, for
297
+ # easier visualization we replace their colors with the final average
298
+ # color.
299
+ weights = [r['ray_weights'] for r in renderings]
300
+ rgbs = [r['ray_rgbs'] for r in renderings]
301
+ final_rgb = torch.sum(rgbs[-1] * weights[-1][..., None], dim=-2)
302
+ avg_rgbs = [
303
+ torch.broadcast_to(final_rgb[:, None, :], r.shape) for r in rgbs[:-1]
304
+ ]
305
+ for i in range(len(avg_rgbs)):
306
+ renderings[i]['ray_rgbs'] = avg_rgbs[i]
307
+
308
+ return renderings, ray_history
309
+
310
+
311
+ class MLP(nn.Module):
312
+ """A PosEnc MLP."""
313
+ bottleneck_width: int = 256 # The width of the bottleneck vector.
314
+ net_depth_viewdirs: int = 2 # The depth of the second part of ML.
315
+ net_width_viewdirs: int = 256 # The width of the second part of MLP.
316
+ skip_layer_dir: int = 0 # Add a skip connection to 2nd MLP after Nth layers.
317
+ num_rgb_channels: int = 3 # The number of RGB channels.
318
+ deg_view: int = 4 # Degree of encoding for viewdirs or refdirs.
319
+ use_reflections: bool = False # If True, use refdirs instead of viewdirs.
320
+ use_directional_enc: bool = False # If True, use IDE to encode directions.
321
+ # If False and if use_directional_enc is True, use zero roughness in IDE.
322
+ enable_pred_roughness: bool = False
323
+ roughness_bias: float = -1. # Shift added to raw roughness pre-activation.
324
+ use_diffuse_color: bool = False # If True, predict diffuse & specular colors.
325
+ use_specular_tint: bool = False # If True, predict tint.
326
+ use_n_dot_v: bool = False # If True, feed dot(n * viewdir) to 2nd MLP.
327
+ bottleneck_noise: float = 0.0 # Std. deviation of noise added to bottleneck.
328
+ density_bias: float = -1. # Shift added to raw densities pre-activation.
329
+ density_noise: float = 0. # Standard deviation of noise added to raw density.
330
+ rgb_premultiplier: float = 1. # Premultiplier on RGB before activation.
331
+ rgb_bias: float = 0. # The shift added to raw colors pre-activation.
332
+ rgb_padding: float = 0.001 # Padding added to the RGB outputs.
333
+ enable_pred_normals: bool = False # If True compute predicted normals.
334
+ disable_density_normals: bool = False # If True don't compute normals.
335
+ disable_rgb: bool = False # If True don't output RGB.
336
+ warp_fn = 'contract'
337
+ num_glo_features: int = 0 # GLO vector length, disabled if 0.
338
+ num_glo_embeddings: int = 1000 # Upper bound on max number of train images.
339
+ scale_featurization: bool = False
340
+ grid_num_levels: int = 10
341
+ grid_level_interval: int = 2
342
+ grid_level_dim: int = 4
343
+ grid_base_resolution: int = 16
344
+ grid_disired_resolution: int = 8192
345
+ grid_log2_hashmap_size: int = 21
346
+ net_width_glo: int = 128 # The width of the second part of MLP.
347
+ net_depth_glo: int = 2 # The width of the second part of MLP.
348
+
349
+ def __init__(self, **kwargs):
350
+ super().__init__()
351
+ set_kwargs(self, kwargs)
352
+ # Make sure that normals are computed if reflection direction is used.
353
+ if self.use_reflections and not (self.enable_pred_normals or
354
+ not self.disable_density_normals):
355
+ raise ValueError('Normals must be computed for reflection directions.')
356
+
357
+ # Precompute and define viewdir or refdir encoding function.
358
+ if self.use_directional_enc:
359
+ self.dir_enc_fn = ref_utils.generate_ide_fn(self.deg_view)
360
+ dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), torch.zeros(1, 1)).shape[-1]
361
+ else:
362
+
363
+ def dir_enc_fn(direction, _):
364
+ return coord.pos_enc(
365
+ direction, min_deg=0, max_deg=self.deg_view, append_identity=True)
366
+
367
+ self.dir_enc_fn = dir_enc_fn
368
+ dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), None).shape[-1]
369
+ self.grid_num_levels = int(
370
+ np.log(self.grid_disired_resolution / self.grid_base_resolution) / np.log(self.grid_level_interval)) + 1
371
+ self.encoder = GridEncoder(input_dim=3,
372
+ num_levels=self.grid_num_levels,
373
+ level_dim=self.grid_level_dim,
374
+ base_resolution=self.grid_base_resolution,
375
+ desired_resolution=self.grid_disired_resolution,
376
+ log2_hashmap_size=self.grid_log2_hashmap_size,
377
+ gridtype='hash',
378
+ align_corners=False)
379
+ last_dim = self.encoder.output_dim
380
+ if self.scale_featurization:
381
+ last_dim += self.encoder.num_levels
382
+ self.density_layer = nn.Sequential(nn.Linear(last_dim, 64),
383
+ nn.ReLU(),
384
+ nn.Linear(64,
385
+ 1 if self.disable_rgb else self.bottleneck_width)) # Hardcoded to a single channel.
386
+ last_dim = 1 if self.disable_rgb and not self.enable_pred_normals else self.bottleneck_width
387
+ if self.enable_pred_normals:
388
+ self.normal_layer = nn.Linear(last_dim, 3)
389
+
390
+ if not self.disable_rgb:
391
+ if self.use_diffuse_color:
392
+ self.diffuse_layer = nn.Linear(last_dim, self.num_rgb_channels)
393
+
394
+ if self.use_specular_tint:
395
+ self.specular_layer = nn.Linear(last_dim, 3)
396
+
397
+ if self.enable_pred_roughness:
398
+ self.roughness_layer = nn.Linear(last_dim, 1)
399
+
400
+ # Output of the first part of MLP.
401
+ if self.bottleneck_width > 0:
402
+ last_dim_rgb = self.bottleneck_width
403
+ else:
404
+ last_dim_rgb = 0
405
+
406
+ last_dim_rgb += dim_dir_enc
407
+
408
+ if self.use_n_dot_v:
409
+ last_dim_rgb += 1
410
+
411
+ if self.num_glo_features > 0:
412
+ last_dim_glo = self.num_glo_features
413
+ for i in range(self.net_depth_glo - 1):
414
+ self.register_module(f"lin_glo_{i}", nn.Linear(last_dim_glo, self.net_width_glo))
415
+ last_dim_glo = self.net_width_glo
416
+ self.register_module(f"lin_glo_{self.net_depth_glo - 1}",
417
+ nn.Linear(last_dim_glo, self.bottleneck_width * 2))
418
+
419
+ input_dim_rgb = last_dim_rgb
420
+ for i in range(self.net_depth_viewdirs):
421
+ lin = nn.Linear(last_dim_rgb, self.net_width_viewdirs)
422
+ torch.nn.init.kaiming_uniform_(lin.weight)
423
+ self.register_module(f"lin_second_stage_{i}", lin)
424
+ last_dim_rgb = self.net_width_viewdirs
425
+ if i == self.skip_layer_dir:
426
+ last_dim_rgb += input_dim_rgb
427
+ self.rgb_layer = nn.Linear(last_dim_rgb, self.num_rgb_channels)
428
+
429
+ def predict_density(self, means, stds, rand=False, no_warp=False):
430
+ """Helper function to output density."""
431
+ # Encode input positions
432
+ if self.warp_fn is not None and not no_warp:
433
+ means, stds = coord.track_linearize(self.warp_fn, means, stds)
434
+ # contract [-2, 2] to [-1, 1]
435
+ bound = 2
436
+ means = means / bound
437
+ stds = stds / bound
438
+ features = self.encoder(means, bound=1).unflatten(-1, (self.encoder.num_levels, -1))
439
+ weights = torch.erf(1 / torch.sqrt(8 * stds[..., None] ** 2 * self.encoder.grid_sizes ** 2))
440
+ features = (features * weights[..., None]).mean(dim=-3).flatten(-2, -1)
441
+ if self.scale_featurization:
442
+ with torch.no_grad():
443
+ vl2mean = segment_coo((self.encoder.embeddings ** 2).sum(-1),
444
+ self.encoder.idx,
445
+ torch.zeros(self.grid_num_levels, device=weights.device),
446
+ self.grid_num_levels,
447
+ reduce='mean'
448
+ )
449
+ featurized_w = (2 * weights.mean(dim=-2) - 1) * (self.encoder.init_std ** 2 + vl2mean).sqrt()
450
+ features = torch.cat([features, featurized_w], dim=-1)
451
+ x = self.density_layer(features)
452
+ raw_density = x[..., 0] # Hardcoded to a single channel.
453
+ # Add noise to regularize the density predictions if needed.
454
+ if rand and (self.density_noise > 0):
455
+ raw_density += self.density_noise * torch.randn_like(raw_density)
456
+ return raw_density, x, means.mean(dim=-2)
457
+
458
+ def forward(self,
459
+ rand,
460
+ means, stds,
461
+ viewdirs=None,
462
+ imageplane=None,
463
+ glo_vec=None,
464
+ exposure=None,
465
+ no_warp=False):
466
+ """Evaluate the MLP.
467
+
468
+ Args:
469
+ rand: if random .
470
+ means: [..., n, 3], coordinate means.
471
+ stds: [..., n], coordinate stds.
472
+ viewdirs: [..., 3], if not None, this variable will
473
+ be part of the input to the second part of the MLP concatenated with the
474
+ output vector of the first part of the MLP. If None, only the first part
475
+ of the MLP will be used with input x. In the original paper, this
476
+ variable is the view direction.
477
+ imageplane:[batch, 2], xy image plane coordinates
478
+ for each ray in the batch. Useful for image plane operations such as a
479
+ learned vignette mapping.
480
+ glo_vec: [..., num_glo_features], The GLO vector for each ray.
481
+ exposure: [..., 1], exposure value (shutter_speed * ISO) for each ray.
482
+
483
+ Returns:
484
+ rgb: [..., num_rgb_channels].
485
+ density: [...].
486
+ normals: [..., 3], or None.
487
+ normals_pred: [..., 3], or None.
488
+ roughness: [..., 1], or None.
489
+ """
490
+ if self.disable_density_normals:
491
+ raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp)
492
+ raw_grad_density = None
493
+ normals = None
494
+ else:
495
+ with torch.enable_grad():
496
+ means.requires_grad_(True)
497
+ raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp)
498
+ d_output = torch.ones_like(raw_density, requires_grad=False, device=raw_density.device)
499
+ raw_grad_density = torch.autograd.grad(
500
+ outputs=raw_density,
501
+ inputs=means,
502
+ grad_outputs=d_output,
503
+ create_graph=True,
504
+ retain_graph=True,
505
+ only_inputs=True)[0]
506
+ raw_grad_density = raw_grad_density.mean(-2)
507
+ # Compute normal vectors as negative normalized density gradient.
508
+ # We normalize the gradient of raw (pre-activation) density because
509
+ # it's the same as post-activation density, but is more numerically stable
510
+ # when the activation function has a steep or flat gradient.
511
+ normals = -ref_utils.l2_normalize(raw_grad_density)
512
+
513
+ if self.enable_pred_normals:
514
+ grad_pred = self.normal_layer(x)
515
+
516
+ # Normalize negative predicted gradients to get predicted normal vectors.
517
+ normals_pred = -ref_utils.l2_normalize(grad_pred)
518
+ normals_to_use = normals_pred
519
+ else:
520
+ grad_pred = None
521
+ normals_pred = None
522
+ normals_to_use = normals
523
+
524
+ # Apply bias and activation to raw density
525
+ density = F.softplus(raw_density + self.density_bias)
526
+
527
+ roughness = None
528
+ if self.disable_rgb:
529
+ rgb = torch.zeros(density.shape + (3,), device=density.device)
530
+ else:
531
+ if viewdirs is not None:
532
+ # Predict diffuse color.
533
+ if self.use_diffuse_color:
534
+ raw_rgb_diffuse = self.diffuse_layer(x)
535
+
536
+ if self.use_specular_tint:
537
+ tint = torch.sigmoid(self.specular_layer(x))
538
+
539
+ if self.enable_pred_roughness:
540
+ raw_roughness = self.roughness_layer(x)
541
+ roughness = (F.softplus(raw_roughness + self.roughness_bias))
542
+
543
+ # Output of the first part of MLP.
544
+ if self.bottleneck_width > 0:
545
+ bottleneck = x
546
+ # Add bottleneck noise.
547
+ if rand and (self.bottleneck_noise > 0):
548
+ bottleneck += self.bottleneck_noise * torch.randn_like(bottleneck)
549
+
550
+ # Append GLO vector if used.
551
+ if glo_vec is not None:
552
+ for i in range(self.net_depth_glo):
553
+ glo_vec = self.get_submodule(f"lin_glo_{i}")(glo_vec)
554
+ if i != self.net_depth_glo - 1:
555
+ glo_vec = F.relu(glo_vec)
556
+ glo_vec = torch.broadcast_to(glo_vec[..., None, :],
557
+ bottleneck.shape[:-1] + glo_vec.shape[-1:])
558
+ scale, shift = glo_vec.chunk(2, dim=-1)
559
+ bottleneck = bottleneck * torch.exp(scale) + shift
560
+
561
+ x = [bottleneck]
562
+ else:
563
+ x = []
564
+
565
+ # Encode view (or reflection) directions.
566
+ if self.use_reflections:
567
+ # Compute reflection directions. Note that we flip viewdirs before
568
+ # reflecting, because they point from the camera to the point,
569
+ # whereas ref_utils.reflect() assumes they point toward the camera.
570
+ # Returned refdirs then point from the point to the environment.
571
+ refdirs = ref_utils.reflect(-viewdirs[..., None, :], normals_to_use)
572
+ # Encode reflection directions.
573
+ dir_enc = self.dir_enc_fn(refdirs, roughness)
574
+ else:
575
+ # Encode view directions.
576
+ dir_enc = self.dir_enc_fn(viewdirs, roughness)
577
+ dir_enc = torch.broadcast_to(
578
+ dir_enc[..., None, :],
579
+ bottleneck.shape[:-1] + (dir_enc.shape[-1],))
580
+
581
+ # Append view (or reflection) direction encoding to bottleneck vector.
582
+ x.append(dir_enc)
583
+
584
+ # Append dot product between normal vectors and view directions.
585
+ if self.use_n_dot_v:
586
+ dotprod = torch.sum(
587
+ normals_to_use * viewdirs[..., None, :], dim=-1, keepdim=True)
588
+ x.append(dotprod)
589
+
590
+ # Concatenate bottleneck, directional encoding, and GLO.
591
+ x = torch.cat(x, dim=-1)
592
+ # Output of the second part of MLP.
593
+ inputs = x
594
+ for i in range(self.net_depth_viewdirs):
595
+ x = self.get_submodule(f"lin_second_stage_{i}")(x)
596
+ x = F.relu(x)
597
+ if i == self.skip_layer_dir:
598
+ x = torch.cat([x, inputs], dim=-1)
599
+ # If using diffuse/specular colors, then `rgb` is treated as linear
600
+ # specular color. Otherwise it's treated as the color itself.
601
+ rgb = torch.sigmoid(self.rgb_premultiplier *
602
+ self.rgb_layer(x) +
603
+ self.rgb_bias)
604
+
605
+ if self.use_diffuse_color:
606
+ # Initialize linear diffuse color around 0.25, so that the combined
607
+ # linear color is initialized around 0.5.
608
+ diffuse_linear = torch.sigmoid(raw_rgb_diffuse - np.log(3.0))
609
+ if self.use_specular_tint:
610
+ specular_linear = tint * rgb
611
+ else:
612
+ specular_linear = 0.5 * rgb
613
+
614
+ # Combine specular and diffuse components and tone map to sRGB.
615
+ rgb = torch.clip(image.linear_to_srgb(specular_linear + diffuse_linear), 0.0, 1.0)
616
+
617
+ # Apply padding, mapping color to [-rgb_padding, 1+rgb_padding].
618
+ rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding
619
+
620
+ return dict(
621
+ coord=means_contract,
622
+ density=density,
623
+ rgb=rgb,
624
+ raw_grad_density=raw_grad_density,
625
+ grad_pred=grad_pred,
626
+ normals=normals,
627
+ normals_pred=normals_pred,
628
+ roughness=roughness,
629
+ )
630
+
631
+
632
+ @gin.configurable
633
+ class NerfMLP(MLP):
634
+ pass
635
+
636
+
637
+ @gin.configurable
638
+ class PropMLP(MLP):
639
+ pass
640
+
641
+
642
+ @torch.no_grad()
643
+ def render_image(model,
644
+ accelerator: accelerate.Accelerator,
645
+ batch,
646
+ rand,
647
+ train_frac,
648
+ config,
649
+ verbose=True,
650
+ return_weights=False):
651
+ """Render all the pixels of an image (in test mode).
652
+
653
+ Args:
654
+ render_fn: function, jit-ed render function mapping (rand, batch) -> pytree.
655
+ accelerator: used for DDP.
656
+ batch: a `Rays` pytree, the rays to be rendered.
657
+ rand: if random
658
+ config: A Config class.
659
+
660
+ Returns:
661
+ rgb: rendered color image.
662
+ disp: rendered disparity image.
663
+ acc: rendered accumulated weights per pixel.
664
+ """
665
+ model.eval()
666
+
667
+ height, width = batch['origins'].shape[:2]
668
+ num_rays = height * width
669
+ batch = {k: v.reshape((num_rays, -1)) for k, v in batch.items() if v is not None}
670
+
671
+ global_rank = accelerator.process_index
672
+ chunks = []
673
+ idx0s = tqdm(range(0, num_rays, config.render_chunk_size),
674
+ desc="Rendering chunk", leave=False,
675
+ disable=not (accelerator.is_main_process and verbose))
676
+
677
+ for i_chunk, idx0 in enumerate(idx0s):
678
+ chunk_batch = tree_map(lambda r: r[idx0:idx0 + config.render_chunk_size], batch)
679
+ actual_chunk_size = chunk_batch['origins'].shape[0]
680
+ rays_remaining = actual_chunk_size % accelerator.num_processes
681
+ if rays_remaining != 0:
682
+ padding = accelerator.num_processes - rays_remaining
683
+ chunk_batch = tree_map(lambda v: torch.cat([v, torch.zeros_like(v[-padding:])], dim=0), chunk_batch)
684
+ else:
685
+ padding = 0
686
+ # After padding the number of chunk_rays is always divisible by host_count.
687
+ rays_per_host = chunk_batch['origins'].shape[0] // accelerator.num_processes
688
+ start, stop = global_rank * rays_per_host, (global_rank + 1) * rays_per_host
689
+ chunk_batch = tree_map(lambda r: r[start:stop], chunk_batch)
690
+
691
+ with accelerator.autocast():
692
+ chunk_renderings, ray_history = model(rand,
693
+ chunk_batch,
694
+ train_frac=train_frac,
695
+ compute_extras=True,
696
+ zero_glo=True)
697
+
698
+ gather = lambda v: accelerator.gather(v.contiguous())[:-padding] \
699
+ if padding > 0 else accelerator.gather(v.contiguous())
700
+ # Unshard the renderings.
701
+ chunk_renderings = tree_map(gather, chunk_renderings)
702
+
703
+ # Gather the final pass for 2D buffers and all passes for ray bundles.
704
+ chunk_rendering = chunk_renderings[-1]
705
+ for k in chunk_renderings[0]:
706
+ if k.startswith('ray_'):
707
+ chunk_rendering[k] = [r[k] for r in chunk_renderings]
708
+
709
+ if return_weights:
710
+ chunk_rendering['weights'] = gather(ray_history[-1]['weights'])
711
+ chunk_rendering['coord'] = gather(ray_history[-1]['coord'])
712
+ chunks.append(chunk_rendering)
713
+
714
+ # Concatenate all chunks within each leaf of a single pytree.
715
+ rendering = {}
716
+ for k in chunks[0].keys():
717
+ if isinstance(chunks[0][k], list):
718
+ rendering[k] = []
719
+ for i in range(len(chunks[0][k])):
720
+ rendering[k].append(torch.cat([item[k][i] for item in chunks]))
721
+ else:
722
+ rendering[k] = torch.cat([item[k] for item in chunks])
723
+
724
+ for k, z in rendering.items():
725
+ if not k.startswith('ray_'):
726
+ # Reshape 2D buffers into original image shape.
727
+ rendering[k] = z.reshape((height, width) + z.shape[1:])
728
+
729
+ # After all of the ray bundles have been concatenated together, extract a
730
+ # new random bundle (deterministically) from the concatenation that is the
731
+ # same size as one of the individual bundles.
732
+ keys = [k for k in rendering if k.startswith('ray_')]
733
+ if keys:
734
+ num_rays = rendering[keys[0]][0].shape[0]
735
+ ray_idx = torch.randperm(num_rays)
736
+ ray_idx = ray_idx[:config.vis_num_rays]
737
+ for k in keys:
738
+ rendering[k] = [r[ray_idx] for r in rendering[k]]
739
+ model.train()
740
+ return rendering
internal/pycolmap/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pyc
2
+ *.sw*
internal/pycolmap/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 True Price, UNC Chapel Hill
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
internal/pycolmap/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # pycolmap
2
+ Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions.
3
+
4
+ This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output.
internal/pycolmap/pycolmap/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from camera import Camera
2
+ from database import COLMAPDatabase
3
+ from image import Image
4
+ from scene_manager import SceneManager
5
+ from rotation import Quaternion, DualQuaternion
internal/pycolmap/pycolmap/camera.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: True Price <jtprice at cs.unc.edu>
2
+
3
+ import numpy as np
4
+
5
+ from scipy.optimize import root
6
+
7
+
8
+ #-------------------------------------------------------------------------------
9
+ #
10
+ # camera distortion functions for arrays of size (..., 2)
11
+ #
12
+ #-------------------------------------------------------------------------------
13
+
14
+ def simple_radial_distortion(camera, x):
15
+ return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True))
16
+
17
+ def radial_distortion(camera, x):
18
+ r_sq = np.square(x).sum(axis=-1, keepdims=True)
19
+ return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq))
20
+
21
+ def opencv_distortion(camera, x):
22
+ x_sq = np.square(x)
23
+ xy = np.prod(x, axis=-1, keepdims=True)
24
+ r_sq = x_sq.sum(axis=-1, keepdims=True)
25
+
26
+ return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate((
27
+ 2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq),
28
+ camera.p1 * (r_sq + 2. * y_sq) + 2. * camera.p2 * xy),
29
+ axis=-1)
30
+
31
+
32
+ #-------------------------------------------------------------------------------
33
+ #
34
+ # Camera
35
+ #
36
+ #-------------------------------------------------------------------------------
37
+
38
+ class Camera:
39
+ @staticmethod
40
+ def GetNumParams(type_):
41
+ if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
42
+ return 3
43
+ if type_ == 1 or type_ == 'PINHOLE':
44
+ return 4
45
+ if type_ == 2 or type_ == 'SIMPLE_RADIAL':
46
+ return 4
47
+ if type_ == 3 or type_ == 'RADIAL':
48
+ return 5
49
+ if type_ == 4 or type_ == 'OPENCV':
50
+ return 8
51
+ #if type_ == 5 or type_ == 'OPENCV_FISHEYE':
52
+ # return 8
53
+ #if type_ == 6 or type_ == 'FULL_OPENCV':
54
+ # return 12
55
+ #if type_ == 7 or type_ == 'FOV':
56
+ # return 5
57
+ #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE':
58
+ # return 4
59
+ #if type_ == 9 or type_ == 'RADIAL_FISHEYE':
60
+ # return 5
61
+ #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE':
62
+ # return 12
63
+
64
+ # TODO: not supporting other camera types, currently
65
+ raise Exception('Camera type not supported')
66
+
67
+
68
+ #---------------------------------------------------------------------------
69
+
70
+ @staticmethod
71
+ def GetNameFromType(type_):
72
+ if type_ == 0: return 'SIMPLE_PINHOLE'
73
+ if type_ == 1: return 'PINHOLE'
74
+ if type_ == 2: return 'SIMPLE_RADIAL'
75
+ if type_ == 3: return 'RADIAL'
76
+ if type_ == 4: return 'OPENCV'
77
+ #if type_ == 5: return 'OPENCV_FISHEYE'
78
+ #if type_ == 6: return 'FULL_OPENCV'
79
+ #if type_ == 7: return 'FOV'
80
+ #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE'
81
+ #if type_ == 9: return 'RADIAL_FISHEYE'
82
+ #if type_ == 10: return 'THIN_PRISM_FISHEYE'
83
+
84
+ raise Exception('Camera type not supported')
85
+
86
+
87
+ #---------------------------------------------------------------------------
88
+
89
+ def __init__(self, type_, width_, height_, params):
90
+ self.width = width_
91
+ self.height = height_
92
+
93
+ if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
94
+ self.fx, self.cx, self.cy = params
95
+ self.fy = self.fx
96
+ self.distortion_func = None
97
+ self.camera_type = 0
98
+
99
+ elif type_ == 1 or type_ == 'PINHOLE':
100
+ self.fx, self.fy, self.cx, self.cy = params
101
+ self.distortion_func = None
102
+ self.camera_type = 1
103
+
104
+ elif type_ == 2 or type_ == 'SIMPLE_RADIAL':
105
+ self.fx, self.cx, self.cy, self.k1 = params
106
+ self.fy = self.fx
107
+ self.distortion_func = simple_radial_distortion
108
+ self.camera_type = 2
109
+
110
+ elif type_ == 3 or type_ == 'RADIAL':
111
+ self.fx, self.cx, self.cy, self.k1, self.k2 = params
112
+ self.fy = self.fx
113
+ self.distortion_func = radial_distortion
114
+ self.camera_type = 3
115
+
116
+ elif type_ == 4 or type_ == 'OPENCV':
117
+ self.fx, self.fy, self.cx, self.cy = params[:4]
118
+ self.k1, self.k2, self.p1, self.p2 = params[4:]
119
+ self.distortion_func = opencv_distortion
120
+ self.camera_type = 4
121
+
122
+ else:
123
+ raise Exception('Camera type not supported')
124
+
125
+
126
+ #---------------------------------------------------------------------------
127
+
128
+ def __str__(self):
129
+ s = (self.GetNameFromType(self.camera_type) +
130
+ ' {} {} {}'.format(self.width, self.height, self.fx))
131
+
132
+ if self.camera_type in (1, 4): # PINHOLE, OPENCV
133
+ s += ' {}'.format(self.fy)
134
+
135
+ s += ' {} {}'.format(self.cx, self.cy)
136
+
137
+ if self.camera_type == 2: # SIMPLE_RADIAL
138
+ s += ' {}'.format(self.k1)
139
+
140
+ elif self.camera_type == 3: # RADIAL
141
+ s += ' {} {}'.format(self.k1, self.k2)
142
+
143
+ elif self.camera_type == 4: # OPENCV
144
+ s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2)
145
+
146
+ return s
147
+
148
+
149
+ #---------------------------------------------------------------------------
150
+
151
+ # return the camera parameters in the same order as the colmap output format
152
+ def get_params(self):
153
+ if self.camera_type == 0:
154
+ return np.array((self.fx, self.cx, self.cy))
155
+ if self.camera_type == 1:
156
+ return np.array((self.fx, self.fy, self.cx, self.cy))
157
+ if self.camera_type == 2:
158
+ return np.array((self.fx, self.cx, self.cy, self.k1))
159
+ if self.camera_type == 3:
160
+ return np.array((self.fx, self.cx, self.cy, self.k1, self.k2))
161
+ if self.camera_type == 4:
162
+ return np.array((self.fx, self.fy, self.cx, self.cy, self.k1,
163
+ self.k2, self.p1, self.p2))
164
+
165
+
166
+ #---------------------------------------------------------------------------
167
+
168
+ def get_camera_matrix(self):
169
+ return np.array(
170
+ ((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1)))
171
+
172
+ def get_inverse_camera_matrix(self):
173
+ return np.array(
174
+ ((1. / self.fx, 0, -self.cx / self.fx),
175
+ (0, 1. / self.fy, -self.cy / self.fy),
176
+ (0, 0, 1)))
177
+
178
+ @property
179
+ def K(self):
180
+ return self.get_camera_matrix()
181
+
182
+ @property
183
+ def K_inv(self):
184
+ return self.get_inverse_camera_matrix()
185
+
186
+ #---------------------------------------------------------------------------
187
+
188
+ # return the inverse camera matrix
189
+ def get_inv_camera_matrix(self):
190
+ inv_fx, inv_fy = 1. / self.fx, 1. / self.fy
191
+ return np.array(((inv_fx, 0, -inv_fx * self.cx),
192
+ (0, inv_fy, -inv_fy * self.cy),
193
+ (0, 0, 1)))
194
+
195
+
196
+ #---------------------------------------------------------------------------
197
+
198
+ # return an (x, y) pixel coordinate grid for this camera
199
+ def get_image_grid(self):
200
+ xmin = (0.5 - self.cx) / self.fx
201
+ xmax = (self.width - 0.5 - self.cx) / self.fx
202
+ ymin = (0.5 - self.cy) / self.fy
203
+ ymax = (self.height - 0.5 - self.cy) / self.fy
204
+ return np.meshgrid(np.linspace(xmin, xmax, self.width),
205
+ np.linspace(ymin, ymax, self.height))
206
+
207
+
208
+ #---------------------------------------------------------------------------
209
+
210
+ # x: array of shape (N,2) or (2,)
211
+ # normalized: False if the input points are in pixel coordinates
212
+ # denormalize: True if the points should be put back into pixel coordinates
213
+ def distort_points(self, x, normalized=True, denormalize=True):
214
+ x = np.atleast_2d(x)
215
+
216
+ # put the points into normalized camera coordinates
217
+ if not normalized:
218
+ x -= np.array([[self.cx, self.cy]])
219
+ x /= np.array([[self.fx, self.fy]])
220
+
221
+ # distort, if necessary
222
+ if self.distortion_func is not None:
223
+ x = self.distortion_func(self, x)
224
+
225
+ if denormalize:
226
+ x *= np.array([[self.fx, self.fy]])
227
+ x += np.array([[self.cx, self.cy]])
228
+
229
+ return x
230
+
231
+
232
+ #---------------------------------------------------------------------------
233
+
234
+ # x: array of shape (N1,N2,...,2), (N,2), or (2,)
235
+ # normalized: False if the input points are in pixel coordinates
236
+ # denormalize: True if the points should be put back into pixel coordinates
237
+ def undistort_points(self, x, normalized=False, denormalize=True):
238
+ x = np.atleast_2d(x)
239
+
240
+ # put the points into normalized camera coordinates
241
+ if not normalized:
242
+ x = x - np.array([self.cx, self.cy]) # creates a copy
243
+ x /= np.array([self.fx, self.fy])
244
+
245
+ # undistort, if necessary
246
+ if self.distortion_func is not None:
247
+ def objective(xu):
248
+ return (x - self.distortion_func(self, xu.reshape(*x.shape))
249
+ ).ravel()
250
+
251
+ xu = root(objective, x).x.reshape(*x.shape)
252
+ else:
253
+ xu = x
254
+
255
+ if denormalize:
256
+ xu *= np.array([[self.fx, self.fy]])
257
+ xu += np.array([[self.cx, self.cy]])
258
+
259
+ return xu
internal/pycolmap/pycolmap/database.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sqlite3
4
+
5
+
6
+ #-------------------------------------------------------------------------------
7
+ # convert SQLite BLOBs to/from numpy arrays
8
+
9
+ def array_to_blob(arr):
10
+ return np.getbuffer(arr)
11
+
12
+ def blob_to_array(blob, dtype, shape=(-1,)):
13
+ return np.frombuffer(blob, dtype).reshape(*shape)
14
+
15
+
16
+ #-------------------------------------------------------------------------------
17
+ # convert to/from image pair ids
18
+
19
+ MAX_IMAGE_ID = 2**31 - 1
20
+
21
+ def get_pair_id(image_id1, image_id2):
22
+ if image_id1 > image_id2:
23
+ image_id1, image_id2 = image_id2, image_id1
24
+ return image_id1 * MAX_IMAGE_ID + image_id2
25
+
26
+
27
+ def get_image_ids_from_pair_id(pair_id):
28
+ image_id2 = pair_id % MAX_IMAGE_ID
29
+ return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2
30
+
31
+
32
+ #-------------------------------------------------------------------------------
33
+ # create table commands
34
+
35
+ CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
36
+ camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
37
+ model INTEGER NOT NULL,
38
+ width INTEGER NOT NULL,
39
+ height INTEGER NOT NULL,
40
+ params BLOB,
41
+ prior_focal_length INTEGER NOT NULL)"""
42
+
43
+ CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
44
+ image_id INTEGER PRIMARY KEY NOT NULL,
45
+ rows INTEGER NOT NULL,
46
+ cols INTEGER NOT NULL,
47
+ data BLOB,
48
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
49
+
50
+ CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
51
+ image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
52
+ name TEXT NOT NULL UNIQUE,
53
+ camera_id INTEGER NOT NULL,
54
+ prior_qw REAL,
55
+ prior_qx REAL,
56
+ prior_qy REAL,
57
+ prior_qz REAL,
58
+ prior_tx REAL,
59
+ prior_ty REAL,
60
+ prior_tz REAL,
61
+ CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647),
62
+ FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))"""
63
+
64
+ CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries (
65
+ pair_id INTEGER PRIMARY KEY NOT NULL,
66
+ rows INTEGER NOT NULL,
67
+ cols INTEGER NOT NULL,
68
+ data BLOB,
69
+ config INTEGER NOT NULL,
70
+ F BLOB,
71
+ E BLOB,
72
+ H BLOB)"""
73
+
74
+ CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
75
+ image_id INTEGER PRIMARY KEY NOT NULL,
76
+ rows INTEGER NOT NULL,
77
+ cols INTEGER NOT NULL,
78
+ data BLOB,
79
+ FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
80
+
81
+ CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
82
+ pair_id INTEGER PRIMARY KEY NOT NULL,
83
+ rows INTEGER NOT NULL,
84
+ cols INTEGER NOT NULL,
85
+ data BLOB)"""
86
+
87
+ CREATE_NAME_INDEX = \
88
+ "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
89
+
90
+ CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE,
91
+ CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE,
92
+ CREATE_MATCHES_TABLE, CREATE_NAME_INDEX])
93
+
94
+
95
+ #-------------------------------------------------------------------------------
96
+ # functional interface for adding objects
97
+
98
+ def add_camera(db, model, width, height, params, prior_focal_length=False,
99
+ camera_id=None):
100
+ # TODO: Parameter count checks
101
+ params = np.asarray(params, np.float64)
102
+ db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
103
+ (camera_id, model, width, height, array_to_blob(params),
104
+ prior_focal_length))
105
+
106
+
107
+ def add_descriptors(db, image_id, descriptors):
108
+ descriptors = np.ascontiguousarray(descriptors, np.uint8)
109
+ db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)",
110
+ (image_id,) + descriptors.shape + (array_to_blob(descriptors),))
111
+
112
+
113
+ def add_image(db, name, camera_id, prior_q=np.zeros(4), prior_t=np.zeros(3),
114
+ image_id=None):
115
+ db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
116
+ (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
117
+ prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
118
+
119
+
120
+ # config: defaults to fundamental matrix
121
+ def add_inlier_matches(db, image_id1, image_id2, matches, config=2, F=None,
122
+ E=None, H=None):
123
+ assert(len(matches.shape) == 2)
124
+ assert(matches.shape[1] == 2)
125
+
126
+ if image_id1 > image_id2:
127
+ matches = matches[:,::-1]
128
+
129
+ if F is not None:
130
+ F = np.asarray(F, np.float64)
131
+ if E is not None:
132
+ E = np.asarray(E, np.float64)
133
+ if H is not None:
134
+ H = np.asarray(H, np.float64)
135
+
136
+ pair_id = get_pair_id(image_id1, image_id2)
137
+ matches = np.asarray(matches, np.uint32)
138
+ db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
139
+ (pair_id,) + matches.shape + (array_to_blob(matches), config, F, E, H))
140
+
141
+
142
+ def add_keypoints(db, image_id, keypoints):
143
+ assert(len(keypoints.shape) == 2)
144
+ assert(keypoints.shape[1] in [2, 4, 6])
145
+
146
+ keypoints = np.asarray(keypoints, np.float32)
147
+ db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)",
148
+ (image_id,) + keypoints.shape + (array_to_blob(keypoints),))
149
+
150
+
151
+ # config: defaults to fundamental matrix
152
+ def add_matches(db, image_id1, image_id2, matches):
153
+ assert(len(matches.shape) == 2)
154
+ assert(matches.shape[1] == 2)
155
+
156
+ if image_id1 > image_id2:
157
+ matches = matches[:,::-1]
158
+
159
+ pair_id = get_pair_id(image_id1, image_id2)
160
+ matches = np.asarray(matches, np.uint32)
161
+ db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)",
162
+ (pair_id,) + matches.shape + (array_to_blob(matches),))
163
+
164
+
165
+ #-------------------------------------------------------------------------------
166
+ # simple functional interface
167
+
168
+ class COLMAPDatabase(sqlite3.Connection):
169
+ @staticmethod
170
+ def connect(database_path):
171
+ return sqlite3.connect(database_path, factory=COLMAPDatabase)
172
+
173
+
174
+ def __init__(self, *args, **kwargs):
175
+ super(COLMAPDatabase, self).__init__(*args, **kwargs)
176
+
177
+ self.initialize_tables = lambda: self.executescript(CREATE_ALL)
178
+
179
+ self.initialize_cameras = \
180
+ lambda: self.executescript(CREATE_CAMERAS_TABLE)
181
+ self.initialize_descriptors = \
182
+ lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
183
+ self.initialize_images = \
184
+ lambda: self.executescript(CREATE_IMAGES_TABLE)
185
+ self.initialize_inlier_matches = \
186
+ lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE)
187
+ self.initialize_keypoints = \
188
+ lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
189
+ self.initialize_matches = \
190
+ lambda: self.executescript(CREATE_MATCHES_TABLE)
191
+
192
+ self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
193
+
194
+
195
+ add_camera = add_camera
196
+ add_descriptors = add_descriptors
197
+ add_image = add_image
198
+ add_inlier_matches = add_inlier_matches
199
+ add_keypoints = add_keypoints
200
+ add_matches = add_matches
201
+
202
+
203
+ #-------------------------------------------------------------------------------
204
+
205
+ def main(args):
206
+ import os
207
+
208
+ if os.path.exists(args.database_path):
209
+ print("Error: database path already exists -- will not modify it.")
210
+ exit()
211
+
212
+ db = COLMAPDatabase.connect(args.database_path)
213
+
214
+ #
215
+ # for convenience, try creating all the tables upfront
216
+ #
217
+
218
+ db.initialize_tables()
219
+
220
+
221
+ #
222
+ # create dummy cameras
223
+ #
224
+
225
+ model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.))
226
+ model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1))
227
+
228
+ db.add_camera(model1, w1, h1, params1)
229
+ db.add_camera(model2, w2, h2, params2)
230
+
231
+
232
+ #
233
+ # create dummy images
234
+ #
235
+
236
+ db.add_image("image1.png", 0)
237
+ db.add_image("image2.png", 0)
238
+ db.add_image("image3.png", 2)
239
+ db.add_image("image4.png", 2)
240
+
241
+
242
+ #
243
+ # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y),
244
+ # 4D keypoints (x, y, theta, scale), and 6D affine keypoints
245
+ # (x, y, a_11, a_12, a_21, a_22)
246
+ #
247
+
248
+ N = 1000
249
+ kp1 = np.random.rand(N, 2) * (1024., 768.)
250
+ kp2 = np.random.rand(N, 2) * (1024., 768.)
251
+ kp3 = np.random.rand(N, 2) * (1024., 768.)
252
+ kp4 = np.random.rand(N, 2) * (1024., 768.)
253
+
254
+ db.add_keypoints(1, kp1)
255
+ db.add_keypoints(2, kp2)
256
+ db.add_keypoints(3, kp3)
257
+ db.add_keypoints(4, kp4)
258
+
259
+
260
+ #
261
+ # create dummy matches
262
+ #
263
+
264
+ M = 50
265
+ m12 = np.random.randint(N, size=(M, 2))
266
+ m23 = np.random.randint(N, size=(M, 2))
267
+ m34 = np.random.randint(N, size=(M, 2))
268
+
269
+ db.add_matches(1, 2, m12)
270
+ db.add_matches(2, 3, m23)
271
+ db.add_matches(3, 4, m34)
272
+
273
+
274
+ #
275
+ # check cameras
276
+ #
277
+
278
+ rows = db.execute("SELECT * FROM cameras")
279
+
280
+ camera_id, model, width, height, params, prior = next(rows)
281
+ params = blob_to_array(params, np.float32)
282
+ assert model == model1 and width == w1 and height == h1
283
+ assert np.allclose(params, params1)
284
+
285
+ camera_id, model, width, height, params, prior = next(rows)
286
+ params = blob_to_array(params, np.float32)
287
+ assert model == model2 and width == w2 and height == h2
288
+ assert np.allclose(params, params2)
289
+
290
+
291
+ #
292
+ # check keypoints
293
+ #
294
+
295
+ kps = dict(
296
+ (image_id, blob_to_array(data, np.float32, (-1, 2)))
297
+ for image_id, data in db.execute(
298
+ "SELECT image_id, data FROM keypoints"))
299
+
300
+ assert np.allclose(kps[1], kp1)
301
+ assert np.allclose(kps[2], kp2)
302
+ assert np.allclose(kps[3], kp3)
303
+ assert np.allclose(kps[4], kp4)
304
+
305
+
306
+ #
307
+ # check matches
308
+ #
309
+
310
+ pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]]
311
+
312
+ matches = dict(
313
+ (get_image_ids_from_pair_id(pair_id),
314
+ blob_to_array(data, np.uint32, (-1, 2)))
315
+ for pair_id, data in db.execute("SELECT pair_id, data FROM matches"))
316
+
317
+ assert np.all(matches[(1, 2)] == m12)
318
+ assert np.all(matches[(2, 3)] == m23)
319
+ assert np.all(matches[(3, 4)] == m34)
320
+
321
+ #
322
+ # clean up
323
+ #
324
+
325
+ db.close()
326
+ os.remove(args.database_path)
327
+
328
+ #-------------------------------------------------------------------------------
329
+
330
+ if __name__ == "__main__":
331
+ import argparse
332
+
333
+ parser = argparse.ArgumentParser(
334
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
335
+
336
+ parser.add_argument("--database_path", type=str, default="database.db")
337
+
338
+ args = parser.parse_args()
339
+
340
+ main(args)
internal/pycolmap/pycolmap/image.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: True Price <jtprice at cs.unc.edu>
2
+
3
+ import numpy as np
4
+
5
+ #-------------------------------------------------------------------------------
6
+ #
7
+ # Image
8
+ #
9
+ #-------------------------------------------------------------------------------
10
+
11
+ class Image:
12
+ def __init__(self, name_, camera_id_, q_, tvec_):
13
+ self.name = name_
14
+ self.camera_id = camera_id_
15
+ self.q = q_
16
+ self.tvec = tvec_
17
+
18
+ self.points2D = np.empty((0, 2), dtype=np.float64)
19
+ self.point3D_ids = np.empty((0,), dtype=np.uint64)
20
+
21
+ #---------------------------------------------------------------------------
22
+
23
+ def R(self):
24
+ return self.q.ToR()
25
+
26
+ #---------------------------------------------------------------------------
27
+
28
+ def C(self):
29
+ return -self.R().T.dot(self.tvec)
30
+
31
+ #---------------------------------------------------------------------------
32
+
33
+ @property
34
+ def t(self):
35
+ return self.tvec
internal/pycolmap/pycolmap/rotation.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: True Price <jtprice at cs.unc.edu>
2
+
3
+ import numpy as np
4
+
5
+ #-------------------------------------------------------------------------------
6
+ #
7
+ # Axis-Angle Functions
8
+ #
9
+ #-------------------------------------------------------------------------------
10
+
11
+ # returns the cross product matrix representation of a 3-vector v
12
+ def cross_prod_matrix(v):
13
+ return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.)))
14
+
15
+ #-------------------------------------------------------------------------------
16
+
17
+ # www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/
18
+ # if angle is None, assume ||axis|| == angle, in radians
19
+ # if angle is not None, assume that axis is a unit vector
20
+ def axis_angle_to_rotation_matrix(axis, angle=None):
21
+ if angle is None:
22
+ angle = np.linalg.norm(axis)
23
+ if np.abs(angle) > np.finfo('float').eps:
24
+ axis = axis / angle
25
+
26
+ cp_axis = cross_prod_matrix(axis)
27
+ return np.eye(3) + (
28
+ np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis))
29
+
30
+ #-------------------------------------------------------------------------------
31
+
32
+ # after some deliberation, I've decided the easiest way to do this is to use
33
+ # quaternions as an intermediary
34
+ def rotation_matrix_to_axis_angle(R):
35
+ return Quaternion.FromR(R).ToAxisAngle()
36
+
37
+ #-------------------------------------------------------------------------------
38
+ #
39
+ # Quaternion
40
+ #
41
+ #-------------------------------------------------------------------------------
42
+
43
+ class Quaternion:
44
+ # create a quaternion from an existing rotation matrix
45
+ # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/
46
+ @staticmethod
47
+ def FromR(R):
48
+ trace = np.trace(R)
49
+
50
+ if trace > 0:
51
+ qw = 0.5 * np.sqrt(1. + trace)
52
+ qx = (R[2,1] - R[1,2]) * 0.25 / qw
53
+ qy = (R[0,2] - R[2,0]) * 0.25 / qw
54
+ qz = (R[1,0] - R[0,1]) * 0.25 / qw
55
+ elif R[0,0] > R[1,1] and R[0,0] > R[2,2]:
56
+ s = 2. * np.sqrt(1. + R[0,0] - R[1,1] - R[2,2])
57
+ qw = (R[2,1] - R[1,2]) / s
58
+ qx = 0.25 * s
59
+ qy = (R[0,1] + R[1,0]) / s
60
+ qz = (R[0,2] + R[2,0]) / s
61
+ elif R[1,1] > R[2,2]:
62
+ s = 2. * np.sqrt(1. + R[1,1] - R[0,0] - R[2,2])
63
+ qw = (R[0,2] - R[2,0]) / s
64
+ qx = (R[0,1] + R[1,0]) / s
65
+ qy = 0.25 * s
66
+ qz = (R[1,2] + R[2,1]) / s
67
+ else:
68
+ s = 2. * np.sqrt(1. + R[2,2] - R[0,0] - R[1,1])
69
+ qw = (R[1,0] - R[0,1]) / s
70
+ qx = (R[0,2] + R[2,0]) / s
71
+ qy = (R[1,2] + R[2,1]) / s
72
+ qz = 0.25 * s
73
+
74
+ return Quaternion(np.array((qw, qx, qy, qz)))
75
+
76
+ # if angle is None, assume ||axis|| == angle, in radians
77
+ # if angle is not None, assume that axis is a unit vector
78
+ @staticmethod
79
+ def FromAxisAngle(axis, angle=None):
80
+ if angle is None:
81
+ angle = np.linalg.norm(axis)
82
+ if np.abs(angle) > np.finfo('float').eps:
83
+ axis = axis / angle
84
+
85
+ qw = np.cos(0.5 * angle)
86
+ axis = axis * np.sin(0.5 * angle)
87
+
88
+ return Quaternion(np.array((qw, axis[0], axis[1], axis[2])))
89
+
90
+ #---------------------------------------------------------------------------
91
+
92
+ def __init__(self, q=np.array((1., 0., 0., 0.))):
93
+ if isinstance(q, Quaternion):
94
+ self.q = q.q.copy()
95
+ else:
96
+ q = np.asarray(q)
97
+ if q.size == 4:
98
+ self.q = q.copy()
99
+ elif q.size == 3: # convert from a 3-vector to a quaternion
100
+ self.q = np.empty(4)
101
+ self.q[0], self.q[1:] = 0., q.ravel()
102
+ else:
103
+ raise Exception('Input quaternion should be a 3- or 4-vector')
104
+
105
+ def __add__(self, other):
106
+ return Quaternion(self.q + other.q)
107
+
108
+ def __iadd__(self, other):
109
+ self.q += other.q
110
+ return self
111
+
112
+ # conjugation via the ~ operator
113
+ def __invert__(self):
114
+ return Quaternion(
115
+ np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3])))
116
+
117
+ # returns: self.q * other.q if other is a Quaternion; otherwise performs
118
+ # scalar multiplication
119
+ def __mul__(self, other):
120
+ if isinstance(other, Quaternion): # quaternion multiplication
121
+ return Quaternion(np.array((
122
+ self.q[0] * other.q[0] - self.q[1] * other.q[1] -
123
+ self.q[2] * other.q[2] - self.q[3] * other.q[3],
124
+ self.q[0] * other.q[1] + self.q[1] * other.q[0] +
125
+ self.q[2] * other.q[3] - self.q[3] * other.q[2],
126
+ self.q[0] * other.q[2] - self.q[1] * other.q[3] +
127
+ self.q[2] * other.q[0] + self.q[3] * other.q[1],
128
+ self.q[0] * other.q[3] + self.q[1] * other.q[2] -
129
+ self.q[2] * other.q[1] + self.q[3] * other.q[0])))
130
+ else: # scalar multiplication (assumed)
131
+ return Quaternion(other * self.q)
132
+
133
+ def __rmul__(self, other):
134
+ return self * other
135
+
136
+ def __imul__(self, other):
137
+ self.q[:] = (self * other).q
138
+ return self
139
+
140
+ def __irmul__(self, other):
141
+ self.q[:] = (self * other).q
142
+ return self
143
+
144
+ def __neg__(self):
145
+ return Quaternion(-self.q)
146
+
147
+ def __sub__(self, other):
148
+ return Quaternion(self.q - other.q)
149
+
150
+ def __isub__(self, other):
151
+ self.q -= other.q
152
+ return self
153
+
154
+ def __str__(self):
155
+ return str(self.q)
156
+
157
+ def copy(self):
158
+ return Quaternion(self)
159
+
160
+ def dot(self, other):
161
+ return self.q.dot(other.q)
162
+
163
+ # assume the quaternion is nonzero!
164
+ def inverse(self):
165
+ return Quaternion((~self).q / self.q.dot(self.q))
166
+
167
+ def norm(self):
168
+ return np.linalg.norm(self.q)
169
+
170
+ def normalize(self):
171
+ self.q /= np.linalg.norm(self.q)
172
+ return self
173
+
174
+ # assume x is a Nx3 numpy array or a numpy 3-vector
175
+ def rotate_points(self, x):
176
+ x = np.atleast_2d(x)
177
+ return x.dot(self.ToR().T)
178
+
179
+ # convert to a rotation matrix
180
+ def ToR(self):
181
+ return np.eye(3) + 2 * np.array((
182
+ (-self.q[2] * self.q[2] - self.q[3] * self.q[3],
183
+ self.q[1] * self.q[2] - self.q[3] * self.q[0],
184
+ self.q[1] * self.q[3] + self.q[2] * self.q[0]),
185
+ ( self.q[1] * self.q[2] + self.q[3] * self.q[0],
186
+ -self.q[1] * self.q[1] - self.q[3] * self.q[3],
187
+ self.q[2] * self.q[3] - self.q[1] * self.q[0]),
188
+ ( self.q[1] * self.q[3] - self.q[2] * self.q[0],
189
+ self.q[2] * self.q[3] + self.q[1] * self.q[0],
190
+ -self.q[1] * self.q[1] - self.q[2] * self.q[2])))
191
+
192
+ # convert to axis-angle representation, with angle encoded by the length
193
+ def ToAxisAngle(self):
194
+ # recall that for axis-angle representation (a, angle), with "a" unit:
195
+ # q = (cos(angle/2), a * sin(angle/2))
196
+ # below, for readability, "theta" actually means half of the angle
197
+
198
+ sin_sq_theta = self.q[1:].dot(self.q[1:])
199
+
200
+ # if theta is non-zero, then we can compute a unique rotation
201
+ if np.abs(sin_sq_theta) > np.finfo('float').eps:
202
+ sin_theta = np.sqrt(sin_sq_theta)
203
+ cos_theta = self.q[0]
204
+
205
+ # atan2 is more stable, so we use it to compute theta
206
+ # note that we multiply by 2 to get the actual angle
207
+ angle = 2. * (
208
+ np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else
209
+ np.arctan2(sin_theta, cos_theta))
210
+
211
+ return self.q[1:] * (angle / sin_theta)
212
+
213
+ # otherwise, the result is singular, and we avoid dividing by
214
+ # sin(angle/2) = 0
215
+ return np.zeros(3)
216
+
217
+ # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler
218
+ # this assumes the quaternion is non-zero
219
+ # returns yaw, pitch, roll, with application in that order
220
+ def ToEulerAngles(self):
221
+ qsq = self.q**2
222
+ k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum()
223
+
224
+ if (1. - k) < np.finfo('float').eps: # north pole singularity
225
+ return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0.
226
+ if (1. + k) < np.finfo('float').eps: # south pole singularity
227
+ return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0.
228
+
229
+ yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]),
230
+ qsq[0] + qsq[1] - qsq[2] - qsq[3])
231
+ pitch = np.arcsin(k)
232
+ roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]),
233
+ qsq[0] - qsq[1] + qsq[2] - qsq[3])
234
+
235
+ return yaw, pitch, roll
236
+
237
+ #-------------------------------------------------------------------------------
238
+ #
239
+ # DualQuaternion
240
+ #
241
+ #-------------------------------------------------------------------------------
242
+
243
+ class DualQuaternion:
244
+ # DualQuaternion from an existing rotation + translation
245
+ @staticmethod
246
+ def FromQT(q, t):
247
+ return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q)
248
+
249
+ def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)):
250
+ self.q0, self.qe = Quaternion(q0), Quaternion(qe)
251
+
252
+ def __add__(self, other):
253
+ return DualQuaternion(self.q0 + other.q0, self.qe + other.qe)
254
+
255
+ def __iadd__(self, other):
256
+ self.q0 += other.q0
257
+ self.qe += other.qe
258
+ return self
259
+
260
+ # conguation via the ~ operator
261
+ def __invert__(self):
262
+ return DualQuaternion(~self.q0, ~self.qe)
263
+
264
+ def __mul__(self, other):
265
+ if isinstance(other, DualQuaternion):
266
+ return DualQuaternion(
267
+ self.q0 * other.q0,
268
+ self.q0 * other.qe + self.qe * other.q0)
269
+ elif isinstance(other, complex): # multiplication by a dual number
270
+ return DualQuaternion(
271
+ self.q0 * other.real,
272
+ self.q0 * other.imag + self.qe * other.real)
273
+ else: # scalar multiplication (assumed)
274
+ return DualQuaternion(other * self.q0, other * self.qe)
275
+
276
+ def __rmul__(self, other):
277
+ return self.__mul__(other)
278
+
279
+ def __imul__(self, other):
280
+ tmp = self * other
281
+ self.q0, self.qe = tmp.q0, tmp.qe
282
+ return self
283
+
284
+ def __neg__(self):
285
+ return DualQuaternion(-self.q0, -self.qe)
286
+
287
+ def __sub__(self, other):
288
+ return DualQuaternion(self.q0 - other.q0, self.qe - other.qe)
289
+
290
+ def __isub__(self, other):
291
+ self.q0 -= other.q0
292
+ self.qe -= other.qe
293
+ return self
294
+
295
+ # q^-1 = q* / ||q||^2
296
+ # assume that q0 is nonzero!
297
+ def inverse(self):
298
+ normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q))
299
+ inv_len_real = 1. / normsq.real
300
+ return ~self * complex(
301
+ inv_len_real, -normsq.imag * inv_len_real * inv_len_real)
302
+
303
+ # returns a complex representation of the real and imaginary parts of the norm
304
+ # assume that q0 is nonzero!
305
+ def norm(self):
306
+ q0_norm = self.q0.norm()
307
+ return complex(q0_norm, self.q0.dot(self.qe) / q0_norm)
308
+
309
+ # assume that q0 is nonzero!
310
+ def normalize(self):
311
+ # current length is ||q0|| + eps * (<q0, qe> / ||q0||)
312
+ # writing this as a + eps * b, the inverse is
313
+ # 1/||q|| = 1/a - eps * b / a^2
314
+ norm = self.norm()
315
+ inv_len_real = 1. / norm.real
316
+ self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real)
317
+ return self
318
+
319
+ # return the translation vector for this dual quaternion
320
+ def getT(self):
321
+ return 2 * (self.qe * ~self.q0).q[1:]
322
+
323
+ def ToQT(self):
324
+ return self.q0, self.getT()
internal/pycolmap/pycolmap/scene_manager.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: True Price <jtprice at cs.unc.edu>
2
+
3
+ import array
4
+ import numpy as np
5
+ import os
6
+ import struct
7
+
8
+ from collections import OrderedDict
9
+ from itertools import combinations
10
+
11
+ from camera import Camera
12
+ from image import Image
13
+ from rotation import Quaternion
14
+
15
+ #-------------------------------------------------------------------------------
16
+ #
17
+ # SceneManager
18
+ #
19
+ #-------------------------------------------------------------------------------
20
+
21
+ class SceneManager:
22
+ INVALID_POINT3D = np.uint64(-1)
23
+
24
+ def __init__(self, colmap_results_folder, image_path=None):
25
+ self.folder = colmap_results_folder
26
+ if not self.folder.endswith('/'):
27
+ self.folder += '/'
28
+
29
+ self.image_path = None
30
+ self.load_colmap_project_file(image_path=image_path)
31
+
32
+ self.cameras = OrderedDict()
33
+ self.images = OrderedDict()
34
+ self.name_to_image_id = dict()
35
+
36
+ self.last_camera_id = 0
37
+ self.last_image_id = 0
38
+
39
+ # Nx3 array of point3D xyz's
40
+ self.points3D = np.zeros((0, 3))
41
+
42
+ # for each element in points3D, stores the id of the point
43
+ self.point3D_ids = np.empty(0)
44
+
45
+ # point3D_id => index in self.points3D
46
+ self.point3D_id_to_point3D_idx = dict()
47
+
48
+ # point3D_id => [(image_id, point2D idx in image)]
49
+ self.point3D_id_to_images = dict()
50
+
51
+ self.point3D_colors = np.zeros((0, 3), dtype=np.uint8)
52
+ self.point3D_errors = np.zeros(0)
53
+
54
+ #---------------------------------------------------------------------------
55
+
56
+ def load_colmap_project_file(self, project_file=None, image_path=None):
57
+ if project_file is None:
58
+ project_file = self.folder + 'project.ini'
59
+
60
+ self.image_path = image_path
61
+
62
+ if self.image_path is None:
63
+ try:
64
+ with open(project_file, 'r') as f:
65
+ for line in iter(f.readline, ''):
66
+ if line.startswith('image_path'):
67
+ self.image_path = line[11:].strip()
68
+ break
69
+ except:
70
+ pass
71
+
72
+ if self.image_path is None:
73
+ print('Warning: image_path not found for reconstruction')
74
+ elif not self.image_path.endswith('/'):
75
+ self.image_path += '/'
76
+
77
+ #---------------------------------------------------------------------------
78
+
79
+ def load(self):
80
+ self.load_cameras()
81
+ self.load_images()
82
+ self.load_points3D()
83
+
84
+ #---------------------------------------------------------------------------
85
+
86
+ def load_cameras(self, input_file=None):
87
+ if input_file is None:
88
+ input_file = self.folder + 'cameras.bin'
89
+ if os.path.exists(input_file):
90
+ self._load_cameras_bin(input_file)
91
+ else:
92
+ input_file = self.folder + 'cameras.txt'
93
+ if os.path.exists(input_file):
94
+ self._load_cameras_txt(input_file)
95
+ else:
96
+ raise IOError('no cameras file found')
97
+
98
+ def _load_cameras_bin(self, input_file):
99
+ self.cameras = OrderedDict()
100
+
101
+ with open(input_file, 'rb') as f:
102
+ num_cameras = struct.unpack('L', f.read(8))[0]
103
+
104
+ for _ in range(num_cameras):
105
+ camera_id, camera_type, w, h = struct.unpack('IiLL', f.read(24))
106
+ num_params = Camera.GetNumParams(camera_type)
107
+ params = struct.unpack('d' * num_params, f.read(8 * num_params))
108
+ self.cameras[camera_id] = Camera(camera_type, w, h, params)
109
+ self.last_camera_id = max(self.last_camera_id, camera_id)
110
+
111
+ def _load_cameras_txt(self, input_file):
112
+ self.cameras = OrderedDict()
113
+
114
+ with open(input_file, 'r') as f:
115
+ for line in iter(lambda: f.readline().strip(), ''):
116
+ if not line or line.startswith('#'):
117
+ continue
118
+
119
+ data = line.split()
120
+ camera_id = int(data[0])
121
+ self.cameras[camera_id] = Camera(
122
+ data[1], int(data[2]), int(data[3]), map(float, data[4:]))
123
+ self.last_camera_id = max(self.last_camera_id, camera_id)
124
+
125
+ #---------------------------------------------------------------------------
126
+
127
+ def load_images(self, input_file=None):
128
+ if input_file is None:
129
+ input_file = self.folder + 'images.bin'
130
+ if os.path.exists(input_file):
131
+ self._load_images_bin(input_file)
132
+ else:
133
+ input_file = self.folder + 'images.txt'
134
+ if os.path.exists(input_file):
135
+ self._load_images_txt(input_file)
136
+ else:
137
+ raise IOError('no images file found')
138
+
139
+ def _load_images_bin(self, input_file):
140
+ self.images = OrderedDict()
141
+
142
+ with open(input_file, 'rb') as f:
143
+ num_images = struct.unpack('L', f.read(8))[0]
144
+ image_struct = struct.Struct('<I 4d 3d I')
145
+ for _ in range(num_images):
146
+ data = image_struct.unpack(f.read(image_struct.size))
147
+ image_id = data[0]
148
+ q = Quaternion(np.array(data[1:5]))
149
+ t = np.array(data[5:8])
150
+ camera_id = data[8]
151
+ name = b''.join(c for c in iter(lambda: f.read(1), b'\x00')).decode()
152
+
153
+ image = Image(name, camera_id, q, t)
154
+ num_points2D = struct.unpack('Q', f.read(8))[0]
155
+
156
+ # Optimized code below.
157
+ # Read all elements as double first, then convert to array, slice it
158
+ # into points2d and ids, and convert ids back to unsigned long longs
159
+ # ('Q'). This is significantly faster than using O(num_points2D) f.read
160
+ # calls, experiments show >7x improvements in 60 image model, 23s -> 3s.
161
+ points_array = array.array('d')
162
+ points_array.fromfile(f, 3 * num_points2D)
163
+ points_elements = np.array(points_array).reshape((num_points2D, 3))
164
+ image.points2D = points_elements[:, :2]
165
+
166
+ ids_array = array.array('Q')
167
+ ids_array.frombytes(points_elements[:, 2].tobytes())
168
+ image.point3D_ids = np.array(ids_array, dtype=np.uint64).reshape(
169
+ (num_points2D,))
170
+
171
+ # automatically remove points without an associated 3D point
172
+ #mask = (image.point3D_ids != SceneManager.INVALID_POINT3D)
173
+ #image.points2D = image.points2D[mask]
174
+ #image.point3D_ids = image.point3D_ids[mask]
175
+
176
+ self.images[image_id] = image
177
+ self.name_to_image_id[image.name] = image_id
178
+
179
+ self.last_image_id = max(self.last_image_id, image_id)
180
+
181
+ def _load_images_txt(self, input_file):
182
+ self.images = OrderedDict()
183
+
184
+ with open(input_file, 'r') as f:
185
+ is_camera_description_line = False
186
+
187
+ for line in iter(lambda: f.readline().strip(), ''):
188
+ if not line or line.startswith('#'):
189
+ continue
190
+
191
+ is_camera_description_line = not is_camera_description_line
192
+
193
+ data = line.split()
194
+
195
+ if is_camera_description_line:
196
+ image_id = int(data[0])
197
+ image = Image(data[-1], int(data[-2]),
198
+ Quaternion(np.array(map(float, data[1:5]))),
199
+ np.array(map(float, data[5:8])))
200
+ else:
201
+ image.points2D = np.array(
202
+ [map(float, data[::3]), map(float, data[1::3])]).T
203
+ image.point3D_ids = np.array(map(np.uint64, data[2::3]))
204
+
205
+ # automatically remove points without an associated 3D point
206
+ #mask = (image.point3D_ids != SceneManager.INVALID_POINT3D)
207
+ #image.points2D = image.points2D[mask]
208
+ #image.point3D_ids = image.point3D_ids[mask]
209
+
210
+ self.images[image_id] = image
211
+ self.name_to_image_id[image.name] = image_id
212
+
213
+ self.last_image_id = max(self.last_image_id, image_id)
214
+
215
+ #---------------------------------------------------------------------------
216
+
217
+ def load_points3D(self, input_file=None):
218
+ if input_file is None:
219
+ input_file = self.folder + 'points3D.bin'
220
+ if os.path.exists(input_file):
221
+ self._load_points3D_bin(input_file)
222
+ else:
223
+ input_file = self.folder + 'points3D.txt'
224
+ if os.path.exists(input_file):
225
+ self._load_points3D_txt(input_file)
226
+ else:
227
+ raise IOError('no points3D file found')
228
+
229
+ def _load_points3D_bin(self, input_file):
230
+ with open(input_file, 'rb') as f:
231
+ num_points3D = struct.unpack('L', f.read(8))[0]
232
+
233
+ self.points3D = np.empty((num_points3D, 3))
234
+ self.point3D_ids = np.empty(num_points3D, dtype=np.uint64)
235
+ self.point3D_colors = np.empty((num_points3D, 3), dtype=np.uint8)
236
+ self.point3D_id_to_point3D_idx = dict()
237
+ self.point3D_id_to_images = dict()
238
+ self.point3D_errors = np.empty(num_points3D)
239
+
240
+ data_struct = struct.Struct('<Q 3d 3B d Q')
241
+
242
+ for i in range(num_points3D):
243
+ data = data_struct.unpack(f.read(data_struct.size))
244
+ self.point3D_ids[i] = data[0]
245
+ self.points3D[i] = data[1:4]
246
+ self.point3D_colors[i] = data[4:7]
247
+ self.point3D_errors[i] = data[7]
248
+ track_len = data[8]
249
+
250
+ self.point3D_id_to_point3D_idx[self.point3D_ids[i]] = i
251
+
252
+ data = struct.unpack(f'{2*track_len}I', f.read(2 * track_len * 4))
253
+
254
+ self.point3D_id_to_images[self.point3D_ids[i]] = \
255
+ np.array(data, dtype=np.uint32).reshape(track_len, 2)
256
+
257
+ def _load_points3D_txt(self, input_file):
258
+ self.points3D = []
259
+ self.point3D_ids = []
260
+ self.point3D_colors = []
261
+ self.point3D_id_to_point3D_idx = dict()
262
+ self.point3D_id_to_images = dict()
263
+ self.point3D_errors = []
264
+
265
+ with open(input_file, 'r') as f:
266
+ for line in iter(lambda: f.readline().strip(), ''):
267
+ if not line or line.startswith('#'):
268
+ continue
269
+
270
+ data = line.split()
271
+ point3D_id = np.uint64(data[0])
272
+
273
+ self.point3D_ids.append(point3D_id)
274
+ self.point3D_id_to_point3D_idx[point3D_id] = len(self.points3D)
275
+ self.points3D.append(map(np.float64, data[1:4]))
276
+ self.point3D_colors.append(map(np.uint8, data[4:7]))
277
+ self.point3D_errors.append(np.float64(data[7]))
278
+
279
+ # load (image id, point2D idx) pairs
280
+ self.point3D_id_to_images[point3D_id] = \
281
+ np.array(map(np.uint32, data[8:])).reshape(-1, 2)
282
+
283
+ self.points3D = np.array(self.points3D)
284
+ self.point3D_ids = np.array(self.point3D_ids)
285
+ self.point3D_colors = np.array(self.point3D_colors)
286
+ self.point3D_errors = np.array(self.point3D_errors)
287
+
288
+ #---------------------------------------------------------------------------
289
+
290
+ def save(self, output_folder, binary=True):
291
+ self.save_cameras(output_folder, binary=binary)
292
+ self.save_images(output_folder, binary=binary)
293
+ self.save_points3D(output_folder, binary=binary)
294
+
295
+ #---------------------------------------------------------------------------
296
+
297
+ def save_cameras(self, output_folder, output_file=None, binary=True):
298
+ if not os.path.exists(output_folder):
299
+ os.makedirs(output_folder)
300
+
301
+ if output_file is None:
302
+ output_file = 'cameras.bin' if binary else 'cameras.txt'
303
+
304
+ output_file = os.path.join(output_folder, output_file)
305
+
306
+ if binary:
307
+ self._save_cameras_bin(output_file)
308
+ else:
309
+ self._save_cameras_txt(output_file)
310
+
311
+ def _save_cameras_bin(self, output_file):
312
+ with open(output_file, 'wb') as fid:
313
+ fid.write(struct.pack('L', len(self.cameras)))
314
+
315
+ camera_struct = struct.Struct('IiLL')
316
+
317
+ for camera_id, camera in sorted(self.cameras.iteritems()):
318
+ fid.write(camera_struct.pack(
319
+ camera_id, camera.camera_type, camera.width, camera.height))
320
+ # TODO (True): should move this into the Camera class
321
+ fid.write(camera.get_params().tobytes())
322
+
323
+ def _save_cameras_txt(self, output_file):
324
+ with open(output_file, 'w') as fid:
325
+ print>>fid, '# Camera list with one line of data per camera:'
326
+ print>>fid, '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]'
327
+ print>>fid, '# Number of cameras:', len(self.cameras)
328
+
329
+ for camera_id, camera in sorted(self.cameras.iteritems()):
330
+ print>>fid, camera_id, camera
331
+
332
+ #---------------------------------------------------------------------------
333
+
334
+ def save_images(self, output_folder, output_file=None, binary=True):
335
+ if not os.path.exists(output_folder):
336
+ os.makedirs(output_folder)
337
+
338
+ if output_file is None:
339
+ output_file = 'images.bin' if binary else 'images.txt'
340
+
341
+ output_file = os.path.join(output_folder, output_file)
342
+
343
+ if binary:
344
+ self._save_images_bin(output_file)
345
+ else:
346
+ self._save_images_txt(output_file)
347
+
348
+ def _save_images_bin(self, output_file):
349
+ with open(output_file, 'wb') as fid:
350
+ fid.write(struct.pack('L', len(self.images)))
351
+
352
+ for image_id, image in self.images.iteritems():
353
+ fid.write(struct.pack('I', image_id))
354
+ fid.write(image.q.q.tobytes())
355
+ fid.write(image.tvec.tobytes())
356
+ fid.write(struct.pack('I', image.camera_id))
357
+ fid.write(image.name + '\0')
358
+ fid.write(struct.pack('L', len(image.points2D)))
359
+ data = np.rec.fromarrays(
360
+ (image.points2D[:,0], image.points2D[:,1], image.point3D_ids))
361
+ fid.write(data.tobytes())
362
+
363
+ def _save_images_txt(self, output_file):
364
+ with open(output_file, 'w') as fid:
365
+ print>>fid, '# Image list with two lines of data per image:'
366
+ print>>fid, '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME'
367
+ print>>fid, '# POINTS2D[] as (X, Y, POINT3D_ID)'
368
+ print>>fid, '# Number of images: {},'.format(len(self.images)),
369
+ print>>fid, 'mean observations per image: unknown'
370
+
371
+ for image_id, image in self.images.iteritems():
372
+ print>>fid, image_id,
373
+ print>>fid, ' '.join(str(qi) for qi in image.q.q),
374
+ print>>fid, ' '.join(str(ti) for ti in image.tvec),
375
+ print>>fid, image.camera_id, image.name
376
+
377
+ data = np.rec.fromarrays(
378
+ (image.points2D[:,0], image.points2D[:,1],
379
+ image.point3D_ids.astype(np.int64)))
380
+ if len(data) > 0:
381
+ np.savetxt(fid, data, '%.2f %.2f %d', newline=' ')
382
+ fid.seek(-1, os.SEEK_CUR)
383
+ fid.write('\n')
384
+
385
+ #---------------------------------------------------------------------------
386
+
387
+ def save_points3D(self, output_folder, output_file=None, binary=True):
388
+ if not os.path.exists(output_folder):
389
+ os.makedirs(output_folder)
390
+
391
+ if output_file is None:
392
+ output_file = 'points3D.bin' if binary else 'points3D.txt'
393
+
394
+ output_file = os.path.join(output_folder, output_file)
395
+
396
+ if binary:
397
+ self._save_points3D_bin(output_file)
398
+ else:
399
+ self._save_points3D_txt(output_file)
400
+
401
+ def _save_points3D_bin(self, output_file):
402
+ num_valid_points3D = sum(
403
+ 1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues()
404
+ if point3D_idx != SceneManager.INVALID_POINT3D)
405
+
406
+ iter_point3D_id_to_point3D_idx = \
407
+ self.point3D_id_to_point3D_idx.iteritems()
408
+
409
+ with open(output_file, 'wb') as fid:
410
+ fid.write(struct.pack('L', num_valid_points3D))
411
+
412
+ for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
413
+ if point3D_idx == SceneManager.INVALID_POINT3D:
414
+ continue
415
+
416
+ fid.write(struct.pack('L', point3D_id))
417
+ fid.write(self.points3D[point3D_idx].tobytes())
418
+ fid.write(self.point3D_colors[point3D_idx].tobytes())
419
+ fid.write(self.point3D_errors[point3D_idx].tobytes())
420
+ fid.write(
421
+ struct.pack('L', len(self.point3D_id_to_images[point3D_id])))
422
+ fid.write(self.point3D_id_to_images[point3D_id].tobytes())
423
+
424
+ def _save_points3D_txt(self, output_file):
425
+ num_valid_points3D = sum(
426
+ 1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues()
427
+ if point3D_idx != SceneManager.INVALID_POINT3D)
428
+
429
+ array_to_string = lambda arr: ' '.join(str(x) for x in arr)
430
+
431
+ iter_point3D_id_to_point3D_idx = \
432
+ self.point3D_id_to_point3D_idx.iteritems()
433
+
434
+ with open(output_file, 'w') as fid:
435
+ print>>fid, '# 3D point list with one line of data per point:'
436
+ print>>fid, '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as ',
437
+ print>>fid, '(IMAGE_ID, POINT2D_IDX)'
438
+ print>>fid, '# Number of points: {},'.format(num_valid_points3D),
439
+ print>>fid, 'mean track length: unknown'
440
+
441
+ for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
442
+ if point3D_idx == SceneManager.INVALID_POINT3D:
443
+ continue
444
+
445
+ print>>fid, point3D_id,
446
+ print>>fid, array_to_string(self.points3D[point3D_idx]),
447
+ print>>fid, array_to_string(self.point3D_colors[point3D_idx]),
448
+ print>>fid, self.point3D_errors[point3D_idx],
449
+ print>>fid, array_to_string(
450
+ self.point3D_id_to_images[point3D_id].flat)
451
+
452
+ #---------------------------------------------------------------------------
453
+
454
+ # return the image id associated with a given image file
455
+ def get_image_from_name(self, image_name):
456
+ image_id = self.name_to_image_id[image_name]
457
+ return image_id, self.images[image_id]
458
+
459
+ #---------------------------------------------------------------------------
460
+
461
+ def get_camera(self, camera_id):
462
+ return self.cameras[camera_id]
463
+
464
+ #---------------------------------------------------------------------------
465
+
466
+ def get_points3D(self, image_id, return_points2D=True, return_colors=False):
467
+ image = self.images[image_id]
468
+
469
+ mask = (image.point3D_ids != SceneManager.INVALID_POINT3D)
470
+
471
+ point3D_idxs = np.array([
472
+ self.point3D_id_to_point3D_idx[point3D_id]
473
+ for point3D_id in image.point3D_ids[mask]])
474
+ # detect filtered points
475
+ filter_mask = (point3D_idxs != SceneManager.INVALID_POINT3D)
476
+ point3D_idxs = point3D_idxs[filter_mask]
477
+ result = [self.points3D[point3D_idxs,:]]
478
+
479
+ if return_points2D:
480
+ mask[mask] &= filter_mask
481
+ result += [image.points2D[mask]]
482
+ if return_colors:
483
+ result += [self.point3D_colors[point3D_idxs,:]]
484
+
485
+ return result if len(result) > 1 else result[0]
486
+
487
+ #---------------------------------------------------------------------------
488
+
489
+ def point3D_valid(self, point3D_id):
490
+ return (self.point3D_id_to_point3D_idx[point3D_id] !=
491
+ SceneManager.INVALID_POINT3D)
492
+
493
+ #---------------------------------------------------------------------------
494
+
495
+ def get_filtered_points3D(self, return_colors=False):
496
+ point3D_idxs = [
497
+ idx for idx in self.point3D_id_to_point3D_idx.values()
498
+ if idx != SceneManager.INVALID_POINT3D]
499
+ result = [self.points3D[point3D_idxs,:]]
500
+
501
+ if return_colors:
502
+ result += [self.point3D_colors[point3D_idxs,:]]
503
+
504
+ return result if len(result) > 1 else result[0]
505
+
506
+ #---------------------------------------------------------------------------
507
+
508
+ # return 3D points shared by two images
509
+ def get_shared_points3D(self, image_id1, image_id2):
510
+ point3D_ids = (
511
+ set(self.images[image_id1].point3D_ids) &
512
+ set(self.images[image_id2].point3D_ids))
513
+ point3D_ids.discard(SceneManager.INVALID_POINT3D)
514
+
515
+ point3D_idxs = np.array([self.point3D_id_to_point3D_idx[point3D_id]
516
+ for point3D_id in point3D_ids])
517
+
518
+ return self.points3D[point3D_idxs,:]
519
+
520
+ #---------------------------------------------------------------------------
521
+
522
+ # project *all* 3D points into image, return their projection coordinates,
523
+ # as well as their 3D positions
524
+ def get_viewed_points(self, image_id):
525
+ image = self.images[image_id]
526
+
527
+ # get unfiltered points
528
+ point3D_idxs = set(self.point3D_id_to_point3D_idx.itervalues())
529
+ point3D_idxs.discard(SceneManager.INVALID_POINT3D)
530
+ point3D_idxs = list(point3D_idxs)
531
+ points3D = self.points3D[point3D_idxs,:]
532
+
533
+ # orient points relative to camera
534
+ R = image.q.ToR()
535
+ points3D = points3D.dot(R.T) + image.tvec[np.newaxis,:]
536
+ points3D = points3D[points3D[:,2] > 0,:] # keep points with positive z
537
+
538
+ # put points into image coordinates
539
+ camera = self.cameras[image.camera_id]
540
+ points2D = points3D.dot(camera.get_camera_matrix().T)
541
+ points2D = points2D[:,:2] / points2D[:,2][:,np.newaxis]
542
+
543
+ # keep points that are within the image
544
+ mask = (
545
+ (points2D[:,0] >= 0) &
546
+ (points2D[:,1] >= 0) &
547
+ (points2D[:,0] < camera.width - 1) &
548
+ (points2D[:,1] < camera.height - 1))
549
+
550
+ return points2D[mask,:], points3D[mask,:]
551
+
552
+ #---------------------------------------------------------------------------
553
+
554
+ def add_camera(self, camera):
555
+ self.last_camera_id += 1
556
+ self.cameras[self.last_camera_id] = camera
557
+ return self.last_camera_id
558
+
559
+ #---------------------------------------------------------------------------
560
+
561
+ def add_image(self, image):
562
+ self.last_image_id += 1
563
+ self.images[self.last_image_id] = image
564
+ return self.last_image_id
565
+
566
+ #---------------------------------------------------------------------------
567
+
568
+ def delete_images(self, image_list):
569
+ # delete specified images
570
+ for image_id in image_list:
571
+ if image_id in self.images:
572
+ del self.images[image_id]
573
+
574
+ keep_set = set(self.images.iterkeys())
575
+
576
+ # delete references to specified images, and ignore any points that are
577
+ # invalidated
578
+ iter_point3D_id_to_point3D_idx = \
579
+ self.point3D_id_to_point3D_idx.iteritems()
580
+
581
+ for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
582
+ if point3D_idx == SceneManager.INVALID_POINT3D:
583
+ continue
584
+
585
+ mask = np.array([
586
+ image_id in keep_set
587
+ for image_id in self.point3D_id_to_images[point3D_id][:,0]])
588
+ if np.any(mask):
589
+ self.point3D_id_to_images[point3D_id] = \
590
+ self.point3D_id_to_images[point3D_id][mask]
591
+ else:
592
+ self.point3D_id_to_point3D_idx[point3D_id] = \
593
+ SceneManager.INVALID_POINT3D
594
+
595
+ #---------------------------------------------------------------------------
596
+
597
+ # camera_list: set of cameras whose points we'd like to keep
598
+ # min/max triangulation angle: in degrees
599
+ def filter_points3D(self,
600
+ min_track_len=0, max_error=np.inf, min_tri_angle=0,
601
+ max_tri_angle=180, image_set=set()):
602
+
603
+ image_set = set(image_set)
604
+
605
+ check_triangulation_angles = (min_tri_angle > 0 or max_tri_angle < 180)
606
+ if check_triangulation_angles:
607
+ max_tri_prod = np.cos(np.radians(min_tri_angle))
608
+ min_tri_prod = np.cos(np.radians(max_tri_angle))
609
+
610
+ iter_point3D_id_to_point3D_idx = \
611
+ self.point3D_id_to_point3D_idx.iteritems()
612
+
613
+ image_ids = []
614
+
615
+ for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
616
+ if point3D_idx == SceneManager.INVALID_POINT3D:
617
+ continue
618
+
619
+ if image_set or min_track_len > 0:
620
+ image_ids = set(self.point3D_id_to_images[point3D_id][:,0])
621
+
622
+ # check if error and min track length are sufficient, or if none of
623
+ # the selected cameras see the point
624
+ if (len(image_ids) < min_track_len or
625
+ self.point3D_errors[point3D_idx] > max_error or
626
+ image_set and image_set.isdisjoint(image_ids)):
627
+ self.point3D_id_to_point3D_idx[point3D_id] = \
628
+ SceneManager.INVALID_POINT3D
629
+
630
+ # find dot product between all camera viewing rays
631
+ elif check_triangulation_angles:
632
+ xyz = self.points3D[point3D_idx,:]
633
+ tvecs = np.array(
634
+ [(self.images[image_id].tvec - xyz)
635
+ for image_id in image_ids])
636
+ tvecs /= np.linalg.norm(tvecs, axis=-1)[:,np.newaxis]
637
+
638
+ cos_theta = np.array(
639
+ [u.dot(v) for u,v in combinations(tvecs, 2)])
640
+
641
+ # min_prod = cos(maximum viewing angle), and vice versa
642
+ # if maximum viewing angle is too small or too large,
643
+ # don't add this point
644
+ if (np.min(cos_theta) > max_tri_prod or
645
+ np.max(cos_theta) < min_tri_prod):
646
+ self.point3D_id_to_point3D_idx[point3D_id] = \
647
+ SceneManager.INVALID_POINT3D
648
+
649
+ # apply the filters to the image point3D_ids
650
+ for image in self.images.itervalues():
651
+ mask = np.array([
652
+ self.point3D_id_to_point3D_idx.get(point3D_id, 0) \
653
+ == SceneManager.INVALID_POINT3D
654
+ for point3D_id in image.point3D_ids])
655
+ image.point3D_ids[mask] = SceneManager.INVALID_POINT3D
656
+
657
+ #---------------------------------------------------------------------------
658
+
659
+ # scene graph: {image_id: [image_id: #shared points]}
660
+ def build_scene_graph(self):
661
+ self.scene_graph = defaultdict(lambda: defaultdict(int))
662
+ point3D_iter = self.point3D_id_to_images.iteritems()
663
+
664
+ for i, (point3D_id, images) in enumerate(point3D_iter):
665
+ if not self.point3D_valid(point3D_id):
666
+ continue
667
+
668
+ for image_id1, image_id2 in combinations(images[:,0], 2):
669
+ self.scene_graph[image_id1][image_id2] += 1
670
+ self.scene_graph[image_id2][image_id1] += 1
internal/pycolmap/tools/colmap_to_nvm.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import sys
3
+ sys.path.append("..")
4
+
5
+ import numpy as np
6
+
7
+ from pycolmap import Quaternion, SceneManager
8
+
9
+
10
+ #-------------------------------------------------------------------------------
11
+
12
+ def main(args):
13
+ scene_manager = SceneManager(args.input_folder)
14
+ scene_manager.load()
15
+
16
+ with open(args.output_file, "w") as fid:
17
+ fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images)))
18
+
19
+ image_fmt_str = " {:.3f} " + 7 * "{:.7f} "
20
+ for image_id, image in scene_manager.images.iteritems():
21
+ camera = scene_manager.cameras[image.camera_id]
22
+ f = 0.5 * (camera.fx + camera.fy)
23
+ fid.write(args.image_name_prefix + image.name)
24
+ fid.write(image_fmt_str.format(
25
+ *((f,) + tuple(image.q.q) + tuple(image.C()))))
26
+ if camera.distortion_func is None:
27
+ fid.write("0 0\n")
28
+ else:
29
+ fid.write("{:.7f} 0\n".format(-camera.k1))
30
+
31
+ image_id_to_idx = dict(
32
+ (image_id, i) for i, image_id in enumerate(scene_manager.images))
33
+
34
+ fid.write("{:d}\n".format(len(scene_manager.points3D)))
35
+ for i, point3D_id in enumerate(scene_manager.point3D_ids):
36
+ fid.write(
37
+ "{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i]))
38
+ fid.write(
39
+ "{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i]))
40
+ keypoints = [
41
+ (image_id_to_idx[image_id], kp_idx) +
42
+ tuple(scene_manager.images[image_id].points2D[kp_idx])
43
+ for image_id, kp_idx in
44
+ scene_manager.point3D_id_to_images[point3D_id]]
45
+ fid.write("{:d}".format(len(keypoints)))
46
+ fid.write(
47
+ (len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format(
48
+ *itertools.chain(*keypoints)))
49
+
50
+
51
+ #-------------------------------------------------------------------------------
52
+
53
+ if __name__ == "__main__":
54
+ import argparse
55
+
56
+ parser = argparse.ArgumentParser(
57
+ description="Save a COLMAP reconstruction in the NVM format "
58
+ "(http://ccwu.me/vsfm/doc.html#nvm).",
59
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
60
+
61
+ parser.add_argument("input_folder")
62
+ parser.add_argument("output_file")
63
+
64
+ parser.add_argument("--image_name_prefix", type=str, default="",
65
+ help="prefix image names with this string (e.g., 'images/')")
66
+
67
+ args = parser.parse_args()
68
+
69
+ main(args)
internal/pycolmap/tools/delete_images.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ import numpy as np
5
+
6
+ from pycolmap import DualQuaternion, Image, SceneManager
7
+
8
+
9
+ #-------------------------------------------------------------------------------
10
+
11
+ def main(args):
12
+ scene_manager = SceneManager(args.input_folder)
13
+ scene_manager.load()
14
+
15
+ image_ids = map(scene_manager.get_image_from_name,
16
+ iter(lambda: sys.stdin.readline().strip(), ""))
17
+ scene_manager.delete_images(image_ids)
18
+
19
+ scene_manager.save(args.output_folder)
20
+
21
+
22
+ #-------------------------------------------------------------------------------
23
+
24
+ if __name__ == "__main__":
25
+ import argparse
26
+
27
+ parser = argparse.ArgumentParser(
28
+ description="Deletes images (filenames read from stdin) from a model.",
29
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30
+
31
+ parser.add_argument("input_folder")
32
+ parser.add_argument("output_folder")
33
+
34
+ args = parser.parse_args()
35
+
36
+ main(args)
internal/pycolmap/tools/impute_missing_cameras.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ import numpy as np
5
+
6
+ from pycolmap import DualQuaternion, Image, SceneManager
7
+
8
+
9
+ #-------------------------------------------------------------------------------
10
+
11
+ image_to_idx = lambda im: int(im.name[:im.name.rfind(".")])
12
+
13
+
14
+ #-------------------------------------------------------------------------------
15
+
16
+ def interpolate_linear(images, camera_id, file_format):
17
+ if len(images) < 2:
18
+ raise ValueError("Need at least two images for linear interpolation!")
19
+
20
+ prev_image = images[0]
21
+ prev_idx = image_to_idx(prev_image)
22
+ prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t)
23
+ start = prev_idx
24
+
25
+ new_images = []
26
+
27
+ for image in images[1:]:
28
+ curr_idx = image_to_idx(image)
29
+ curr_dq = DualQuaternion.FromQT(image.q, image.t)
30
+ T = curr_idx - prev_idx
31
+ Tinv = 1. / T
32
+
33
+ # like quaternions, dq(x) = -dq(x), so we'll need to pick the one more
34
+ # appropriate for interpolation by taking -dq if the dot product of the
35
+ # two q-vectors is negative
36
+ if prev_dq.q0.dot(curr_dq.q0) < 0:
37
+ curr_dq = -curr_dq
38
+
39
+ for i in xrange(1, T):
40
+ t = i * Tinv
41
+ dq = t * prev_dq + (1. - t) * curr_dq
42
+ q, t = dq.ToQT()
43
+ new_images.append(
44
+ Image(file_format.format(prev_idx + i), args.camera_id, q, t))
45
+
46
+ prev_idx = curr_idx
47
+ prev_dq = curr_dq
48
+
49
+ return new_images
50
+
51
+
52
+ #-------------------------------------------------------------------------------
53
+
54
+ def interpolate_hermite(images, camera_id, file_format):
55
+ if len(images) < 4:
56
+ raise ValueError(
57
+ "Need at least four images for Hermite spline interpolation!")
58
+
59
+ new_images = []
60
+
61
+ # linear blending for the first frames
62
+ T0 = image_to_idx(images[0])
63
+ dq0 = DualQuaternion.FromQT(images[0].q, images[0].t)
64
+ T1 = image_to_idx(images[1])
65
+ dq1 = DualQuaternion.FromQT(images[1].q, images[1].t)
66
+
67
+ if dq0.q0.dot(dq1.q0) < 0:
68
+ dq1 = -dq1
69
+ dT = 1. / float(T1 - T0)
70
+ for j in xrange(1, T1 - T0):
71
+ t = j * dT
72
+ dq = ((1. - t) * dq0 + t * dq1).normalize()
73
+ new_images.append(
74
+ Image(file_format.format(T0 + j), camera_id, *dq.ToQT()))
75
+
76
+ T2 = image_to_idx(images[2])
77
+ dq2 = DualQuaternion.FromQT(images[2].q, images[2].t)
78
+ if dq1.q0.dot(dq2.q0) < 0:
79
+ dq2 = -dq2
80
+
81
+ # Hermite spline interpolation of dual quaternions
82
+ # pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf
83
+ for i in xrange(1, len(images) - 2):
84
+ T3 = image_to_idx(images[i + 2])
85
+ dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t)
86
+ if dq2.q0.dot(dq3.q0) < 0:
87
+ dq3 = -dq3
88
+
89
+ prev_duration = T1 - T0
90
+ current_duration = T2 - T1
91
+ next_duration = T3 - T2
92
+
93
+ # approximate the derivatives at dq1 and dq2 using weighted central
94
+ # differences
95
+ dt1 = 1. / float(T2 - T0)
96
+ dt2 = 1. / float(T3 - T1)
97
+
98
+ m1 = (current_duration * dt1) * (dq2 - dq1) + \
99
+ (prev_duration * dt1) * (dq1 - dq0)
100
+ m2 = (next_duration * dt2) * (dq3 - dq2) + \
101
+ (current_duration * dt2) * (dq2 - dq1)
102
+
103
+ dT = 1. / float(current_duration)
104
+
105
+ for j in xrange(1, current_duration):
106
+ t = j * dT # 0 to 1
107
+ t2 = t * t # t squared
108
+ t3 = t2 * t # t cubed
109
+
110
+ # coefficients of the Hermite spline (a=>dq and b=>m)
111
+ a1 = 2. * t3 - 3. * t2 + 1.
112
+ b1 = t3 - 2. * t2 + t
113
+ a2 = -2. * t3 + 3. * t2
114
+ b2 = t3 - t2
115
+
116
+ dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize()
117
+
118
+ new_images.append(
119
+ Image(file_format.format(T1 + j), camera_id, *dq.ToQT()))
120
+
121
+ T0, T1, T2 = T1, T2, T3
122
+ dq0, dq1, dq2 = dq1, dq2, dq3
123
+
124
+ # linear blending for the last frames
125
+ dT = 1. / float(T2 - T1)
126
+ for j in xrange(1, T2 - T1):
127
+ t = j * dT # 0 to 1
128
+ dq = ((1. - t) * dq1 + t * dq2).normalize()
129
+ new_images.append(
130
+ Image(file_format.format(T1 + j), camera_id, *dq.ToQT()))
131
+
132
+ return new_images
133
+
134
+
135
+ #-------------------------------------------------------------------------------
136
+
137
+ def main(args):
138
+ scene_manager = SceneManager(args.input_folder)
139
+ scene_manager.load()
140
+
141
+ images = sorted(scene_manager.images.itervalues(), key=image_to_idx)
142
+
143
+ if args.method.lower() == "linear":
144
+ new_images = interpolate_linear(images, args.camera_id, args.format)
145
+ else:
146
+ new_images = interpolate_hermite(images, args.camera_id, args.format)
147
+
148
+ map(scene_manager.add_image, new_images)
149
+
150
+ scene_manager.save(args.output_folder)
151
+
152
+
153
+ #-------------------------------------------------------------------------------
154
+
155
+ if __name__ == "__main__":
156
+ import argparse
157
+
158
+ parser = argparse.ArgumentParser(
159
+ description="Given a reconstruction with ordered images *with integer "
160
+ "filenames* like '000100.png', fill in missing camera positions for "
161
+ "intermediate frames.",
162
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
163
+
164
+ parser.add_argument("input_folder")
165
+ parser.add_argument("output_folder")
166
+
167
+ parser.add_argument("--camera_id", type=int, default=1,
168
+ help="camera id to use for the missing images")
169
+
170
+ parser.add_argument("--format", type=str, default="{:06d}.png",
171
+ help="filename format to use for added images")
172
+
173
+ parser.add_argument(
174
+ "--method", type=str.lower, choices=("linear", "hermite"),
175
+ default="hermite",
176
+ help="Pose imputation method")
177
+
178
+ args = parser.parse_args()
179
+
180
+ main(args)
internal/pycolmap/tools/save_cameras_as_ply.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ import numpy as np
5
+ import os
6
+
7
+ from pycolmap import SceneManager
8
+
9
+
10
+ #-------------------------------------------------------------------------------
11
+
12
+ # Saves the cameras as a mesh
13
+ #
14
+ # inputs:
15
+ # - ply_file: output file
16
+ # - images: ordered array of pycolmap Image objects
17
+ # - color: color string for the camera
18
+ # - scale: amount to shrink/grow the camera model
19
+ def save_camera_ply(ply_file, images, scale):
20
+ points3D = scale * np.array((
21
+ (0., 0., 0.),
22
+ (-1., -1., 1.),
23
+ (-1., 1., 1.),
24
+ (1., -1., 1.),
25
+ (1., 1., 1.)))
26
+
27
+ faces = np.array(((0, 2, 1),
28
+ (0, 4, 2),
29
+ (0, 3, 4),
30
+ (0, 1, 3),
31
+ (1, 2, 4),
32
+ (1, 4, 3)))
33
+
34
+ r = np.linspace(0, 255, len(images), dtype=np.uint8)
35
+ g = 255 - r
36
+ b = r - np.linspace(0, 128, len(images), dtype=np.uint8)
37
+ color = np.column_stack((r, g, b))
38
+
39
+ with open(ply_file, "w") as fid:
40
+ print>>fid, "ply"
41
+ print>>fid, "format ascii 1.0"
42
+ print>>fid, "element vertex", len(points3D) * len(images)
43
+ print>>fid, "property float x"
44
+ print>>fid, "property float y"
45
+ print>>fid, "property float z"
46
+ print>>fid, "property uchar red"
47
+ print>>fid, "property uchar green"
48
+ print>>fid, "property uchar blue"
49
+ print>>fid, "element face", len(faces) * len(images)
50
+ print>>fid, "property list uchar int vertex_index"
51
+ print>>fid, "end_header"
52
+
53
+ for image, c in zip(images, color):
54
+ for p3D in (points3D.dot(image.R()) + image.C()):
55
+ print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2]
56
+
57
+ for i in xrange(len(images)):
58
+ for f in (faces + len(points3D) * i):
59
+ print>>fid, "3 {} {} {}".format(*f)
60
+
61
+
62
+ #-------------------------------------------------------------------------------
63
+
64
+ def main(args):
65
+ scene_manager = SceneManager(args.input_folder)
66
+ scene_manager.load_images()
67
+
68
+ images = sorted(scene_manager.images.itervalues(),
69
+ key=lambda image: image.name)
70
+
71
+ save_camera_ply(args.output_file, images, args.scale)
72
+
73
+
74
+ #-------------------------------------------------------------------------------
75
+
76
+ if __name__ == "__main__":
77
+ import argparse
78
+
79
+ parser = argparse.ArgumentParser(
80
+ description="Saves camera positions to a PLY for easy viewing outside "
81
+ "of COLMAP. Currently, camera FoV is not reflected in the output.",
82
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
83
+
84
+ parser.add_argument("input_folder")
85
+ parser.add_argument("output_file")
86
+
87
+ parser.add_argument("--scale", type=float, default=1.,
88
+ help="Scaling factor for the camera mesh.")
89
+
90
+ args = parser.parse_args()
91
+
92
+ main(args)
internal/pycolmap/tools/transform_model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ import numpy as np
5
+
6
+ from pycolmap import Quaternion, SceneManager
7
+
8
+
9
+ #-------------------------------------------------------------------------------
10
+
11
+ def main(args):
12
+ scene_manager = SceneManager(args.input_folder)
13
+ scene_manager.load()
14
+
15
+ # expect each line of input corresponds to one row
16
+ P = np.array([
17
+ map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)])
18
+
19
+ scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3]
20
+
21
+ # get rotation without any global scaling (assuming isotropic scaling)
22
+ scale = np.cbrt(np.linalg.det(P[:,:3]))
23
+ q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale)
24
+
25
+ for image in scene_manager.images.itervalues():
26
+ image.q *= q_old_from_new
27
+ image.tvec = scale * image.tvec - image.R().dot(P[:,3])
28
+
29
+ scene_manager.save(args.output_folder)
30
+
31
+
32
+ #-------------------------------------------------------------------------------
33
+
34
+ if __name__ == "__main__":
35
+ import argparse
36
+
37
+ parser = argparse.ArgumentParser(
38
+ description="Apply a 3x4 transformation matrix to a COLMAP model and "
39
+ "save the result as a new model. Row-major input can be piped in from "
40
+ "a file or entered via the command line.",
41
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
42
+
43
+ parser.add_argument("input_folder")
44
+ parser.add_argument("output_folder")
45
+
46
+ args = parser.parse_args()
47
+
48
+ main(args)
internal/pycolmap/tools/write_camera_track_to_bundler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ import numpy as np
5
+
6
+ from pycolmap import SceneManager
7
+
8
+
9
+ #-------------------------------------------------------------------------------
10
+
11
+ def main(args):
12
+ scene_manager = SceneManager(args.input_folder)
13
+ scene_manager.load_cameras()
14
+ scene_manager.load_images()
15
+
16
+ if args.sort:
17
+ images = sorted(
18
+ scene_manager.images.itervalues(), key=lambda im: im.name)
19
+ else:
20
+ images = scene_manager.images.values()
21
+
22
+ fid = open(args.output_file, "w")
23
+ fid_filenames = open(args.output_file + ".list.txt", "w")
24
+
25
+ print>>fid, "# Bundle file v0.3"
26
+ print>>fid, len(images), 0
27
+
28
+ for image in images:
29
+ print>>fid_filenames, image.name
30
+ camera = scene_manager.cameras[image.camera_id]
31
+ print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0
32
+ R, t = image.R(), image.t
33
+ print>>fid, R[0, 0], R[0, 1], R[0, 2]
34
+ print>>fid, -R[1, 0], -R[1, 1], -R[1, 2]
35
+ print>>fid, -R[2, 0], -R[2, 1], -R[2, 2]
36
+ print>>fid, t[0], -t[1], -t[2]
37
+
38
+ fid.close()
39
+ fid_filenames.close()
40
+
41
+
42
+ #-------------------------------------------------------------------------------
43
+
44
+ if __name__ == "__main__":
45
+ import argparse
46
+
47
+ parser = argparse.ArgumentParser(
48
+ description="Saves the camera positions in the Bundler format. Note "
49
+ "that 3D points are not saved.",
50
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
51
+
52
+ parser.add_argument("input_folder")
53
+ parser.add_argument("output_file")
54
+
55
+ parser.add_argument("--sort", default=False, action="store_true",
56
+ help="sort the images by their filename")
57
+
58
+ args = parser.parse_args()
59
+
60
+ main(args)
internal/pycolmap/tools/write_depthmap_to_ply.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ import imageio
5
+ import numpy as np
6
+ import os
7
+
8
+ from plyfile import PlyData, PlyElement
9
+ from pycolmap import SceneManager
10
+ from scipy.ndimage.interpolation import zoom
11
+
12
+
13
+ #-------------------------------------------------------------------------------
14
+
15
+ def main(args):
16
+ suffix = ".photometric.bin" if args.photometric else ".geometric.bin"
17
+
18
+ image_file = os.path.join(args.dense_folder, "images", args.image_filename)
19
+ depth_file = os.path.join(
20
+ args.dense_folder, args.stereo_folder, "depth_maps",
21
+ args.image_filename + suffix)
22
+ if args.save_normals:
23
+ normals_file = os.path.join(
24
+ args.dense_folder, args.stereo_folder, "normal_maps",
25
+ args.image_filename + suffix)
26
+
27
+ # load camera intrinsics from the COLMAP reconstruction
28
+ scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse"))
29
+ scene_manager.load_cameras()
30
+ scene_manager.load_images()
31
+
32
+ image_id, image = scene_manager.get_image_from_name(args.image_filename)
33
+ camera = scene_manager.cameras[image.camera_id]
34
+ rotation_camera_from_world = image.R()
35
+ camera_center = image.C()
36
+
37
+ # load image, depth map, and normal map
38
+ image = imageio.imread(image_file)
39
+
40
+ with open(depth_file, "rb") as fid:
41
+ w = int("".join(iter(lambda: fid.read(1), "&")))
42
+ h = int("".join(iter(lambda: fid.read(1), "&")))
43
+ c = int("".join(iter(lambda: fid.read(1), "&")))
44
+ depth_map = np.fromfile(fid, np.float32).reshape(h, w)
45
+ if (h, w) != image.shape[:2]:
46
+ depth_map = zoom(
47
+ depth_map,
48
+ (float(image.shape[0]) / h, float(image.shape[1]) / w),
49
+ order=0)
50
+
51
+ if args.save_normals:
52
+ with open(normals_file, "rb") as fid:
53
+ w = int("".join(iter(lambda: fid.read(1), "&")))
54
+ h = int("".join(iter(lambda: fid.read(1), "&")))
55
+ c = int("".join(iter(lambda: fid.read(1), "&")))
56
+ normals = np.fromfile(
57
+ fid, np.float32).reshape(c, h, w).transpose([1, 2, 0])
58
+ if (h, w) != image.shape[:2]:
59
+ normals = zoom(
60
+ normals,
61
+ (float(image.shape[0]) / h, float(image.shape[1]) / w, 1.),
62
+ order=0)
63
+
64
+ if args.min_depth is not None:
65
+ depth_map[depth_map < args.min_depth] = 0.
66
+ if args.max_depth is not None:
67
+ depth_map[depth_map > args.max_depth] = 0.
68
+
69
+ # create 3D points
70
+ #depth_map = np.minimum(depth_map, 100.)
71
+ points3D = np.dstack(camera.get_image_grid() + [depth_map])
72
+ points3D[:,:,:2] *= depth_map[:,:,np.newaxis]
73
+
74
+ # save
75
+ points3D = points3D.astype(np.float32).reshape(-1, 3)
76
+ if args.save_normals:
77
+ normals = normals.astype(np.float32).reshape(-1, 3)
78
+ image = image.reshape(-1, 3)
79
+ if image.dtype != np.uint8:
80
+ if image.max() <= 1:
81
+ image = (image * 255.).astype(np.uint8)
82
+ else:
83
+ image = image.astype(np.uint8)
84
+
85
+ if args.world_space:
86
+ points3D = points3D.dot(rotation_camera_from_world) + camera_center
87
+ if args.save_normals:
88
+ normals = normals.dot(rotation_camera_from_world)
89
+
90
+ if args.save_normals:
91
+ vertices = np.rec.fromarrays(
92
+ tuple(points3D.T) + tuple(normals.T) + tuple(image.T),
93
+ names="x,y,z,nx,ny,nz,red,green,blue")
94
+ else:
95
+ vertices = np.rec.fromarrays(
96
+ tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue")
97
+ vertices = PlyElement.describe(vertices, "vertex")
98
+ PlyData([vertices]).write(args.output_filename)
99
+
100
+
101
+ #-------------------------------------------------------------------------------
102
+
103
+ if __name__ == "__main__":
104
+ import argparse
105
+
106
+ parser = argparse.ArgumentParser(
107
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
108
+
109
+ parser.add_argument("dense_folder", type=str)
110
+ parser.add_argument("image_filename", type=str)
111
+ parser.add_argument("output_filename", type=str)
112
+
113
+ parser.add_argument(
114
+ "--photometric", default=False, action="store_true",
115
+ help="use photometric depthmap instead of geometric")
116
+
117
+ parser.add_argument(
118
+ "--world_space", default=False, action="store_true",
119
+ help="apply the camera->world extrinsic transformation to the result")
120
+
121
+ parser.add_argument(
122
+ "--save_normals", default=False, action="store_true",
123
+ help="load the estimated normal map and save as part of the PLY")
124
+
125
+ parser.add_argument(
126
+ "--stereo_folder", type=str, default="stereo",
127
+ help="folder in the dense workspace containing depth and normal maps")
128
+
129
+ parser.add_argument(
130
+ "--min_depth", type=float, default=None,
131
+ help="set pixels with depth less than this value to zero depth")
132
+
133
+ parser.add_argument(
134
+ "--max_depth", type=float, default=None,
135
+ help="set pixels with depth greater than this value to zero depth")
136
+
137
+ args = parser.parse_args()
138
+
139
+ main(args)
internal/raw_utils.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ from internal import image as lib_image
5
+ from internal import math
6
+ from internal import utils
7
+ import numpy as np
8
+ import rawpy
9
+
10
+
11
+ def postprocess_raw(raw, camtorgb, exposure=None):
12
+ """Converts demosaicked raw to sRGB with a minimal postprocessing pipeline.
13
+
14
+ Args:
15
+ raw: [H, W, 3], demosaicked raw camera image.
16
+ camtorgb: [3, 3], color correction transformation to apply to raw image.
17
+ exposure: color value to be scaled to pure white after color correction.
18
+ If None, "autoexposes" at the 97th percentile.
19
+
20
+ Returns:
21
+ srgb: [H, W, 3], color corrected + exposed + gamma mapped image.
22
+ """
23
+ if raw.shape[-1] != 3:
24
+ raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3')
25
+ if camtorgb.shape != (3, 3):
26
+ raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)')
27
+ # Convert from camera color space to standard linear RGB color space.
28
+ rgb_linear = np.matmul(raw, camtorgb.T)
29
+ if exposure is None:
30
+ exposure = np.percentile(rgb_linear, 97)
31
+ # "Expose" image by mapping the input exposure level to white and clipping.
32
+ rgb_linear_scaled = np.clip(rgb_linear / exposure, 0, 1)
33
+ # Apply sRGB gamma curve to serve as a simple tonemap.
34
+ srgb = lib_image.linear_to_srgb_np(rgb_linear_scaled)
35
+ return srgb
36
+
37
+
38
+ def pixels_to_bayer_mask(pix_x, pix_y):
39
+ """Computes binary RGB Bayer mask values from integer pixel coordinates."""
40
+ # Red is top left (0, 0).
41
+ r = (pix_x % 2 == 0) * (pix_y % 2 == 0)
42
+ # Green is top right (0, 1) and bottom left (1, 0).
43
+ g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1)
44
+ # Blue is bottom right (1, 1).
45
+ b = (pix_x % 2 == 1) * (pix_y % 2 == 1)
46
+ return np.stack([r, g, b], -1).astype(np.float32)
47
+
48
+
49
+ def bilinear_demosaic(bayer):
50
+ """Converts Bayer data into a full RGB image using bilinear demosaicking.
51
+
52
+ Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern:
53
+ -------------
54
+ |red |green|
55
+ -------------
56
+ |green|blue |
57
+ -------------
58
+ Red and blue channels are bilinearly upsampled 2x, missing green channel
59
+ elements are the average of the neighboring 4 values in a cross pattern.
60
+
61
+ Args:
62
+ bayer: [H, W] array, Bayer mosaic pattern input image.
63
+
64
+ Returns:
65
+ rgb: [H, W, 3] array, full RGB image.
66
+ """
67
+
68
+ def reshape_quads(*planes):
69
+ """Reshape pixels from four input images to make tiled 2x2 quads."""
70
+ planes = np.stack(planes, -1)
71
+ shape = planes.shape[:-1]
72
+ # Create [2, 2] arrays out of 4 channels.
73
+ zup = planes.reshape(shape + (2, 2,))
74
+ # Transpose so that x-axis dimensions come before y-axis dimensions.
75
+ zup = np.transpose(zup, (0, 2, 1, 3))
76
+ # Reshape to 2D.
77
+ zup = zup.reshape((shape[0] * 2, shape[1] * 2))
78
+ return zup
79
+
80
+ def bilinear_upsample(z):
81
+ """2x bilinear image upsample."""
82
+ # Using np.roll makes the right and bottom edges wrap around. The raw image
83
+ # data has a few garbage columns/rows at the edges that must be discarded
84
+ # anyway, so this does not matter in practice.
85
+ # Horizontally interpolated values.
86
+ zx = .5 * (z + np.roll(z, -1, axis=-1))
87
+ # Vertically interpolated values.
88
+ zy = .5 * (z + np.roll(z, -1, axis=-2))
89
+ # Diagonally interpolated values.
90
+ zxy = .5 * (zx + np.roll(zx, -1, axis=-2))
91
+ return reshape_quads(z, zx, zy, zxy)
92
+
93
+ def upsample_green(g1, g2):
94
+ """Special 2x upsample from the two green channels."""
95
+ z = np.zeros_like(g1)
96
+ z = reshape_quads(z, g1, g2, z)
97
+ alt = 0
98
+ # Grab the 4 directly adjacent neighbors in a "cross" pattern.
99
+ for i in range(4):
100
+ axis = -1 - (i // 2)
101
+ roll = -1 + 2 * (i % 2)
102
+ alt = alt + .25 * np.roll(z, roll, axis=axis)
103
+ # For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross),
104
+ # so alt + z will have every pixel filled in.
105
+ return alt + z
106
+
107
+ r, g1, g2, b = [bayer[(i // 2)::2, (i % 2)::2] for i in range(4)]
108
+ r = bilinear_upsample(r)
109
+ # Flip in x and y before and after calling upsample, as bilinear_upsample
110
+ # assumes that the samples are at the top-left corner of the 2x2 sample.
111
+ b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1]
112
+ g = upsample_green(g1, g2)
113
+ rgb = np.stack([r, g, b], -1)
114
+ return rgb
115
+
116
+
117
+ def load_raw_images(image_dir, image_names=None):
118
+ """Loads raw images and their metadata from disk.
119
+
120
+ Args:
121
+ image_dir: directory containing raw image and EXIF data.
122
+ image_names: files to load (ignores file extension), loads all DNGs if None.
123
+
124
+ Returns:
125
+ A tuple (images, exifs).
126
+ images: [N, height, width, 3] array of raw sensor data.
127
+ exifs: [N] list of dicts, one per image, containing the EXIF data.
128
+ Raises:
129
+ ValueError: The requested `image_dir` does not exist on disk.
130
+ """
131
+
132
+ if not utils.file_exists(image_dir):
133
+ raise ValueError(f'Raw image folder {image_dir} does not exist.')
134
+
135
+ # Load raw images (dng files) and exif metadata (json files).
136
+ def load_raw_exif(image_name):
137
+ base = os.path.join(image_dir, os.path.splitext(image_name)[0])
138
+ with utils.open_file(base + '.dng', 'rb') as f:
139
+ raw = rawpy.imread(f).raw_image
140
+ with utils.open_file(base + '.json', 'rb') as f:
141
+ exif = json.load(f)[0]
142
+ return raw, exif
143
+
144
+ if image_names is None:
145
+ image_names = [
146
+ os.path.basename(f)
147
+ for f in sorted(glob.glob(os.path.join(image_dir, '*.dng')))
148
+ ]
149
+
150
+ data = [load_raw_exif(x) for x in image_names]
151
+ raws, exifs = zip(*data)
152
+ raws = np.stack(raws, axis=0).astype(np.float32)
153
+
154
+ return raws, exifs
155
+
156
+
157
+ # Brightness percentiles to use for re-exposing and tonemapping raw images.
158
+ _PERCENTILE_LIST = (80, 90, 97, 99, 100)
159
+
160
+ # Relevant fields to extract from raw image EXIF metadata.
161
+ # For details regarding EXIF parameters, see:
162
+ # https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf.
163
+ _EXIF_KEYS = (
164
+ 'BlackLevel', # Black level offset added to sensor measurements.
165
+ 'WhiteLevel', # Maximum possible sensor measurement.
166
+ 'AsShotNeutral', # RGB white balance coefficients.
167
+ 'ColorMatrix2', # XYZ to camera color space conversion matrix.
168
+ 'NoiseProfile', # Shot and read noise levels.
169
+ )
170
+
171
+ # Color conversion from reference illuminant XYZ to RGB color space.
172
+ # See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html.
173
+ _RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375],
174
+ [0.2126729, 0.7151522, 0.0721750],
175
+ [0.0193339, 0.1191920, 0.9503041]])
176
+
177
+
178
+ def process_exif(exifs):
179
+ """Processes list of raw image EXIF data into useful metadata dict.
180
+
181
+ Input should be a list of dictionaries loaded from JSON files.
182
+ These JSON files are produced by running
183
+ $ exiftool -json IMAGE.dng > IMAGE.json
184
+ for each input raw file.
185
+
186
+ We extract only the parameters relevant to
187
+ 1. Rescaling the raw data to [0, 1],
188
+ 2. White balance and color correction, and
189
+ 3. Noise level estimation.
190
+
191
+ Args:
192
+ exifs: a list of dicts containing EXIF data as loaded from JSON files.
193
+
194
+ Returns:
195
+ meta: a dict of the relevant metadata for running RawNeRF.
196
+ """
197
+ meta = {}
198
+ exif = exifs[0]
199
+ # Convert from array of dicts (exifs) to dict of arrays (meta).
200
+ for key in _EXIF_KEYS:
201
+ exif_value = exif.get(key)
202
+ if exif_value is None:
203
+ continue
204
+ # Values can be a single int or float...
205
+ if isinstance(exif_value, int) or isinstance(exif_value, float):
206
+ vals = [x[key] for x in exifs]
207
+ # Or a string of numbers with ' ' between.
208
+ elif isinstance(exif_value, str):
209
+ vals = [[float(z) for z in x[key].split(' ')] for x in exifs]
210
+ meta[key] = np.squeeze(np.array(vals))
211
+ # Shutter speed is a special case, a string written like 1/N.
212
+ meta['ShutterSpeed'] = np.fromiter(
213
+ (1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float)
214
+
215
+ # Create raw-to-sRGB color transform matrices. Pipeline is:
216
+ # cam space -> white balanced cam space ("camwb") -> XYZ space -> RGB space.
217
+ # 'AsShotNeutral' is an RGB triplet representing how pure white would measure
218
+ # on the sensor, so dividing by these numbers corrects the white balance.
219
+ whitebalance = meta['AsShotNeutral'].reshape(-1, 3)
220
+ cam2camwb = np.array([np.diag(1. / x) for x in whitebalance])
221
+ # ColorMatrix2 converts from XYZ color space to "reference illuminant" (white
222
+ # balanced) camera space.
223
+ xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3)
224
+ rgb2camwb = xyz2camwb @ _RGB2XYZ
225
+ # We normalize the rows of the full color correction matrix, as is done in
226
+ # https://github.com/AbdoKamel/simple-camera-pipeline.
227
+ rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True)
228
+ # Combining color correction with white balance gives the entire transform.
229
+ cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb
230
+ meta['cam2rgb'] = cam2rgb
231
+
232
+ return meta
233
+
234
+
235
+ def load_raw_dataset(split, data_dir, image_names, exposure_percentile, n_downsample):
236
+ """Loads and processes a set of RawNeRF input images.
237
+
238
+ Includes logic necessary for special "test" scenes that include a noiseless
239
+ ground truth frame, produced by HDR+ merge.
240
+
241
+ Args:
242
+ split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic.
243
+ data_dir: base directory for scene data.
244
+ image_names: which images were successfully posed by COLMAP.
245
+ exposure_percentile: what brightness percentile to expose to white.
246
+ n_downsample: returned images are downsampled by a factor of n_downsample.
247
+
248
+ Returns:
249
+ A tuple (images, meta, testscene).
250
+ images: [N, height // n_downsample, width // n_downsample, 3] array of
251
+ demosaicked raw image data.
252
+ meta: EXIF metadata and other useful processing parameters. Includes per
253
+ image exposure information that can be passed into the NeRF model with
254
+ each ray: the set of unique exposure times is determined and each image
255
+ assigned a corresponding exposure index (mapping to an exposure value).
256
+ These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in
257
+ the `meta` dictionary.
258
+ We rescale so the maximum `exposure_value` is 1 for convenience.
259
+ testscene: True when dataset includes ground truth test image, else False.
260
+ """
261
+
262
+ image_dir = os.path.join(data_dir, 'raw')
263
+
264
+ testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng')
265
+ testscene = utils.file_exists(testimg_file)
266
+ if testscene:
267
+ # Test scenes have train/ and test/ split subdirectories inside raw/.
268
+ image_dir = os.path.join(image_dir, split.value)
269
+ if split == utils.DataSplit.TEST:
270
+ # COLMAP image names not valid for test split of test scene.
271
+ image_names = None
272
+ else:
273
+ # Discard the first COLMAP image name as it is a copy of the test image.
274
+ image_names = image_names[1:]
275
+
276
+ raws, exifs = load_raw_images(image_dir, image_names)
277
+ meta = process_exif(exifs)
278
+
279
+ if testscene and split == utils.DataSplit.TEST:
280
+ # Test split for test scene must load the "ground truth" HDR+ merged image.
281
+ with utils.open_file(testimg_file, 'rb') as imgin:
282
+ testraw = rawpy.imread(imgin).raw_image
283
+ # HDR+ output has 2 extra bits of fixed precision, need to divide by 4.
284
+ testraw = testraw.astype(np.float32) / 4.
285
+ # Need to rescale long exposure test image by fast:slow shutter speed ratio.
286
+ fast_shutter = meta['ShutterSpeed'][0]
287
+ slow_shutter = meta['ShutterSpeed'][-1]
288
+ shutter_ratio = fast_shutter / slow_shutter
289
+ # Replace loaded raws with the "ground truth" test image.
290
+ raws = testraw[None]
291
+ # Test image shares metadata with the first loaded image (fast exposure).
292
+ meta = {k: meta[k][:1] for k in meta}
293
+ else:
294
+ shutter_ratio = 1.
295
+
296
+ # Next we determine an index for each unique shutter speed in the data.
297
+ shutter_speeds = meta['ShutterSpeed']
298
+ # Sort the shutter speeds from slowest (largest) to fastest (smallest).
299
+ # This way index 0 will always correspond to the brightest image.
300
+ unique_shutters = np.sort(np.unique(shutter_speeds))[::-1]
301
+ exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32)
302
+ for i, shutter in enumerate(unique_shutters):
303
+ # Assign index `i` to all images with shutter speed `shutter`.
304
+ exposure_idx[shutter_speeds == shutter] = i
305
+ meta['exposure_idx'] = exposure_idx
306
+ meta['unique_shutters'] = unique_shutters
307
+ # Rescale to use relative shutter speeds, where 1. is the brightest.
308
+ # This way the NeRF output with exposure=1 will always be reasonable.
309
+ meta['exposure_values'] = shutter_speeds / unique_shutters[0]
310
+
311
+ # Rescale raw sensor measurements to [0, 1] (plus noise).
312
+ blacklevel = meta['BlackLevel'].reshape(-1, 1, 1)
313
+ whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1)
314
+ images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio
315
+
316
+ # Calculate value for exposure level when gamma mapping, defaults to 97%.
317
+ # Always based on full resolution image 0 (for consistency).
318
+ image0_raw_demosaic = np.array(bilinear_demosaic(images[0]))
319
+ image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T
320
+ exposure = np.percentile(image0_rgb, exposure_percentile)
321
+ meta['exposure'] = exposure
322
+ # Sweep over various exposure percentiles to visualize in training logs.
323
+ exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST}
324
+ meta['exposure_levels'] = exposure_levels
325
+
326
+ # Create postprocessing function mapping raw images to tonemapped sRGB space.
327
+ cam2rgb0 = meta['cam2rgb'][0]
328
+ meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x)
329
+
330
+ def processing_fn(x):
331
+ x_ = np.array(x)
332
+ x_demosaic = bilinear_demosaic(x_)
333
+ if n_downsample > 1:
334
+ x_demosaic = lib_image.downsample(x_demosaic, n_downsample)
335
+ return np.array(x_demosaic)
336
+
337
+ images = np.stack([processing_fn(im) for im in images], axis=0)
338
+
339
+ return images, meta, testscene
340
+
341
+
342
+ def best_fit_affine(x, y, axis):
343
+ """Computes best fit a, b such that a * x + b = y, in a least square sense."""
344
+ x_m = x.mean(axis=axis)
345
+ y_m = y.mean(axis=axis)
346
+ xy_m = (x * y).mean(axis=axis)
347
+ xx_m = (x * x).mean(axis=axis)
348
+ # slope a = Cov(x, y) / Cov(x, x).
349
+ a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m)
350
+ b = y_m - a * x_m
351
+ return a, b
352
+
353
+
354
+ def match_images_affine(est, gt, axis=(0, 1)):
355
+ """Computes affine best fit of gt->est, then maps est back to match gt."""
356
+ # Mapping is computed gt->est to be robust since `est` may be very noisy.
357
+ a, b = best_fit_affine(gt, est, axis=axis)
358
+ # Inverse mapping back to gt ensures we use a consistent space for metrics.
359
+ est_matched = (est - b) / a
360
+ return est_matched
internal/ref_utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from internal import math
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def reflect(viewdirs, normals):
7
+ """Reflect view directions about normals.
8
+
9
+ The reflection of a vector v about a unit vector n is a vector u such that
10
+ dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two
11
+ equations is u = 2 dot(n, v) n - v.
12
+
13
+ Args:
14
+ viewdirs: [..., 3] array of view directions.
15
+ normals: [..., 3] array of normal directions (assumed to be unit vectors).
16
+
17
+ Returns:
18
+ [..., 3] array of reflection directions.
19
+ """
20
+ return 2.0 * torch.sum(normals * viewdirs, dim=-1, keepdim=True) * normals - viewdirs
21
+
22
+
23
+ def l2_normalize(x):
24
+ """Normalize x to unit length along last axis."""
25
+ return torch.nn.functional.normalize(x, dim=-1, eps=torch.finfo(x.dtype).eps)
26
+
27
+
28
+ def l2_normalize_np(x):
29
+ """Normalize x to unit length along last axis."""
30
+ return x / np.sqrt(np.maximum(np.sum(x ** 2, axis=-1, keepdims=True), np.finfo(x.dtype).eps))
31
+
32
+
33
+ def compute_weighted_mae(weights, normals, normals_gt):
34
+ """Compute weighted mean angular error, assuming normals are unit length."""
35
+ one_eps = 1 - torch.finfo(weights.dtype).eps
36
+ return (weights * torch.arccos(torch.clip((normals * normals_gt).sum(-1),
37
+ -one_eps, one_eps))).sum() / weights.sum() * 180.0 / torch.pi
38
+
39
+
40
+ def compute_weighted_mae_np(weights, normals, normals_gt):
41
+ """Compute weighted mean angular error, assuming normals are unit length."""
42
+ one_eps = 1 - np.finfo(weights.dtype).eps
43
+ return (weights * np.arccos(np.clip((normals * normals_gt).sum(-1),
44
+ -one_eps, one_eps))).sum() / weights.sum() * 180.0 / np.pi
45
+
46
+
47
+ def generalized_binomial_coeff(a, k):
48
+ """Compute generalized binomial coefficients."""
49
+ return np.prod(a - np.arange(k)) / np.math.factorial(k)
50
+
51
+
52
+ def assoc_legendre_coeff(l, m, k):
53
+ """Compute associated Legendre polynomial coefficients.
54
+
55
+ Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the
56
+ (l, m)th associated Legendre polynomial, P_l^m(cos(theta)).
57
+
58
+ Args:
59
+ l: associated Legendre polynomial degree.
60
+ m: associated Legendre polynomial order.
61
+ k: power of cos(theta).
62
+
63
+ Returns:
64
+ A float, the coefficient of the term corresponding to the inputs.
65
+ """
66
+ return ((-1) ** m * 2 ** l * np.math.factorial(l) / np.math.factorial(k) /
67
+ np.math.factorial(l - k - m) *
68
+ generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l))
69
+
70
+
71
+ def sph_harm_coeff(l, m, k):
72
+ """Compute spherical harmonic coefficients."""
73
+ return (np.sqrt(
74
+ (2.0 * l + 1.0) * np.math.factorial(l - m) /
75
+ (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k))
76
+
77
+
78
+ def get_ml_array(deg_view):
79
+ """Create a list with all pairs of (l, m) values to use in the encoding."""
80
+ ml_list = []
81
+ for i in range(deg_view):
82
+ l = 2 ** i
83
+ # Only use nonnegative m values, later splitting real and imaginary parts.
84
+ for m in range(l + 1):
85
+ ml_list.append((m, l))
86
+
87
+ # Convert list into a numpy array.
88
+ ml_array = np.array(ml_list).T
89
+ return ml_array
90
+
91
+
92
+ def generate_ide_fn(deg_view):
93
+ """Generate integrated directional encoding (IDE) function.
94
+
95
+ This function returns a function that computes the integrated directional
96
+ encoding from Equations 6-8 of arxiv.org/abs/2112.03907.
97
+
98
+ Args:
99
+ deg_view: number of spherical harmonics degrees to use.
100
+
101
+ Returns:
102
+ A function for evaluating integrated directional encoding.
103
+
104
+ Raises:
105
+ ValueError: if deg_view is larger than 5.
106
+ """
107
+ if deg_view > 5:
108
+ raise ValueError('Only deg_view of at most 5 is numerically stable.')
109
+
110
+ ml_array = get_ml_array(deg_view)
111
+ l_max = 2 ** (deg_view - 1)
112
+
113
+ # Create a matrix corresponding to ml_array holding all coefficients, which,
114
+ # when multiplied (from the right) by the z coordinate Vandermonde matrix,
115
+ # results in the z component of the encoding.
116
+ mat = np.zeros((l_max + 1, ml_array.shape[1]))
117
+ for i, (m, l) in enumerate(ml_array.T):
118
+ for k in range(l - m + 1):
119
+ mat[k, i] = sph_harm_coeff(l, m, k)
120
+ mat = torch.from_numpy(mat).float()
121
+ ml_array = torch.from_numpy(ml_array).float()
122
+
123
+ def integrated_dir_enc_fn(xyz, kappa_inv):
124
+ """Function returning integrated directional encoding (IDE).
125
+
126
+ Args:
127
+ xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at.
128
+ kappa_inv: [..., 1] reciprocal of the concentration parameter of the von
129
+ Mises-Fisher distribution.
130
+
131
+ Returns:
132
+ An array with the resulting IDE.
133
+ """
134
+ x = xyz[..., 0:1]
135
+ y = xyz[..., 1:2]
136
+ z = xyz[..., 2:3]
137
+
138
+ # Compute z Vandermonde matrix.
139
+ vmz = torch.cat([z ** i for i in range(mat.shape[0])], dim=-1)
140
+
141
+ # Compute x+iy Vandermonde matrix.
142
+ vmxy = torch.cat([(x + 1j * y) ** m for m in ml_array[0, :]], dim=-1)
143
+
144
+ # Get spherical harmonics.
145
+ sph_harms = vmxy * math.matmul(vmz, mat.to(xyz.device))
146
+
147
+ # Apply attenuation function using the von Mises-Fisher distribution
148
+ # concentration parameter, kappa.
149
+ sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1)
150
+ sigma = sigma.to(sph_harms.device)
151
+ ide = sph_harms * torch.exp(-sigma * kappa_inv)
152
+
153
+ # Split into real and imaginary parts and return
154
+ return torch.cat([torch.real(ide), torch.imag(ide)], dim=-1)
155
+
156
+ return integrated_dir_enc_fn
157
+
158
+
159
+ def generate_dir_enc_fn(deg_view):
160
+ """Generate directional encoding (DE) function.
161
+
162
+ Args:
163
+ deg_view: number of spherical harmonics degrees to use.
164
+
165
+ Returns:
166
+ A function for evaluating directional encoding.
167
+ """
168
+ integrated_dir_enc_fn = generate_ide_fn(deg_view)
169
+
170
+ def dir_enc_fn(xyz):
171
+ """Function returning directional encoding (DE)."""
172
+ return integrated_dir_enc_fn(xyz, torch.zeros_like(xyz[..., :1]))
173
+
174
+ return dir_enc_fn
internal/render.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ from internal import stepfun
4
+ from internal import math
5
+ from internal import utils
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def lift_gaussian(d, t_mean, t_var, r_var, diag):
11
+ """Lift a Gaussian defined along a ray to 3D coordinates."""
12
+ mean = d[..., None, :] * t_mean[..., None]
13
+ eps = torch.finfo(d.dtype).eps
14
+ # eps = 1e-3
15
+ d_mag_sq = torch.sum(d ** 2, dim=-1, keepdim=True).clamp_min(eps)
16
+
17
+ if diag:
18
+ d_outer_diag = d ** 2
19
+ null_outer_diag = 1 - d_outer_diag / d_mag_sq
20
+ t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
21
+ xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
22
+ cov_diag = t_cov_diag + xy_cov_diag
23
+ return mean, cov_diag
24
+ else:
25
+ d_outer = d[..., :, None] * d[..., None, :]
26
+ eye = torch.eye(d.shape[-1], device=d.device)
27
+ null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
28
+ t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
29
+ xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
30
+ cov = t_cov + xy_cov
31
+ return mean, cov
32
+
33
+
34
+ def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True):
35
+ """Approximate a conical frustum as a Gaussian distribution (mean+cov).
36
+
37
+ Assumes the ray is originating from the origin, and base_radius is the
38
+ radius at dist=1. Doesn't assume `d` is normalized.
39
+
40
+ Args:
41
+ d: the axis of the cone
42
+ t0: the starting distance of the frustum.
43
+ t1: the ending distance of the frustum.
44
+ base_radius: the scale of the radius as a function of distance.
45
+ diag: whether or the Gaussian will be diagonal or full-covariance.
46
+ stable: whether or not to use the stable computation described in
47
+ the paper (setting this to False will cause catastrophic failure).
48
+
49
+ Returns:
50
+ a Gaussian (mean and covariance).
51
+ """
52
+ if stable:
53
+ # Equation 7 in the paper (https://arxiv.org/abs/2103.13415).
54
+ mu = (t0 + t1) / 2 # The average of the two `t` values.
55
+ hw = (t1 - t0) / 2 # The half-width of the two `t` values.
56
+ eps = torch.finfo(d.dtype).eps
57
+ # eps = 1e-3
58
+ t_mean = mu + (2 * mu * hw ** 2) / (3 * mu ** 2 + hw ** 2).clamp_min(eps)
59
+ denom = (3 * mu ** 2 + hw ** 2).clamp_min(eps)
60
+ t_var = (hw ** 2) / 3 - (4 / 15) * hw ** 4 * (12 * mu ** 2 - hw ** 2) / denom ** 2
61
+ r_var = (mu ** 2) / 4 + (5 / 12) * hw ** 2 - (4 / 15) * (hw ** 4) / denom
62
+ else:
63
+ # Equations 37-39 in the paper.
64
+ t_mean = (3 * (t1 ** 4 - t0 ** 4)) / (4 * (t1 ** 3 - t0 ** 3))
65
+ r_var = 3 / 20 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3)
66
+ t_mosq = 3 / 5 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3)
67
+ t_var = t_mosq - t_mean ** 2
68
+ r_var *= base_radius ** 2
69
+ return lift_gaussian(d, t_mean, t_var, r_var, diag)
70
+
71
+
72
+ def cylinder_to_gaussian(d, t0, t1, radius, diag):
73
+ """Approximate a cylinder as a Gaussian distribution (mean+cov).
74
+
75
+ Assumes the ray is originating from the origin, and radius is the
76
+ radius. Does not renormalize `d`.
77
+
78
+ Args:
79
+ d: the axis of the cylinder
80
+ t0: the starting distance of the cylinder.
81
+ t1: the ending distance of the cylinder.
82
+ radius: the radius of the cylinder
83
+ diag: whether or the Gaussian will be diagonal or full-covariance.
84
+
85
+ Returns:
86
+ a Gaussian (mean and covariance).
87
+ """
88
+ t_mean = (t0 + t1) / 2
89
+ r_var = radius ** 2 / 4
90
+ t_var = (t1 - t0) ** 2 / 12
91
+ return lift_gaussian(d, t_mean, t_var, r_var, diag)
92
+
93
+
94
+ def cast_rays(tdist, origins, directions, cam_dirs, radii, rand=True, n=7, m=3, std_scale=0.5, **kwargs):
95
+ """Cast rays (cone- or cylinder-shaped) and featurize sections of it.
96
+
97
+ Args:
98
+ tdist: float array, the "fencepost" distances along the ray.
99
+ origins: float array, the ray origin coordinates.
100
+ directions: float array, the ray direction vectors.
101
+ radii: float array, the radii (base radii for cones) of the rays.
102
+ ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
103
+ diag: boolean, whether or not the covariance matrices should be diagonal.
104
+
105
+ Returns:
106
+ a tuple of arrays of means and covariances.
107
+ """
108
+ t0 = tdist[..., :-1, None]
109
+ t1 = tdist[..., 1:, None]
110
+ radii = radii[..., None]
111
+
112
+ t_m = (t0 + t1) / 2
113
+ t_d = (t1 - t0) / 2
114
+
115
+ j = torch.arange(6, device=tdist.device)
116
+ t = t0 + t_d / (t_d ** 2 + 3 * t_m ** 2) * (t1 ** 2 + 2 * t_m ** 2 + 3 / 7 ** 0.5 * (2 * j / 5 - 1) * (
117
+ (t_d ** 2 - t_m ** 2) ** 2 + 4 * t_m ** 4).sqrt())
118
+
119
+ deg = torch.pi / 3 * torch.tensor([0, 2, 4, 3, 5, 1], device=tdist.device, dtype=torch.float)
120
+ deg = torch.broadcast_to(deg, t.shape)
121
+ if rand:
122
+ # randomly rotate and flip
123
+ mask = torch.rand_like(t0[..., 0]) > 0.5
124
+ deg = deg + 2 * torch.pi * torch.rand_like(deg[..., 0])[..., None]
125
+ deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg)
126
+ else:
127
+ # rotate 30 degree and flip every other pattern
128
+ mask = torch.arange(t.shape[-2], device=tdist.device) % 2 == 0
129
+ mask = torch.broadcast_to(mask, t.shape[:-1])
130
+ deg = torch.where(mask[..., None], deg, deg + torch.pi / 6)
131
+ deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg)
132
+ means = torch.stack([
133
+ radii * t * torch.cos(deg) / 2 ** 0.5,
134
+ radii * t * torch.sin(deg) / 2 ** 0.5,
135
+ t
136
+ ], dim=-1)
137
+ stds = std_scale * radii * t / 2 ** 0.5
138
+
139
+ # two basis in parallel to the image plane
140
+ rand_vec = torch.randn_like(cam_dirs)
141
+ ortho1 = F.normalize(torch.cross(cam_dirs, rand_vec, dim=-1), dim=-1)
142
+ ortho2 = F.normalize(torch.cross(cam_dirs, ortho1, dim=-1), dim=-1)
143
+
144
+ # just use directions to be the third vector of the orthonormal basis,
145
+ # while the cross section of cone is parallel to the image plane
146
+ basis_matrix = torch.stack([ortho1, ortho2, directions], dim=-1)
147
+ means = math.matmul(means, basis_matrix[..., None, :, :].transpose(-1, -2))
148
+ means = means + origins[..., None, None, :]
149
+ # import trimesh
150
+ # trimesh.Trimesh(means.reshape(-1, 3).detach().cpu().numpy()).export("test.ply", "ply")
151
+
152
+ return means, stds, t
153
+
154
+
155
+ def compute_alpha_weights(density, tdist, dirs, opaque_background=False):
156
+ """Helper function for computing alpha compositing weights."""
157
+ t_delta = tdist[..., 1:] - tdist[..., :-1]
158
+ delta = t_delta * torch.norm(dirs[..., None, :], dim=-1)
159
+ density_delta = density * delta
160
+
161
+ if opaque_background:
162
+ # Equivalent to making the final t-interval infinitely wide.
163
+ density_delta = torch.cat([
164
+ density_delta[..., :-1],
165
+ torch.full_like(density_delta[..., -1:], torch.inf)
166
+ ], dim=-1)
167
+
168
+ alpha = 1 - torch.exp(-density_delta)
169
+ trans = torch.exp(-torch.cat([
170
+ torch.zeros_like(density_delta[..., :1]),
171
+ torch.cumsum(density_delta[..., :-1], dim=-1)
172
+ ], dim=-1))
173
+ weights = alpha * trans
174
+ return weights, alpha, trans
175
+
176
+
177
+ def volumetric_rendering(rgbs,
178
+ weights,
179
+ tdist,
180
+ bg_rgbs,
181
+ t_far,
182
+ compute_extras,
183
+ extras=None):
184
+ """Volumetric Rendering Function.
185
+
186
+ Args:
187
+ rgbs: color, [batch_size, num_samples, 3]
188
+ weights: weights, [batch_size, num_samples].
189
+ tdist: [batch_size, num_samples].
190
+ bg_rgbs: the color(s) to use for the background.
191
+ t_far: [batch_size, 1], the distance of the far plane.
192
+ compute_extras: bool, if True, compute extra quantities besides color.
193
+ extras: dict, a set of values along rays to render by alpha compositing.
194
+
195
+ Returns:
196
+ rendering: a dict containing an rgb image of size [batch_size, 3], and other
197
+ visualizations if compute_extras=True.
198
+ """
199
+ eps = torch.finfo(rgbs.dtype).eps
200
+ # eps = 1e-3
201
+ rendering = {}
202
+
203
+ acc = weights.sum(dim=-1)
204
+ bg_w = (1 - acc[..., None]).clamp_min(0.) # The weight of the background.
205
+ rgb = (weights[..., None] * rgbs).sum(dim=-2) + bg_w * bg_rgbs
206
+ t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:])
207
+ depth = (
208
+ torch.clip(
209
+ torch.nan_to_num((weights * t_mids).sum(dim=-1) / acc.clamp_min(eps), torch.inf),
210
+ tdist[..., 0], tdist[..., -1]))
211
+
212
+ rendering['rgb'] = rgb
213
+ rendering['depth'] = depth
214
+ rendering['acc'] = acc
215
+
216
+ if compute_extras:
217
+ if extras is not None:
218
+ for k, v in extras.items():
219
+ if v is not None:
220
+ rendering[k] = (weights[..., None] * v).sum(dim=-2)
221
+
222
+ expectation = lambda x: (weights * x).sum(dim=-1) / acc.clamp_min(eps)
223
+ # For numerical stability this expectation is computing using log-distance.
224
+ rendering['distance_mean'] = (
225
+ torch.clip(
226
+ torch.nan_to_num(torch.exp(expectation(torch.log(t_mids))), torch.inf),
227
+ tdist[..., 0], tdist[..., -1]))
228
+
229
+ # Add an extra fencepost with the far distance at the end of each ray, with
230
+ # whatever weight is needed to make the new weight vector sum to exactly 1
231
+ # (`weights` is only guaranteed to sum to <= 1, not == 1).
232
+ t_aug = torch.cat([tdist, t_far], dim=-1)
233
+ weights_aug = torch.cat([weights, bg_w], dim=-1)
234
+
235
+ ps = [5, 50, 95]
236
+ distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps)
237
+
238
+ for i, p in enumerate(ps):
239
+ s = 'median' if p == 50 else 'percentile_' + str(p)
240
+ rendering['distance_' + s] = distance_percentiles[..., i]
241
+
242
+ return rendering
internal/stepfun.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from internal import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def searchsorted(a, v):
7
+ """Find indices where v should be inserted into a to maintain order.
8
+
9
+ Args:
10
+ a: tensor, the sorted reference points that we are scanning to see where v
11
+ should lie.
12
+ v: tensor, the query points that we are pretending to insert into a. Does
13
+ not need to be sorted. All but the last dimensions should match or expand
14
+ to those of a, the last dimension can differ.
15
+
16
+ Returns:
17
+ (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the
18
+ range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or
19
+ last index of a.
20
+ """
21
+ i = torch.arange(a.shape[-1], device=a.device)
22
+ v_ge_a = v[..., None, :] >= a[..., :, None]
23
+ idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values
24
+ idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values
25
+ return idx_lo, idx_hi
26
+
27
+
28
+ def query(tq, t, y, outside_value=0):
29
+ """Look up the values of the step function (t, y) at locations tq."""
30
+ idx_lo, idx_hi = searchsorted(t, tq)
31
+ yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value),
32
+ torch.take_along_dim(y, idx_lo, dim=-1))
33
+ return yq
34
+
35
+
36
+ def inner_outer(t0, t1, y1):
37
+ """Construct inner and outer measures on (t1, y1) for t0."""
38
+ cy1 = torch.cat([torch.zeros_like(y1[..., :1]),
39
+ torch.cumsum(y1, dim=-1)],
40
+ dim=-1)
41
+ idx_lo, idx_hi = searchsorted(t1, t0)
42
+
43
+ cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1)
44
+ cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1)
45
+
46
+ y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1]
47
+ y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:],
48
+ cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:]))
49
+ return y0_inner, y0_outer
50
+
51
+
52
+ def lossfun_outer(t, w, t_env, w_env):
53
+ """The proposal weight should be an upper envelope on the nerf weight."""
54
+ eps = torch.finfo(t.dtype).eps
55
+ # eps = 1e-3
56
+
57
+ _, w_outer = inner_outer(t, t_env, w_env)
58
+ # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's
59
+ # more effective to pull w_outer up than it is to push w_inner down.
60
+ # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0.
61
+ return (w - w_outer).clamp_min(0) ** 2 / (w + eps)
62
+
63
+
64
+ def weight_to_pdf(t, w):
65
+ """Turn a vector of weights that sums to 1 into a PDF that integrates to 1."""
66
+ eps = torch.finfo(t.dtype).eps
67
+ return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps)
68
+
69
+
70
+ def pdf_to_weight(t, p):
71
+ """Turn a PDF that integrates to 1 into a vector of weights that sums to 1."""
72
+ return p * (t[..., 1:] - t[..., :-1])
73
+
74
+
75
+ def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)):
76
+ """Dilate (via max-pooling) a non-negative step function."""
77
+ t0 = t[..., :-1] - dilation
78
+ t1 = t[..., 1:] + dilation
79
+ t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1)
80
+ t_dilate = torch.clip(t_dilate, *domain)
81
+ w_dilate = torch.max(
82
+ torch.where(
83
+ (t0[..., None, :] <= t_dilate[..., None])
84
+ & (t1[..., None, :] > t_dilate[..., None]),
85
+ w[..., None, :],
86
+ torch.zeros_like(w[..., None, :]),
87
+ ), dim=-1).values[..., :-1]
88
+ return t_dilate, w_dilate
89
+
90
+
91
+ def max_dilate_weights(t,
92
+ w,
93
+ dilation,
94
+ domain=(-torch.inf, torch.inf),
95
+ renormalize=False):
96
+ """Dilate (via max-pooling) a set of weights."""
97
+ eps = torch.finfo(w.dtype).eps
98
+ # eps = 1e-3
99
+
100
+ p = weight_to_pdf(t, w)
101
+ t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain)
102
+ w_dilate = pdf_to_weight(t_dilate, p_dilate)
103
+ if renormalize:
104
+ w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps)
105
+ return t_dilate, w_dilate
106
+
107
+
108
+ def integrate_weights(w):
109
+ """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
110
+
111
+ The output's size on the last dimension is one greater than that of the input,
112
+ because we're computing the integral corresponding to the endpoints of a step
113
+ function, not the integral of the interior/bin values.
114
+
115
+ Args:
116
+ w: Tensor, which will be integrated along the last axis. This is assumed to
117
+ sum to 1 along the last axis, and this function will (silently) break if
118
+ that is not the case.
119
+
120
+ Returns:
121
+ cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
122
+ """
123
+ cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1)
124
+ shape = cw.shape[:-1] + (1,)
125
+ # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
126
+ cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw,
127
+ torch.ones(shape, device=cw.device)], dim=-1)
128
+ return cw0
129
+
130
+
131
+ def integrate_weights_np(w):
132
+ """Compute the cumulative sum of w, assuming all weight vectors sum to 1.
133
+
134
+ The output's size on the last dimension is one greater than that of the input,
135
+ because we're computing the integral corresponding to the endpoints of a step
136
+ function, not the integral of the interior/bin values.
137
+
138
+ Args:
139
+ w: Tensor, which will be integrated along the last axis. This is assumed to
140
+ sum to 1 along the last axis, and this function will (silently) break if
141
+ that is not the case.
142
+
143
+ Returns:
144
+ cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
145
+ """
146
+ cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
147
+ shape = cw.shape[:-1] + (1,)
148
+ # Ensure that the CDF starts with exactly 0 and ends with exactly 1.
149
+ cw0 = np.concatenate([np.zeros(shape), cw,
150
+ np.ones(shape)], axis=-1)
151
+ return cw0
152
+
153
+
154
+ def invert_cdf(u, t, w_logits):
155
+ """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
156
+ # Compute the PDF and CDF for each weight vector.
157
+ w = torch.softmax(w_logits, dim=-1)
158
+ cw = integrate_weights(w)
159
+ # Interpolate into the inverse CDF.
160
+ t_new = math.sorted_interp(u, cw, t)
161
+ return t_new
162
+
163
+
164
+ def invert_cdf_np(u, t, w_logits):
165
+ """Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
166
+ # Compute the PDF and CDF for each weight vector.
167
+ w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True)
168
+ cw = integrate_weights_np(w)
169
+ # Interpolate into the inverse CDF.
170
+ interp_fn = np.interp
171
+ t_new = interp_fn(u, cw, t)
172
+ return t_new
173
+
174
+
175
+ def sample(rand,
176
+ t,
177
+ w_logits,
178
+ num_samples,
179
+ single_jitter=False,
180
+ deterministic_center=False):
181
+ """Piecewise-Constant PDF sampling from a step function.
182
+
183
+ Args:
184
+ rand: random number generator (or None for `linspace` sampling).
185
+ t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
186
+ w_logits: [..., num_bins], logits corresponding to bin weights
187
+ num_samples: int, the number of samples.
188
+ single_jitter: bool, if True, jitter every sample along each ray by the same
189
+ amount in the inverse CDF. Otherwise, jitter each sample independently.
190
+ deterministic_center: bool, if False, when `rand` is None return samples that
191
+ linspace the entire PDF. If True, skip the front and back of the linspace
192
+ so that the centers of each PDF interval are returned.
193
+
194
+ Returns:
195
+ t_samples: [batch_size, num_samples].
196
+ """
197
+ eps = torch.finfo(t.dtype).eps
198
+ # eps = 1e-3
199
+
200
+ device = t.device
201
+
202
+ # Draw uniform samples.
203
+ if not rand:
204
+ if deterministic_center:
205
+ pad = 1 / (2 * num_samples)
206
+ u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device)
207
+ else:
208
+ u = torch.linspace(0, 1. - eps, num_samples, device=device)
209
+ u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
210
+ else:
211
+ # `u` is in [0, 1) --- it can be zero, but it can never be 1.
212
+ u_max = eps + (1 - eps) / num_samples
213
+ max_jitter = (1 - u_max) / (num_samples - 1) - eps
214
+ d = 1 if single_jitter else num_samples
215
+ u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \
216
+ torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter
217
+
218
+ return invert_cdf(u, t, w_logits)
219
+
220
+
221
+ def sample_np(rand,
222
+ t,
223
+ w_logits,
224
+ num_samples,
225
+ single_jitter=False,
226
+ deterministic_center=False):
227
+ """
228
+ numpy version of sample()
229
+ """
230
+ eps = np.finfo(np.float32).eps
231
+
232
+ # Draw uniform samples.
233
+ if not rand:
234
+ if deterministic_center:
235
+ pad = 1 / (2 * num_samples)
236
+ u = np.linspace(pad, 1. - pad - eps, num_samples)
237
+ else:
238
+ u = np.linspace(0, 1. - eps, num_samples)
239
+ u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
240
+ else:
241
+ # `u` is in [0, 1) --- it can be zero, but it can never be 1.
242
+ u_max = eps + (1 - eps) / num_samples
243
+ max_jitter = (1 - u_max) / (num_samples - 1) - eps
244
+ d = 1 if single_jitter else num_samples
245
+ u = np.linspace(0, 1 - u_max, num_samples) + \
246
+ np.random.rand(*t.shape[:-1], d) * max_jitter
247
+
248
+ return invert_cdf_np(u, t, w_logits)
249
+
250
+
251
+ def sample_intervals(rand,
252
+ t,
253
+ w_logits,
254
+ num_samples,
255
+ single_jitter=False,
256
+ domain=(-torch.inf, torch.inf)):
257
+ """Sample *intervals* (rather than points) from a step function.
258
+
259
+ Args:
260
+ rand: random number generator (or None for `linspace` sampling).
261
+ t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
262
+ w_logits: [..., num_bins], logits corresponding to bin weights
263
+ num_samples: int, the number of intervals to sample.
264
+ single_jitter: bool, if True, jitter every sample along each ray by the same
265
+ amount in the inverse CDF. Otherwise, jitter each sample independently.
266
+ domain: (minval, maxval), the range of valid values for `t`.
267
+
268
+ Returns:
269
+ t_samples: [batch_size, num_samples].
270
+ """
271
+ if num_samples <= 1:
272
+ raise ValueError(f'num_samples must be > 1, is {num_samples}.')
273
+
274
+ # Sample a set of points from the step function.
275
+ centers = sample(
276
+ rand,
277
+ t,
278
+ w_logits,
279
+ num_samples,
280
+ single_jitter,
281
+ deterministic_center=True)
282
+
283
+ # The intervals we return will span the midpoints of each adjacent sample.
284
+ mid = (centers[..., 1:] + centers[..., :-1]) / 2
285
+
286
+ # Each first/last fencepost is the reflection of the first/last midpoint
287
+ # around the first/last sampled center. We clamp to the limits of the input
288
+ # domain, provided by the caller.
289
+ minval, maxval = domain
290
+ first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval)
291
+ last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval)
292
+
293
+ t_samples = torch.cat([first, mid, last], dim=-1)
294
+ return t_samples
295
+
296
+
297
+ def lossfun_distortion(t, w):
298
+ """Compute iint w[i] w[j] |t[i] - t[j]| di dj."""
299
+ # The loss incurred between all pairs of intervals.
300
+ ut = (t[..., 1:] + t[..., :-1]) / 2
301
+ dut = torch.abs(ut[..., :, None] - ut[..., None, :])
302
+ loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)
303
+
304
+ # The loss incurred within each individual interval with itself.
305
+ loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3
306
+
307
+ return loss_inter + loss_intra
308
+
309
+
310
+ def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi):
311
+ """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi])."""
312
+ # Distortion when the intervals do not overlap.
313
+ d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2)
314
+
315
+ # Distortion when the intervals overlap.
316
+ d_overlap = (2 *
317
+ (torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) +
318
+ 3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) +
319
+ t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo *
320
+ (t0_lo - t1_hi) + t1_lo * t0_hi *
321
+ (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo))
322
+
323
+ # Are the two intervals not overlapping?
324
+ are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi)
325
+
326
+ return torch.where(are_disjoint, d_disjoint, d_overlap)
327
+
328
+
329
+ def weighted_percentile(t, w, ps):
330
+ """Compute the weighted percentiles of a step function. w's must sum to 1."""
331
+ cw = integrate_weights(w)
332
+ # We want to interpolate into the integrated weights according to `ps`.
333
+ fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i)
334
+ # Vmap fn to an arbitrary number of leading dimensions.
335
+ cw_mat = cw.reshape([-1, cw.shape[-1]])
336
+ t_mat = t.reshape([-1, t.shape[-1]])
337
+ wprctile_mat = fn(cw_mat, t_mat) # TODO
338
+ wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),))
339
+ return wprctile
340
+
341
+
342
+ def resample(t, tp, vp, use_avg=False):
343
+ """Resample a step function defined by (tp, vp) into intervals t.
344
+
345
+ Args:
346
+ t: tensor with shape (..., n+1), the endpoints to resample into.
347
+ tp: tensor with shape (..., m+1), the endpoints of the step function being
348
+ resampled.
349
+ vp: tensor with shape (..., m), the values of the step function being
350
+ resampled.
351
+ use_avg: bool, if False, return the sum of the step function for each
352
+ interval in `t`. If True, return the average, weighted by the width of
353
+ each interval in `t`.
354
+ eps: float, a small value to prevent division by zero when use_avg=True.
355
+
356
+ Returns:
357
+ v: tensor with shape (..., n), the values of the resampled step function.
358
+ """
359
+ eps = torch.finfo(t.dtype).eps
360
+ # eps = 1e-3
361
+
362
+ if use_avg:
363
+ wp = torch.diff(tp, dim=-1)
364
+ v_numer = resample(t, tp, vp * wp, use_avg=False)
365
+ v_denom = resample(t, tp, wp, use_avg=False)
366
+ v = v_numer / v_denom.clamp_min(eps)
367
+ return v
368
+
369
+ acc = torch.cumsum(vp, dim=-1)
370
+ acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1)
371
+ acc0_resampled = math.sorted_interp(t, tp, acc0) # TODO
372
+ v = torch.diff(acc0_resampled, dim=-1)
373
+ return v
374
+
375
+
376
+ def resample_np(t, tp, vp, use_avg=False):
377
+ """
378
+ numpy version of resample
379
+ """
380
+ eps = np.finfo(t.dtype).eps
381
+ if use_avg:
382
+ wp = np.diff(tp, axis=-1)
383
+ v_numer = resample_np(t, tp, vp * wp, use_avg=False)
384
+ v_denom = resample_np(t, tp, wp, use_avg=False)
385
+ v = v_numer / np.maximum(eps, v_denom)
386
+ return v
387
+
388
+ acc = np.cumsum(vp, axis=-1)
389
+ acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1)
390
+ acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0)
391
+ v = np.diff(acc0_resampled, axis=-1)
392
+ return v
393
+
394
+
395
+ def blur_stepfun(x, y, r):
396
+ xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1))
397
+ y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) -
398
+ torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r)
399
+ y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1)
400
+ yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) *
401
+ torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0)
402
+ yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1)
403
+ return xr, yr
internal/train_utils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import functools
3
+
4
+ import torch.optim
5
+ from internal import camera_utils
6
+ from internal import configs
7
+ from internal import datasets
8
+ from internal import image
9
+ from internal import math
10
+ from internal import models
11
+ from internal import ref_utils
12
+ from internal import stepfun
13
+ from internal import utils
14
+ import numpy as np
15
+ from torch.utils._pytree import tree_map, tree_flatten
16
+ from torch_scatter import segment_coo
17
+
18
+
19
+ class GradientScaler(torch.autograd.Function):
20
+ @staticmethod
21
+ def forward(ctx, colors, sigmas, ray_dist):
22
+ ctx.save_for_backward(ray_dist)
23
+ return colors, sigmas
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output_colors, grad_output_sigmas):
27
+ (ray_dist,) = ctx.saved_tensors
28
+ scaling = torch.square(ray_dist).clamp(0, 1)
29
+ return grad_output_colors * scaling[..., None], grad_output_sigmas * scaling, None
30
+
31
+
32
+ def tree_reduce(fn, tree, initializer=0):
33
+ return functools.reduce(fn, tree_flatten(tree)[0], initializer)
34
+
35
+
36
+ def tree_sum(tree):
37
+ return tree_reduce(lambda x, y: x + y, tree, initializer=0)
38
+
39
+
40
+ def tree_norm_sq(tree):
41
+ return tree_sum(tree_map(lambda x: torch.sum(x ** 2), tree))
42
+
43
+
44
+ def tree_norm(tree):
45
+ return torch.sqrt(tree_norm_sq(tree))
46
+
47
+
48
+ def tree_abs_max(tree):
49
+ return tree_reduce(
50
+ lambda x, y: max(x, torch.abs(y).max().item()), tree, initializer=0)
51
+
52
+
53
+ def tree_len(tree):
54
+ return tree_sum(tree_map(lambda z: np.prod(z.shape), tree))
55
+
56
+
57
+ def summarize_tree(tree, fn, ancestry=(), max_depth=3):
58
+ """Flatten 'tree' while 'fn'-ing values and formatting keys like/this."""
59
+ stats = {}
60
+ for k, v in tree.items():
61
+ name = ancestry + (k,)
62
+ stats['/'.join(name)] = fn(v)
63
+ if hasattr(v, 'items') and len(ancestry) < (max_depth - 1):
64
+ stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth))
65
+ return stats
66
+
67
+
68
+ def compute_data_loss(batch, renderings, config):
69
+ """Computes data loss terms for RGB, normal, and depth outputs."""
70
+ data_losses = []
71
+ stats = collections.defaultdict(lambda: [])
72
+
73
+ # lossmult can be used to apply a weight to each ray in the batch.
74
+ # For example: masking out rays, applying the Bayer mosaic mask, upweighting
75
+ # rays from lower resolution images and so on.
76
+ lossmult = batch['lossmult']
77
+ lossmult = torch.broadcast_to(lossmult, batch['rgb'][..., :3].shape)
78
+ if config.disable_multiscale_loss:
79
+ lossmult = torch.ones_like(lossmult)
80
+
81
+ for rendering in renderings:
82
+ resid_sq = (rendering['rgb'] - batch['rgb'][..., :3]) ** 2
83
+ denom = lossmult.sum()
84
+ stats['mses'].append(((lossmult * resid_sq).sum() / denom).item())
85
+
86
+ if config.data_loss_type == 'mse':
87
+ # Mean-squared error (L2) loss.
88
+ data_loss = resid_sq
89
+ elif config.data_loss_type == 'charb':
90
+ # Charbonnier loss.
91
+ data_loss = torch.sqrt(resid_sq + config.charb_padding ** 2)
92
+ elif config.data_loss_type == 'rawnerf':
93
+ # Clip raw values against 1 to match sensor overexposure behavior.
94
+ rgb_render_clip = rendering['rgb'].clamp_max(1)
95
+ resid_sq_clip = (rgb_render_clip - batch['rgb'][..., :3]) ** 2
96
+ # Scale by gradient of log tonemapping curve.
97
+ scaling_grad = 1. / (1e-3 + rgb_render_clip.detach())
98
+ # Reweighted L2 loss.
99
+ data_loss = resid_sq_clip * scaling_grad ** 2
100
+ else:
101
+ assert False
102
+ data_losses.append((lossmult * data_loss).sum() / denom)
103
+
104
+ if config.compute_disp_metrics:
105
+ # Using mean to compute disparity, but other distance statistics can
106
+ # be used instead.
107
+ disp = 1 / (1 + rendering['distance_mean'])
108
+ stats['disparity_mses'].append(((disp - batch['disps']) ** 2).mean().item())
109
+
110
+ if config.compute_normal_metrics:
111
+ if 'normals' in rendering:
112
+ weights = rendering['acc'] * batch['alphas']
113
+ normalized_normals_gt = ref_utils.l2_normalize(batch['normals'])
114
+ normalized_normals = ref_utils.l2_normalize(rendering['normals'])
115
+ normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals,
116
+ normalized_normals_gt)
117
+ else:
118
+ # If normals are not computed, set MAE to NaN.
119
+ normal_mae = torch.nan
120
+ stats['normal_maes'].append(normal_mae.item())
121
+
122
+ loss = (
123
+ config.data_coarse_loss_mult * sum(data_losses[:-1]) +
124
+ config.data_loss_mult * data_losses[-1])
125
+
126
+ stats = {k: np.array(stats[k]) for k in stats}
127
+ return loss, stats
128
+
129
+
130
+ def interlevel_loss(ray_history, config):
131
+ """Computes the interlevel loss defined in mip-NeRF 360."""
132
+ # Stop the gradient from the interlevel loss onto the NeRF MLP.
133
+ last_ray_results = ray_history[-1]
134
+ c = last_ray_results['sdist'].detach()
135
+ w = last_ray_results['weights'].detach()
136
+ loss_interlevel = 0.
137
+ for ray_results in ray_history[:-1]:
138
+ cp = ray_results['sdist']
139
+ wp = ray_results['weights']
140
+ loss_interlevel += stepfun.lossfun_outer(c, w, cp, wp).mean()
141
+ return config.interlevel_loss_mult * loss_interlevel
142
+
143
+
144
+ def anti_interlevel_loss(ray_history, config):
145
+ """Computes the interlevel loss defined in mip-NeRF 360."""
146
+ last_ray_results = ray_history[-1]
147
+ c = last_ray_results['sdist'].detach()
148
+ w = last_ray_results['weights'].detach()
149
+ w_normalize = w / (c[..., 1:] - c[..., :-1])
150
+ loss_anti_interlevel = 0.
151
+ for i, ray_results in enumerate(ray_history[:-1]):
152
+ cp = ray_results['sdist']
153
+ wp = ray_results['weights']
154
+ c_, w_ = stepfun.blur_stepfun(c, w_normalize, config.pulse_width[i])
155
+
156
+ # piecewise linear pdf to piecewise quadratic cdf
157
+ area = 0.5 * (w_[..., 1:] + w_[..., :-1]) * (c_[..., 1:] - c_[..., :-1])
158
+
159
+ cdf = torch.cat([torch.zeros_like(area[..., :1]), torch.cumsum(area, dim=-1)], dim=-1)
160
+
161
+ # query piecewise quadratic interpolation
162
+ cdf_interp = math.sorted_interp_quad(cp, c_, w_, cdf)
163
+ # difference between adjacent interpolated values
164
+ w_s = torch.diff(cdf_interp, dim=-1)
165
+
166
+ loss_anti_interlevel += ((w_s - wp).clamp_min(0) ** 2 / (wp + 1e-5)).mean()
167
+ return config.anti_interlevel_loss_mult * loss_anti_interlevel
168
+
169
+
170
+ def distortion_loss(ray_history, config):
171
+ """Computes the distortion loss regularizer defined in mip-NeRF 360."""
172
+ last_ray_results = ray_history[-1]
173
+ c = last_ray_results['sdist']
174
+ w = last_ray_results['weights']
175
+ loss = stepfun.lossfun_distortion(c, w).mean()
176
+ return config.distortion_loss_mult * loss
177
+
178
+
179
+ def orientation_loss(batch, model, ray_history, config):
180
+ """Computes the orientation loss regularizer defined in ref-NeRF."""
181
+ total_loss = 0.
182
+ for i, ray_results in enumerate(ray_history):
183
+ w = ray_results['weights']
184
+ n = ray_results[config.orientation_loss_target]
185
+ if n is None:
186
+ raise ValueError('Normals cannot be None if orientation loss is on.')
187
+ # Negate viewdirs to represent normalized vectors from point to camera.
188
+ v = -1. * batch['viewdirs']
189
+ n_dot_v = (n * v[..., None, :]).sum(dim=-1)
190
+ loss = (w * n_dot_v.clamp_min(0) ** 2).sum(dim=-1).mean()
191
+ if i < model.num_levels - 1:
192
+ total_loss += config.orientation_coarse_loss_mult * loss
193
+ else:
194
+ total_loss += config.orientation_loss_mult * loss
195
+ return total_loss
196
+
197
+
198
+ def hash_decay_loss(ray_history, config):
199
+ total_loss = 0.
200
+ for i, ray_results in enumerate(ray_history):
201
+ total_loss += config.hash_decay_mults * ray_results['loss_hash_decay']
202
+ return total_loss
203
+
204
+
205
+ def opacity_loss(renderings, config):
206
+ total_loss = 0.
207
+ for i, rendering in enumerate(renderings):
208
+ o = rendering['acc']
209
+ total_loss += config.opacity_loss_mult * (-o * torch.log(o + 1e-5)).mean()
210
+ return total_loss
211
+
212
+
213
+ def predicted_normal_loss(model, ray_history, config):
214
+ """Computes the predicted normal supervision loss defined in ref-NeRF."""
215
+ total_loss = 0.
216
+ for i, ray_results in enumerate(ray_history):
217
+ w = ray_results['weights']
218
+ n = ray_results['normals']
219
+ n_pred = ray_results['normals_pred']
220
+ if n is None or n_pred is None:
221
+ raise ValueError(
222
+ 'Predicted normals and gradient normals cannot be None if '
223
+ 'predicted normal loss is on.')
224
+ loss = torch.mean((w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1))
225
+ if i < model.num_levels - 1:
226
+ total_loss += config.predicted_normal_coarse_loss_mult * loss
227
+ else:
228
+ total_loss += config.predicted_normal_loss_mult * loss
229
+ return total_loss
230
+
231
+
232
+ def clip_gradients(model, accelerator, config):
233
+ """Clips gradients of MLP based on norm and max value."""
234
+ if config.grad_max_norm > 0 and accelerator.sync_gradients:
235
+ accelerator.clip_grad_norm_(model.parameters(), config.grad_max_norm)
236
+
237
+ if config.grad_max_val > 0 and accelerator.sync_gradients:
238
+ accelerator.clip_grad_value_(model.parameters(), config.grad_max_val)
239
+
240
+ for param in model.parameters():
241
+ param.grad.nan_to_num_()
242
+
243
+
244
+ def create_optimizer(config: configs.Config, model):
245
+ """Creates optax optimizer for model training."""
246
+ adam_kwargs = {
247
+ 'betas': [config.adam_beta1, config.adam_beta2],
248
+ 'eps': config.adam_eps,
249
+ }
250
+ lr_kwargs = {
251
+ 'max_steps': config.max_steps,
252
+ 'lr_delay_steps': config.lr_delay_steps,
253
+ 'lr_delay_mult': config.lr_delay_mult,
254
+ }
255
+
256
+ lr_fn_main = lambda step: math.learning_rate_decay(
257
+ step,
258
+ lr_init=config.lr_init,
259
+ lr_final=config.lr_final,
260
+ **lr_kwargs)
261
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init, **adam_kwargs)
262
+
263
+ return optimizer, lr_fn_main