copy files from SuLvXiangXin
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +15 -0
- LICENSE +202 -0
- README.md +252 -3
- configs/360.gin +15 -0
- configs/360_glo.gin +15 -0
- configs/blender.gin +15 -0
- configs/blender_refnerf.gin +41 -0
- configs/llff_256.gin +19 -0
- configs/llff_512.gin +19 -0
- configs/llff_raw.gin +73 -0
- configs/multi360.gin +5 -0
- eval.py +307 -0
- extract.py +638 -0
- gridencoder/__init__.py +1 -0
- gridencoder/backend.py +40 -0
- gridencoder/grid.py +198 -0
- gridencoder/setup.py +50 -0
- gridencoder/src/bindings.cpp +9 -0
- gridencoder/src/gridencoder.cu +645 -0
- gridencoder/src/gridencoder.h +17 -0
- internal/camera_utils.py +673 -0
- internal/checkpoints.py +38 -0
- internal/configs.py +177 -0
- internal/coord.py +225 -0
- internal/datasets.py +1016 -0
- internal/geopoly.py +108 -0
- internal/image.py +126 -0
- internal/math.py +133 -0
- internal/models.py +740 -0
- internal/pycolmap/.gitignore +2 -0
- internal/pycolmap/LICENSE.txt +21 -0
- internal/pycolmap/README.md +4 -0
- internal/pycolmap/pycolmap/__init__.py +5 -0
- internal/pycolmap/pycolmap/camera.py +259 -0
- internal/pycolmap/pycolmap/database.py +340 -0
- internal/pycolmap/pycolmap/image.py +35 -0
- internal/pycolmap/pycolmap/rotation.py +324 -0
- internal/pycolmap/pycolmap/scene_manager.py +670 -0
- internal/pycolmap/tools/colmap_to_nvm.py +69 -0
- internal/pycolmap/tools/delete_images.py +36 -0
- internal/pycolmap/tools/impute_missing_cameras.py +180 -0
- internal/pycolmap/tools/save_cameras_as_ply.py +92 -0
- internal/pycolmap/tools/transform_model.py +48 -0
- internal/pycolmap/tools/write_camera_track_to_bundler.py +60 -0
- internal/pycolmap/tools/write_depthmap_to_ply.py +139 -0
- internal/raw_utils.py +360 -0
- internal/ref_utils.py +174 -0
- internal/render.py +242 -0
- internal/stepfun.py +403 -0
- internal/train_utils.py +263 -0
.gitignore
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
interal/__pycache__/
|
3 |
+
tests/__pycache__/
|
4 |
+
.DS_Store
|
5 |
+
.vscode/
|
6 |
+
.idea/
|
7 |
+
__MACOSX/
|
8 |
+
exp/
|
9 |
+
data/
|
10 |
+
assets/
|
11 |
+
test.py
|
12 |
+
test2.py
|
13 |
+
*.mp4
|
14 |
+
*.ply
|
15 |
+
scripts/train_360_debug.sh
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,252 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ZipNeRF
|
2 |
+
|
3 |
+
An unofficial pytorch implementation of
|
4 |
+
"Zip-NeRF: Anti-Aliased Grid-Based Neural Radiance Fields"
|
5 |
+
[https://arxiv.org/abs/2304.06706](https://arxiv.org/abs/2304.06706).
|
6 |
+
This work is based on [multinerf](https://github.com/google-research/multinerf), so features in refnerf,rawnerf,mipnerf360 are also available.
|
7 |
+
|
8 |
+
## News
|
9 |
+
- (6.22) Add extracting mesh through tsdf; add [gradient scaling](https://gradient-scaling.github.io/) for near plane floaters.
|
10 |
+
- (5.26) Implement the latest version of ZipNeRF [https://arxiv.org/abs/2304.06706](https://arxiv.org/abs/2304.06706).
|
11 |
+
- (5.22) Add extracting mesh; add logging,checkpointing system
|
12 |
+
|
13 |
+
## Results
|
14 |
+
New results(5.27):
|
15 |
+
|
16 |
+
360_v2:
|
17 |
+
|
18 |
+
https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/2b276e48-2dc4-4508-8441-e90ec963f7d9
|
19 |
+
|
20 |
+
|
21 |
+
360_v2_glo:(fewer floaters, but worse metric)
|
22 |
+
|
23 |
+
|
24 |
+
https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/bddb5610-2a4f-4981-8e17-71326a24d291
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
mesh results(5.27):
|
32 |
+
|
33 |
+
![mesh](https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/35866fa7-fe6a-44fe-9590-05d594bdb8cd)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
Mipnerf360(PSNR):
|
38 |
+
|
39 |
+
| | bicycle | garden | stump | room | counter | kitchen | bonsai |
|
40 |
+
|:---------:|:-------:|:------:|:-----:|:-----:|:-------:|:-------:|:------:|
|
41 |
+
| Paper | 25.80 | 28.20 | 27.55 | 32.65 | 29.38 | 32.50 | 34.46 |
|
42 |
+
| This repo | 25.44 | 27.98 | 26.75 | 32.13 | 29.10 | 32.63 | 34.20 |
|
43 |
+
|
44 |
+
|
45 |
+
Blender(PSNR):
|
46 |
+
|
47 |
+
| | chair | drums | ficus | hotdog | lego | materials | mic | ship |
|
48 |
+
|:---------:|:-----:|:-----:|:-----:|:------:|:-----:|:---------:|:-----:|:-----:|
|
49 |
+
| Paper | 34.84 | 25.84 | 33.90 | 37.14 | 34.84 | 31.66 | 35.15 | 31.38 |
|
50 |
+
| This repo | 35.26 | 25.51 | 32.66 | 36.56 | 35.04 | 29.43 | 34.93 | 31.38 |
|
51 |
+
|
52 |
+
For Mipnerf360 dataset, the model is trained with a downsample factor of 4 for outdoor scene and 2 for indoor scene(same as in paper).
|
53 |
+
Training speed is about 1.5x slower than paper(1.5 hours on 8 A6000).
|
54 |
+
|
55 |
+
The hash decay loss seems to have little effect(?), as many floaters can be found in the final results in both experiments (especially in Blender).
|
56 |
+
|
57 |
+
## Install
|
58 |
+
|
59 |
+
```
|
60 |
+
# Clone the repo.
|
61 |
+
git clone https://github.com/SuLvXiangXin/zipnerf-pytorch.git
|
62 |
+
cd zipnerf-pytorch
|
63 |
+
|
64 |
+
# Make a conda environment.
|
65 |
+
conda create --name zipnerf python=3.9
|
66 |
+
conda activate zipnerf
|
67 |
+
|
68 |
+
# Install requirements.
|
69 |
+
pip install -r requirements.txt
|
70 |
+
|
71 |
+
# Install other extensions
|
72 |
+
pip install ./gridencoder
|
73 |
+
|
74 |
+
# Install nvdiffrast (optional, for textured mesh)
|
75 |
+
git clone https://github.com/NVlabs/nvdiffrast
|
76 |
+
pip install ./nvdiffrast
|
77 |
+
|
78 |
+
# Install a specific cuda version of torch_scatter
|
79 |
+
# see more detail at https://github.com/rusty1s/pytorch_scatter
|
80 |
+
CUDA=cu117
|
81 |
+
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
|
82 |
+
```
|
83 |
+
|
84 |
+
## Dataset
|
85 |
+
[mipnerf360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip)
|
86 |
+
|
87 |
+
[refnerf](https://storage.googleapis.com/gresearch/refraw360/ref.zip)
|
88 |
+
|
89 |
+
[nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
|
90 |
+
|
91 |
+
[nerf_llff_data](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
|
92 |
+
|
93 |
+
```
|
94 |
+
mkdir data
|
95 |
+
cd data
|
96 |
+
|
97 |
+
# e.g. mipnerf360 data
|
98 |
+
wget http://storage.googleapis.com/gresearch/refraw360/360_v2.zip
|
99 |
+
unzip 360_v2.zip
|
100 |
+
```
|
101 |
+
|
102 |
+
## Train
|
103 |
+
```
|
104 |
+
# Configure your training (DDP? fp16? ...)
|
105 |
+
# see https://huggingface.co/docs/accelerate/index for details
|
106 |
+
accelerate config
|
107 |
+
|
108 |
+
# Where your data is
|
109 |
+
DATA_DIR=data/360_v2/bicycle
|
110 |
+
EXP_NAME=360_v2/bicycle
|
111 |
+
|
112 |
+
# Experiment will be conducted under "exp/${EXP_NAME}" folder
|
113 |
+
# "--gin_configs=configs/360.gin" can be seen as a default config
|
114 |
+
# and you can add specific config useing --gin_bindings="..."
|
115 |
+
accelerate launch train.py \
|
116 |
+
--gin_configs=configs/360.gin \
|
117 |
+
--gin_bindings="Config.data_dir = '${DATA_DIR}'" \
|
118 |
+
--gin_bindings="Config.exp_name = '${EXP_NAME}'" \
|
119 |
+
--gin_bindings="Config.factor = 4"
|
120 |
+
|
121 |
+
# or you can also run without accelerate (without DDP)
|
122 |
+
CUDA_VISIBLE_DEVICES=0 python train.py \
|
123 |
+
--gin_configs=configs/360.gin \
|
124 |
+
--gin_bindings="Config.data_dir = '${DATA_DIR}'" \
|
125 |
+
--gin_bindings="Config.exp_name = '${EXP_NAME}'" \
|
126 |
+
--gin_bindings="Config.factor = 4"
|
127 |
+
|
128 |
+
# alternatively you can use an example training script
|
129 |
+
bash scripts/train_360.sh
|
130 |
+
|
131 |
+
# blender dataset
|
132 |
+
bash scripts/train_blender.sh
|
133 |
+
|
134 |
+
# metric, render image, etc can be viewed through tensorboard
|
135 |
+
tensorboard --logdir "exp/${EXP_NAME}"
|
136 |
+
|
137 |
+
```
|
138 |
+
|
139 |
+
### Render
|
140 |
+
Rendering results can be found in the directory `exp/${EXP_NAME}/render`
|
141 |
+
```
|
142 |
+
accelerate launch render.py \
|
143 |
+
--gin_configs=configs/360.gin \
|
144 |
+
--gin_bindings="Config.data_dir = '${DATA_DIR}'" \
|
145 |
+
--gin_bindings="Config.exp_name = '${EXP_NAME}'" \
|
146 |
+
--gin_bindings="Config.render_path = True" \
|
147 |
+
--gin_bindings="Config.render_path_frames = 480" \
|
148 |
+
--gin_bindings="Config.render_video_fps = 60" \
|
149 |
+
--gin_bindings="Config.factor = 4"
|
150 |
+
|
151 |
+
# alternatively you can use an example rendering script
|
152 |
+
bash scripts/render_360.sh
|
153 |
+
```
|
154 |
+
## Evaluate
|
155 |
+
Evaluating results can be found in the directory `exp/${EXP_NAME}/test_preds`
|
156 |
+
```
|
157 |
+
# using the same exp_name as in training
|
158 |
+
accelerate launch eval.py \
|
159 |
+
--gin_configs=configs/360.gin \
|
160 |
+
--gin_bindings="Config.data_dir = '${DATA_DIR}'" \
|
161 |
+
--gin_bindings="Config.exp_name = '${EXP_NAME}'" \
|
162 |
+
--gin_bindings="Config.factor = 4"
|
163 |
+
|
164 |
+
|
165 |
+
# alternatively you can use an example evaluating script
|
166 |
+
bash scripts/eval_360.sh
|
167 |
+
```
|
168 |
+
|
169 |
+
## Extract mesh
|
170 |
+
Mesh results can be found in the directory `exp/${EXP_NAME}/mesh`
|
171 |
+
```
|
172 |
+
# more configuration can be found in internal/configs.py
|
173 |
+
accelerate launch extract.py \
|
174 |
+
--gin_configs=configs/360.gin \
|
175 |
+
--gin_bindings="Config.data_dir = '${DATA_DIR}'" \
|
176 |
+
--gin_bindings="Config.exp_name = '${EXP_NAME}'" \
|
177 |
+
--gin_bindings="Config.factor = 4"
|
178 |
+
# --gin_bindings="Config.mesh_radius = 1" # (optional) smaller for more details e.g. 0.2 in bicycle scene
|
179 |
+
# --gin_bindings="Config.isosurface_threshold = 20" # (optional) empirical value
|
180 |
+
# --gin_bindings="Config.mesh_voxels=134217728" # (optional) number of voxels used to extract mesh, e.g. 134217728 equals to 512**3 . Smaller values may solve OutoFMemoryError
|
181 |
+
# --gin_bindings="Config.vertex_color = True" # (optional) saving mesh with vertex color instead of atlas which is much slower but with more details.
|
182 |
+
# --gin_bindings="Config.vertex_projection = True" # (optional) use projection for vertex color
|
183 |
+
|
184 |
+
# or extracting mesh using tsdf method
|
185 |
+
accelerate launch extract.py \
|
186 |
+
--gin_configs=configs/360.gin \
|
187 |
+
--gin_bindings="Config.data_dir = '${DATA_DIR}'" \
|
188 |
+
--gin_bindings="Config.exp_name = '${EXP_NAME}'" \
|
189 |
+
--gin_bindings="Config.factor = 4"
|
190 |
+
|
191 |
+
# alternatively you can use an example script
|
192 |
+
bash scripts/extract_360.sh
|
193 |
+
```
|
194 |
+
|
195 |
+
## OutOfMemory
|
196 |
+
you can decrease the total batch size by
|
197 |
+
adding e.g. `--gin_bindings="Config.batch_size = 8192" `,
|
198 |
+
or decrease the test chunk size by adding e.g. `--gin_bindings="Config.render_chunk_size = 8192" `,
|
199 |
+
or use more GPU by configure `accelerate config` .
|
200 |
+
|
201 |
+
|
202 |
+
## Preparing custom data
|
203 |
+
More details can be found at https://github.com/google-research/multinerf
|
204 |
+
```
|
205 |
+
DATA_DIR=my_dataset_dir
|
206 |
+
bash scripts/local_colmap_and_resize.sh ${DATA_DIR}
|
207 |
+
```
|
208 |
+
|
209 |
+
## TODO
|
210 |
+
- [x] Add MultiScale training and testing
|
211 |
+
|
212 |
+
## Citation
|
213 |
+
```
|
214 |
+
@misc{barron2023zipnerf,
|
215 |
+
title={Zip-NeRF: Anti-Aliased Grid-Based Neural Radiance Fields},
|
216 |
+
author={Jonathan T. Barron and Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman},
|
217 |
+
year={2023},
|
218 |
+
eprint={2304.06706},
|
219 |
+
archivePrefix={arXiv},
|
220 |
+
primaryClass={cs.CV}
|
221 |
+
}
|
222 |
+
|
223 |
+
@misc{multinerf2022,
|
224 |
+
title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}},
|
225 |
+
author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron},
|
226 |
+
year={2022},
|
227 |
+
url={https://github.com/google-research/multinerf},
|
228 |
+
}
|
229 |
+
|
230 |
+
@Misc{accelerate,
|
231 |
+
title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
|
232 |
+
author = {Sylvain Gugger, Lysandre Debut, Thomas Wolf, Philipp Schmid, Zachary Mueller, Sourab Mangrulkar},
|
233 |
+
howpublished = {\url{https://github.com/huggingface/accelerate}},
|
234 |
+
year = {2022}
|
235 |
+
}
|
236 |
+
|
237 |
+
@misc{torch-ngp,
|
238 |
+
Author = {Jiaxiang Tang},
|
239 |
+
Year = {2022},
|
240 |
+
Note = {https://github.com/ashawkey/torch-ngp},
|
241 |
+
Title = {Torch-ngp: a PyTorch implementation of instant-ngp}
|
242 |
+
}
|
243 |
+
```
|
244 |
+
|
245 |
+
## Acknowledgements
|
246 |
+
This work is based on my another repo https://github.com/SuLvXiangXin/multinerf-pytorch,
|
247 |
+
which is basically a pytorch translation from [multinerf](https://github.com/google-research/multinerf)
|
248 |
+
|
249 |
+
- Thanks to [multinerf](https://github.com/google-research/multinerf) for amazing multinerf(MipNeRF360,RefNeRF,RawNeRF) implementation
|
250 |
+
- Thanks to [accelerate](https://github.com/huggingface/accelerate) for distributed training
|
251 |
+
- Thanks to [torch-ngp](https://github.com/ashawkey/torch-ngp) for super useful hashencoder
|
252 |
+
- Thanks to [Yurui Chen](https://github.com/519401113) for discussing the details of the paper.
|
configs/360.gin
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Config.exp_name = 'test'
|
2 |
+
Config.dataset_loader = 'llff'
|
3 |
+
Config.near = 0.2
|
4 |
+
Config.far = 1e6
|
5 |
+
Config.factor = 4
|
6 |
+
|
7 |
+
Model.raydist_fn = 'power_transformation'
|
8 |
+
Model.opaque_background = True
|
9 |
+
|
10 |
+
PropMLP.disable_density_normals = True
|
11 |
+
PropMLP.disable_rgb = True
|
12 |
+
PropMLP.grid_level_dim = 1
|
13 |
+
|
14 |
+
NerfMLP.disable_density_normals = True
|
15 |
+
|
configs/360_glo.gin
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Config.dataset_loader = 'llff'
|
2 |
+
Config.near = 0.2
|
3 |
+
Config.far = 1e6
|
4 |
+
Config.factor = 4
|
5 |
+
|
6 |
+
Model.raydist_fn = 'power_transformation'
|
7 |
+
Model.num_glo_features = 128
|
8 |
+
Model.opaque_background = True
|
9 |
+
|
10 |
+
PropMLP.disable_density_normals = True
|
11 |
+
PropMLP.disable_rgb = True
|
12 |
+
PropMLP.grid_level_dim = 1
|
13 |
+
|
14 |
+
|
15 |
+
NerfMLP.disable_density_normals = True
|
configs/blender.gin
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Config.exp_name = 'test'
|
2 |
+
Config.dataset_loader = 'blender'
|
3 |
+
Config.near = 2
|
4 |
+
Config.far = 6
|
5 |
+
Config.factor = 0
|
6 |
+
Config.hash_decay_mults = 10
|
7 |
+
|
8 |
+
Model.raydist_fn = None
|
9 |
+
|
10 |
+
PropMLP.disable_density_normals = True
|
11 |
+
PropMLP.disable_rgb = True
|
12 |
+
PropMLP.grid_level_dim = 1
|
13 |
+
|
14 |
+
NerfMLP.disable_density_normals = True
|
15 |
+
|
configs/blender_refnerf.gin
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Config.dataset_loader = 'blender'
|
2 |
+
Config.batching = 'single_image'
|
3 |
+
Config.near = 2
|
4 |
+
Config.far = 6
|
5 |
+
|
6 |
+
Config.eval_render_interval = 5
|
7 |
+
Config.compute_normal_metrics = True
|
8 |
+
Config.data_loss_type = 'mse'
|
9 |
+
Config.distortion_loss_mult = 0.0
|
10 |
+
Config.orientation_loss_mult = 0.1
|
11 |
+
Config.orientation_loss_target = 'normals_pred'
|
12 |
+
Config.predicted_normal_loss_mult = 3e-4
|
13 |
+
Config.orientation_coarse_loss_mult = 0.01
|
14 |
+
Config.predicted_normal_coarse_loss_mult = 3e-5
|
15 |
+
Config.interlevel_loss_mult = 0.0
|
16 |
+
Config.data_coarse_loss_mult = 0.1
|
17 |
+
Config.adam_eps = 1e-8
|
18 |
+
|
19 |
+
Model.num_levels = 2
|
20 |
+
Model.single_mlp = True
|
21 |
+
Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True.
|
22 |
+
Model.num_nerf_samples = 128
|
23 |
+
Model.anneal_slope = 0.
|
24 |
+
Model.dilation_multiplier = 0.
|
25 |
+
Model.dilation_bias = 0.
|
26 |
+
Model.single_jitter = False
|
27 |
+
Model.resample_padding = 0.01
|
28 |
+
Model.distinct_prop = False
|
29 |
+
|
30 |
+
NerfMLP.disable_density_normals = False
|
31 |
+
NerfMLP.enable_pred_normals = True
|
32 |
+
NerfMLP.use_directional_enc = True
|
33 |
+
NerfMLP.use_reflections = True
|
34 |
+
NerfMLP.deg_view = 5
|
35 |
+
NerfMLP.enable_pred_roughness = True
|
36 |
+
NerfMLP.use_diffuse_color = True
|
37 |
+
NerfMLP.use_specular_tint = True
|
38 |
+
NerfMLP.use_n_dot_v = True
|
39 |
+
NerfMLP.bottleneck_width = 128
|
40 |
+
NerfMLP.density_bias = 0.5
|
41 |
+
NerfMLP.max_deg_point = 16
|
configs/llff_256.gin
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Config.dataset_loader = 'llff'
|
2 |
+
Config.near = 0.
|
3 |
+
Config.far = 1.
|
4 |
+
Config.factor = 4
|
5 |
+
Config.forward_facing = True
|
6 |
+
Config.adam_eps = 1e-8
|
7 |
+
|
8 |
+
Model.opaque_background = True
|
9 |
+
Model.num_levels = 2
|
10 |
+
Model.num_prop_samples = 128
|
11 |
+
Model.num_nerf_samples = 32
|
12 |
+
|
13 |
+
PropMLP.disable_density_normals = True
|
14 |
+
PropMLP.disable_rgb = True
|
15 |
+
|
16 |
+
NerfMLP.disable_density_normals = True
|
17 |
+
|
18 |
+
NerfMLP.max_deg_point = 16
|
19 |
+
PropMLP.max_deg_point = 16
|
configs/llff_512.gin
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Config.dataset_loader = 'llff'
|
2 |
+
Config.near = 0.
|
3 |
+
Config.far = 1.
|
4 |
+
Config.factor = 4
|
5 |
+
Config.forward_facing = True
|
6 |
+
Config.adam_eps = 1e-8
|
7 |
+
|
8 |
+
Model.opaque_background = True
|
9 |
+
Model.num_levels = 2
|
10 |
+
Model.num_prop_samples = 128
|
11 |
+
Model.num_nerf_samples = 32
|
12 |
+
|
13 |
+
PropMLP.disable_density_normals = True
|
14 |
+
PropMLP.disable_rgb = True
|
15 |
+
|
16 |
+
NerfMLP.disable_density_normals = True
|
17 |
+
|
18 |
+
NerfMLP.max_deg_point = 16
|
19 |
+
PropMLP.max_deg_point = 16
|
configs/llff_raw.gin
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# General LLFF settings
|
2 |
+
|
3 |
+
Config.dataset_loader = 'llff'
|
4 |
+
Config.near = 0.
|
5 |
+
Config.far = 1.
|
6 |
+
Config.factor = 4
|
7 |
+
Config.forward_facing = True
|
8 |
+
|
9 |
+
PropMLP.disable_density_normals = True # Turn this off if using orientation loss.
|
10 |
+
PropMLP.disable_rgb = True
|
11 |
+
|
12 |
+
NerfMLP.disable_density_normals = True # Turn this off if using orientation loss.
|
13 |
+
|
14 |
+
NerfMLP.max_deg_point = 16
|
15 |
+
PropMLP.max_deg_point = 16
|
16 |
+
|
17 |
+
Config.train_render_every = 5000
|
18 |
+
|
19 |
+
|
20 |
+
########################## RawNeRF specific settings ##########################
|
21 |
+
|
22 |
+
Config.rawnerf_mode = True
|
23 |
+
Config.data_loss_type = 'rawnerf'
|
24 |
+
Config.apply_bayer_mask = True
|
25 |
+
Model.learned_exposure_scaling = True
|
26 |
+
|
27 |
+
Model.num_levels = 2
|
28 |
+
Model.num_prop_samples = 128 # Using extra samples for now because of noise instability.
|
29 |
+
Model.num_nerf_samples = 128
|
30 |
+
Model.opaque_background = True
|
31 |
+
Model.distinct_prop = False
|
32 |
+
|
33 |
+
# RGB activation we use for linear color outputs is exp(x - 5).
|
34 |
+
NerfMLP.rgb_padding = 0.
|
35 |
+
NerfMLP.rgb_activation = @math.safe_exp
|
36 |
+
NerfMLP.rgb_bias = -5.
|
37 |
+
PropMLP.rgb_padding = 0.
|
38 |
+
PropMLP.rgb_activation = @math.safe_exp
|
39 |
+
PropMLP.rgb_bias = -5.
|
40 |
+
|
41 |
+
## Experimenting with the various regularizers and losses:
|
42 |
+
Config.interlevel_loss_mult = .0 # Turning off interlevel for now (default = 1.).
|
43 |
+
Config.distortion_loss_mult = .01 # Distortion loss helps with floaters (default = .01).
|
44 |
+
Config.orientation_loss_mult = 0. # Orientation loss also not great (try .01).
|
45 |
+
Config.data_coarse_loss_mult = 0.1 # Setting this to match old MipNeRF.
|
46 |
+
|
47 |
+
## Density noise used in original NeRF:
|
48 |
+
NerfMLP.density_noise = 1.
|
49 |
+
PropMLP.density_noise = 1.
|
50 |
+
|
51 |
+
## Use a single MLP for all rounds of sampling:
|
52 |
+
Model.single_mlp = True
|
53 |
+
|
54 |
+
## Some algorithmic settings to match the paper:
|
55 |
+
Model.anneal_slope = 0.
|
56 |
+
Model.dilation_multiplier = 0.
|
57 |
+
Model.dilation_bias = 0.
|
58 |
+
Model.single_jitter = False
|
59 |
+
NerfMLP.weight_init = 'glorot_uniform'
|
60 |
+
PropMLP.weight_init = 'glorot_uniform'
|
61 |
+
|
62 |
+
## Training hyperparameters used in the paper:
|
63 |
+
Config.batch_size = 16384
|
64 |
+
Config.render_chunk_size = 16384
|
65 |
+
Config.lr_init = 1e-3
|
66 |
+
Config.lr_final = 1e-5
|
67 |
+
Config.max_steps = 500000
|
68 |
+
Config.checkpoint_every = 25000
|
69 |
+
Config.lr_delay_steps = 2500
|
70 |
+
Config.lr_delay_mult = 0.01
|
71 |
+
Config.grad_max_norm = 0.1
|
72 |
+
Config.grad_max_val = 0.1
|
73 |
+
Config.adam_eps = 1e-8
|
configs/multi360.gin
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include 'configs/360.gin'
|
2 |
+
Config.multiscale = True
|
3 |
+
Config.multiscale_levels = 4
|
4 |
+
|
5 |
+
|
eval.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import accelerate
|
6 |
+
from absl import app
|
7 |
+
import gin
|
8 |
+
from internal import configs
|
9 |
+
from internal import datasets
|
10 |
+
from internal import image
|
11 |
+
from internal import models
|
12 |
+
from internal import raw_utils
|
13 |
+
from internal import ref_utils
|
14 |
+
from internal import train_utils
|
15 |
+
from internal import checkpoints
|
16 |
+
from internal import utils
|
17 |
+
from internal import vis
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import tensorboardX
|
21 |
+
from torch.utils._pytree import tree_map
|
22 |
+
|
23 |
+
configs.define_common_flags()
|
24 |
+
|
25 |
+
|
26 |
+
def summarize_results(folder, scene_names, num_buckets):
|
27 |
+
metric_names = ['psnrs', 'ssims', 'lpips']
|
28 |
+
num_iters = 1000000
|
29 |
+
precisions = [3, 4, 4, 4]
|
30 |
+
|
31 |
+
results = []
|
32 |
+
for scene_name in scene_names:
|
33 |
+
test_preds_folder = os.path.join(folder, scene_name, 'test_preds')
|
34 |
+
values = []
|
35 |
+
for metric_name in metric_names:
|
36 |
+
filename = os.path.join(folder, scene_name, 'test_preds', f'{metric_name}_{num_iters}.txt')
|
37 |
+
with utils.open_file(filename) as f:
|
38 |
+
v = np.array([float(s) for s in f.readline().split(' ')])
|
39 |
+
values.append(np.mean(np.reshape(v, [-1, num_buckets]), 0))
|
40 |
+
results.append(np.concatenate(values))
|
41 |
+
avg_results = np.mean(np.array(results), 0)
|
42 |
+
|
43 |
+
psnr, ssim, lpips = np.mean(np.reshape(avg_results, [-1, num_buckets]), 1)
|
44 |
+
|
45 |
+
mse = np.exp(-0.1 * np.log(10.) * psnr)
|
46 |
+
dssim = np.sqrt(1 - ssim)
|
47 |
+
avg_avg = np.exp(np.mean(np.log(np.array([mse, dssim, lpips]))))
|
48 |
+
|
49 |
+
s = []
|
50 |
+
for i, v in enumerate(np.reshape(avg_results, [-1, num_buckets])):
|
51 |
+
s.append(' '.join([f'{s:0.{precisions[i]}f}' for s in v]))
|
52 |
+
s.append(f'{avg_avg:0.{precisions[-1]}f}')
|
53 |
+
return ' | '.join(s)
|
54 |
+
|
55 |
+
|
56 |
+
def main(unused_argv):
|
57 |
+
config = configs.load_config()
|
58 |
+
config.exp_path = os.path.join('exp', config.exp_name)
|
59 |
+
config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints')
|
60 |
+
config.render_dir = os.path.join(config.exp_path, 'render')
|
61 |
+
|
62 |
+
accelerator = accelerate.Accelerator()
|
63 |
+
|
64 |
+
# setup logger
|
65 |
+
logging.basicConfig(
|
66 |
+
format="%(asctime)s: %(message)s",
|
67 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
68 |
+
force=True,
|
69 |
+
handlers=[logging.StreamHandler(sys.stdout),
|
70 |
+
logging.FileHandler(os.path.join(config.exp_path, 'log_eval.txt'))],
|
71 |
+
level=logging.INFO,
|
72 |
+
)
|
73 |
+
sys.excepthook = utils.handle_exception
|
74 |
+
logger = accelerate.logging.get_logger(__name__)
|
75 |
+
logger.info(config)
|
76 |
+
logger.info(accelerator.state, main_process_only=False)
|
77 |
+
|
78 |
+
config.world_size = accelerator.num_processes
|
79 |
+
config.global_rank = accelerator.process_index
|
80 |
+
accelerate.utils.set_seed(config.seed, device_specific=True)
|
81 |
+
model = models.Model(config=config)
|
82 |
+
model.eval()
|
83 |
+
model.to(accelerator.device)
|
84 |
+
|
85 |
+
dataset = datasets.load_dataset('test', config.data_dir, config)
|
86 |
+
dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)),
|
87 |
+
shuffle=False,
|
88 |
+
batch_size=1,
|
89 |
+
collate_fn=dataset.collate_fn,
|
90 |
+
)
|
91 |
+
tb_process_fn = lambda x: x.transpose(2, 0, 1) if len(x.shape) == 3 else x[None]
|
92 |
+
if config.rawnerf_mode:
|
93 |
+
postprocess_fn = dataset.metadata['postprocess_fn']
|
94 |
+
else:
|
95 |
+
postprocess_fn = lambda z: z
|
96 |
+
|
97 |
+
if config.eval_raw_affine_cc:
|
98 |
+
cc_fun = raw_utils.match_images_affine
|
99 |
+
else:
|
100 |
+
cc_fun = image.color_correct
|
101 |
+
|
102 |
+
model = accelerator.prepare(model)
|
103 |
+
|
104 |
+
metric_harness = image.MetricHarness()
|
105 |
+
|
106 |
+
last_step = 0
|
107 |
+
out_dir = os.path.join(config.exp_path,
|
108 |
+
'path_renders' if config.render_path else 'test_preds')
|
109 |
+
path_fn = lambda x: os.path.join(out_dir, x)
|
110 |
+
|
111 |
+
if not config.eval_only_once:
|
112 |
+
summary_writer = tensorboardX.SummaryWriter(
|
113 |
+
os.path.join(config.exp_path, 'eval'))
|
114 |
+
while True:
|
115 |
+
step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger)
|
116 |
+
if step <= last_step:
|
117 |
+
logger.info(f'Checkpoint step {step} <= last step {last_step}, sleeping.')
|
118 |
+
time.sleep(10)
|
119 |
+
continue
|
120 |
+
logger.info(f'Evaluating checkpoint at step {step}.')
|
121 |
+
if config.eval_save_output and (not utils.isdir(out_dir)):
|
122 |
+
utils.makedirs(out_dir)
|
123 |
+
|
124 |
+
num_eval = min(dataset.size, config.eval_dataset_limit)
|
125 |
+
perm = np.random.permutation(num_eval)
|
126 |
+
showcase_indices = np.sort(perm[:config.num_showcase_images])
|
127 |
+
metrics = []
|
128 |
+
metrics_cc = []
|
129 |
+
showcases = []
|
130 |
+
render_times = []
|
131 |
+
for idx, batch in enumerate(dataloader):
|
132 |
+
batch = accelerate.utils.send_to_device(batch, accelerator.device)
|
133 |
+
eval_start_time = time.time()
|
134 |
+
if idx >= num_eval:
|
135 |
+
logger.info(f'Skipping image {idx + 1}/{dataset.size}')
|
136 |
+
continue
|
137 |
+
logger.info(f'Evaluating image {idx + 1}/{dataset.size}')
|
138 |
+
rendering = models.render_image(model, accelerator,
|
139 |
+
batch, False, 1, config)
|
140 |
+
|
141 |
+
if not accelerator.is_main_process: # Only record via host 0.
|
142 |
+
continue
|
143 |
+
|
144 |
+
render_times.append((time.time() - eval_start_time))
|
145 |
+
logger.info(f'Rendered in {render_times[-1]:0.3f}s')
|
146 |
+
|
147 |
+
cc_start_time = time.time()
|
148 |
+
rendering['rgb_cc'] = cc_fun(rendering['rgb'], batch['rgb'])
|
149 |
+
|
150 |
+
rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering)
|
151 |
+
batch = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, batch)
|
152 |
+
|
153 |
+
gt_rgb = batch['rgb']
|
154 |
+
logger.info(f'Color corrected in {(time.time() - cc_start_time):0.3f}s')
|
155 |
+
|
156 |
+
if not config.eval_only_once and idx in showcase_indices:
|
157 |
+
showcase_idx = idx if config.deterministic_showcase else len(showcases)
|
158 |
+
showcases.append((showcase_idx, rendering, batch))
|
159 |
+
if not config.render_path:
|
160 |
+
rgb = postprocess_fn(rendering['rgb'])
|
161 |
+
rgb_cc = postprocess_fn(rendering['rgb_cc'])
|
162 |
+
rgb_gt = postprocess_fn(gt_rgb)
|
163 |
+
|
164 |
+
if config.eval_quantize_metrics:
|
165 |
+
# Ensures that the images written to disk reproduce the metrics.
|
166 |
+
rgb = np.round(rgb * 255) / 255
|
167 |
+
rgb_cc = np.round(rgb_cc * 255) / 255
|
168 |
+
|
169 |
+
if config.eval_crop_borders > 0:
|
170 |
+
crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c]
|
171 |
+
rgb = crop_fn(rgb)
|
172 |
+
rgb_cc = crop_fn(rgb_cc)
|
173 |
+
rgb_gt = crop_fn(rgb_gt)
|
174 |
+
|
175 |
+
metric = metric_harness(rgb, rgb_gt)
|
176 |
+
metric_cc = metric_harness(rgb_cc, rgb_gt)
|
177 |
+
|
178 |
+
if config.compute_disp_metrics:
|
179 |
+
for tag in ['mean', 'median']:
|
180 |
+
key = f'distance_{tag}'
|
181 |
+
if key in rendering:
|
182 |
+
disparity = 1 / (1 + rendering[key])
|
183 |
+
metric[f'disparity_{tag}_mse'] = float(
|
184 |
+
((disparity - batch['disps']) ** 2).mean())
|
185 |
+
|
186 |
+
if config.compute_normal_metrics:
|
187 |
+
weights = rendering['acc'] * batch['alphas']
|
188 |
+
normalized_normals_gt = ref_utils.l2_normalize_np(batch['normals'])
|
189 |
+
for key, val in rendering.items():
|
190 |
+
if key.startswith('normals') and val is not None:
|
191 |
+
normalized_normals = ref_utils.l2_normalize_np(val)
|
192 |
+
metric[key + '_mae'] = ref_utils.compute_weighted_mae_np(
|
193 |
+
weights, normalized_normals, normalized_normals_gt)
|
194 |
+
|
195 |
+
for m, v in metric.items():
|
196 |
+
logger.info(f'{m:30s} = {v:.4f}')
|
197 |
+
|
198 |
+
metrics.append(metric)
|
199 |
+
metrics_cc.append(metric_cc)
|
200 |
+
|
201 |
+
if config.eval_save_output and (config.eval_render_interval > 0):
|
202 |
+
if (idx % config.eval_render_interval) == 0:
|
203 |
+
utils.save_img_u8(postprocess_fn(rendering['rgb']),
|
204 |
+
path_fn(f'color_{idx:03d}.png'))
|
205 |
+
utils.save_img_u8(postprocess_fn(rendering['rgb_cc']),
|
206 |
+
path_fn(f'color_cc_{idx:03d}.png'))
|
207 |
+
|
208 |
+
for key in ['distance_mean', 'distance_median']:
|
209 |
+
if key in rendering:
|
210 |
+
utils.save_img_f32(rendering[key],
|
211 |
+
path_fn(f'{key}_{idx:03d}.tiff'))
|
212 |
+
|
213 |
+
for key in ['normals']:
|
214 |
+
if key in rendering:
|
215 |
+
utils.save_img_u8(rendering[key] / 2. + 0.5,
|
216 |
+
path_fn(f'{key}_{idx:03d}.png'))
|
217 |
+
|
218 |
+
utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff'))
|
219 |
+
|
220 |
+
if (not config.eval_only_once) and accelerator.is_main_process:
|
221 |
+
summary_writer.add_scalar('eval_median_render_time', np.median(render_times),
|
222 |
+
step)
|
223 |
+
for name in metrics[0]:
|
224 |
+
scores = [m[name] for m in metrics]
|
225 |
+
summary_writer.add_scalar('eval_metrics/' + name, np.mean(scores), step)
|
226 |
+
summary_writer.add_histogram('eval_metrics/' + 'perimage_' + name, scores,
|
227 |
+
step)
|
228 |
+
for name in metrics_cc[0]:
|
229 |
+
scores = [m[name] for m in metrics_cc]
|
230 |
+
summary_writer.add_scalar('eval_metrics_cc/' + name, np.mean(scores), step)
|
231 |
+
summary_writer.add_histogram('eval_metrics_cc/' + 'perimage_' + name,
|
232 |
+
scores, step)
|
233 |
+
|
234 |
+
for i, r, b in showcases:
|
235 |
+
if config.vis_decimate > 1:
|
236 |
+
d = config.vis_decimate
|
237 |
+
decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]
|
238 |
+
else:
|
239 |
+
decimate_fn = lambda x: x
|
240 |
+
r = tree_map(decimate_fn, r)
|
241 |
+
b = tree_map(decimate_fn, b)
|
242 |
+
visualizations = vis.visualize_suite(r, b)
|
243 |
+
for k, v in visualizations.items():
|
244 |
+
if k == 'color':
|
245 |
+
v = postprocess_fn(v)
|
246 |
+
summary_writer.add_image(f'output_{k}_{i}', tb_process_fn(v), step)
|
247 |
+
if not config.render_path:
|
248 |
+
target = postprocess_fn(b['rgb'])
|
249 |
+
summary_writer.add_image(f'true_color_{i}', tb_process_fn(target), step)
|
250 |
+
pred = postprocess_fn(visualizations['color'])
|
251 |
+
residual = np.clip(pred - target + 0.5, 0, 1)
|
252 |
+
summary_writer.add_image(f'true_residual_{i}', tb_process_fn(residual), step)
|
253 |
+
if config.compute_normal_metrics:
|
254 |
+
summary_writer.add_image(f'true_normals_{i}', tb_process_fn(b['normals']) / 2. + 0.5,
|
255 |
+
step)
|
256 |
+
|
257 |
+
if (config.eval_save_output and (not config.render_path) and
|
258 |
+
accelerator.is_main_process):
|
259 |
+
with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f:
|
260 |
+
f.write(' '.join([str(r) for r in render_times]))
|
261 |
+
logger.info(f'metrics:')
|
262 |
+
results = {}
|
263 |
+
num_buckets = config.multiscale_levels if config.multiscale else 1
|
264 |
+
for name in metrics[0]:
|
265 |
+
with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
|
266 |
+
ms = [m[name] for m in metrics]
|
267 |
+
f.write(' '.join([str(m) for m in ms]))
|
268 |
+
results[name] = ' | '.join(
|
269 |
+
list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist())))
|
270 |
+
with utils.open_file(path_fn(f'metric_avg_{step}.txt'), 'w') as f:
|
271 |
+
for name in metrics[0]:
|
272 |
+
f.write(f'{name}: {results[name]}\n')
|
273 |
+
logger.info(f'{name}: {results[name]}')
|
274 |
+
logger.info(f'metrics_cc:')
|
275 |
+
results_cc = {}
|
276 |
+
for name in metrics_cc[0]:
|
277 |
+
with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f:
|
278 |
+
ms = [m[name] for m in metrics_cc]
|
279 |
+
f.write(' '.join([str(m) for m in ms]))
|
280 |
+
results_cc[name] = ' | '.join(
|
281 |
+
list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist())))
|
282 |
+
with utils.open_file(path_fn(f'metric_cc_avg_{step}.txt'), 'w') as f:
|
283 |
+
for name in metrics[0]:
|
284 |
+
f.write(f'{name}: {results_cc[name]}\n')
|
285 |
+
logger.info(f'{name}: {results_cc[name]}')
|
286 |
+
if config.eval_save_ray_data:
|
287 |
+
for i, r, b in showcases:
|
288 |
+
rays = {k: v for k, v in r.items() if 'ray_' in k}
|
289 |
+
np.set_printoptions(threshold=sys.maxsize)
|
290 |
+
with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f:
|
291 |
+
f.write(repr(rays))
|
292 |
+
|
293 |
+
if config.eval_only_once:
|
294 |
+
break
|
295 |
+
if config.early_exit_steps is not None:
|
296 |
+
num_steps = config.early_exit_steps
|
297 |
+
else:
|
298 |
+
num_steps = config.max_steps
|
299 |
+
if int(step) >= num_steps:
|
300 |
+
break
|
301 |
+
last_step = step
|
302 |
+
logger.info('Finish evaluation.')
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == '__main__':
|
306 |
+
with gin.config_scope('eval'):
|
307 |
+
app.run(main)
|
extract.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from absl import app
|
8 |
+
import gin
|
9 |
+
from internal import configs
|
10 |
+
from internal import datasets
|
11 |
+
from internal import models
|
12 |
+
from internal import utils
|
13 |
+
from internal import coord
|
14 |
+
from internal import checkpoints
|
15 |
+
import torch
|
16 |
+
import accelerate
|
17 |
+
from tqdm import tqdm
|
18 |
+
from torch.utils._pytree import tree_map
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from skimage import measure
|
21 |
+
import trimesh
|
22 |
+
import pymeshlab as pml
|
23 |
+
|
24 |
+
configs.define_common_flags()
|
25 |
+
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def evaluate_density(model, accelerator: accelerate.Accelerator,
|
29 |
+
points, config: configs.Config, std_value=0.0):
|
30 |
+
"""
|
31 |
+
Evaluate a signed distance function (SDF) for a batch of points.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
sdf: A callable function that takes a tensor of size (N, 3) containing
|
35 |
+
3D points and returns a tensor of size (N,) with the SDF values.
|
36 |
+
points: A torch tensor containing 3D points.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
A torch tensor with the SDF values evaluated at the given points.
|
40 |
+
"""
|
41 |
+
z = []
|
42 |
+
for _, pnts in enumerate(tqdm(torch.split(points, config.render_chunk_size, dim=0),
|
43 |
+
desc="Evaluating density", leave=False,
|
44 |
+
disable=not accelerator.is_main_process)):
|
45 |
+
rays_remaining = pnts.shape[0] % accelerator.num_processes
|
46 |
+
if rays_remaining != 0:
|
47 |
+
padding = accelerator.num_processes - rays_remaining
|
48 |
+
pnts = torch.cat([pnts, torch.zeros_like(pnts[-padding:])], dim=0)
|
49 |
+
else:
|
50 |
+
padding = 0
|
51 |
+
rays_per_host = pnts.shape[0] // accelerator.num_processes
|
52 |
+
start, stop = accelerator.process_index * rays_per_host, \
|
53 |
+
(accelerator.process_index + 1) * rays_per_host
|
54 |
+
chunk_means = pnts[start:stop]
|
55 |
+
chunk_stds = torch.full_like(chunk_means[..., 0], std_value)
|
56 |
+
raw_density = model.nerf_mlp.predict_density(chunk_means[:, None], chunk_stds[:, None], no_warp=True)[0]
|
57 |
+
density = F.softplus(raw_density + model.nerf_mlp.density_bias)
|
58 |
+
density = accelerator.gather(density)
|
59 |
+
if padding > 0:
|
60 |
+
density = density[: -padding]
|
61 |
+
z.append(density)
|
62 |
+
z = torch.cat(z, dim=0)
|
63 |
+
return z
|
64 |
+
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def evaluate_color(model, accelerator: accelerate.Accelerator,
|
68 |
+
points, config: configs.Config, std_value=0.0):
|
69 |
+
"""
|
70 |
+
Evaluate a signed distance function (SDF) for a batch of points.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
sdf: A callable function that takes a tensor of size (N, 3) containing
|
74 |
+
3D points and returns a tensor of size (N,) with the SDF values.
|
75 |
+
points: A torch tensor containing 3D points.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
A torch tensor with the SDF values evaluated at the given points.
|
79 |
+
"""
|
80 |
+
z = []
|
81 |
+
for _, pnts in enumerate(tqdm(torch.split(points, config.render_chunk_size, dim=0),
|
82 |
+
desc="Evaluating color",
|
83 |
+
disable=not accelerator.is_main_process)):
|
84 |
+
rays_remaining = pnts.shape[0] % accelerator.num_processes
|
85 |
+
if rays_remaining != 0:
|
86 |
+
padding = accelerator.num_processes - rays_remaining
|
87 |
+
pnts = torch.cat([pnts, torch.zeros_like(pnts[-padding:])], dim=0)
|
88 |
+
else:
|
89 |
+
padding = 0
|
90 |
+
rays_per_host = pnts.shape[0] // accelerator.num_processes
|
91 |
+
start, stop = accelerator.process_index * rays_per_host, \
|
92 |
+
(accelerator.process_index + 1) * rays_per_host
|
93 |
+
chunk_means = pnts[start:stop]
|
94 |
+
chunk_stds = torch.full_like(chunk_means[..., 0], std_value)
|
95 |
+
chunk_viewdirs = torch.zeros_like(chunk_means)
|
96 |
+
ray_results = model.nerf_mlp(False, chunk_means[:, None, None], chunk_stds[:, None, None],
|
97 |
+
chunk_viewdirs)
|
98 |
+
rgb = ray_results['rgb'][:, 0]
|
99 |
+
rgb = accelerator.gather(rgb)
|
100 |
+
if padding > 0:
|
101 |
+
rgb = rgb[: -padding]
|
102 |
+
z.append(rgb)
|
103 |
+
z = torch.cat(z, dim=0)
|
104 |
+
return z
|
105 |
+
|
106 |
+
|
107 |
+
@torch.no_grad()
|
108 |
+
def evaluate_color_projection(model, accelerator: accelerate.Accelerator, vertices, faces, config: configs.Config):
|
109 |
+
normals = auto_normals(vertices, faces.long())
|
110 |
+
viewdirs = -normals
|
111 |
+
origins = vertices - 0.005 * viewdirs
|
112 |
+
vc = []
|
113 |
+
chunk = config.render_chunk_size
|
114 |
+
model.num_levels = 1
|
115 |
+
model.opaque_background = True
|
116 |
+
for i in tqdm(range(0, origins.shape[0], chunk),
|
117 |
+
desc="Evaluating color projection",
|
118 |
+
disable=not accelerator.is_main_process):
|
119 |
+
cur_chunk = min(chunk, origins.shape[0] - i)
|
120 |
+
rays_remaining = cur_chunk % accelerator.num_processes
|
121 |
+
rays_per_host = cur_chunk // accelerator.num_processes
|
122 |
+
if rays_remaining != 0:
|
123 |
+
padding = accelerator.num_processes - rays_remaining
|
124 |
+
rays_per_host += 1
|
125 |
+
else:
|
126 |
+
padding = 0
|
127 |
+
start = i + accelerator.process_index * rays_per_host
|
128 |
+
stop = start + rays_per_host
|
129 |
+
|
130 |
+
batch = {
|
131 |
+
'origins': origins[start:stop],
|
132 |
+
'directions': viewdirs[start:stop],
|
133 |
+
'viewdirs': viewdirs[start:stop],
|
134 |
+
'cam_dirs': viewdirs[start:stop],
|
135 |
+
'radii': torch.full_like(origins[start:stop, ..., :1], 0.000723),
|
136 |
+
'near': torch.full_like(origins[start:stop, ..., :1], 0),
|
137 |
+
'far': torch.full_like(origins[start:stop, ..., :1], 0.01),
|
138 |
+
}
|
139 |
+
batch = accelerator.pad_across_processes(batch)
|
140 |
+
with accelerator.autocast():
|
141 |
+
renderings, ray_history = model(
|
142 |
+
False,
|
143 |
+
batch,
|
144 |
+
compute_extras=False,
|
145 |
+
train_frac=1)
|
146 |
+
rgb = renderings[-1]['rgb']
|
147 |
+
acc = renderings[-1]['acc']
|
148 |
+
|
149 |
+
rgb /= acc.clamp_min(1e-5)[..., None]
|
150 |
+
rgb = rgb.clamp(0, 1)
|
151 |
+
|
152 |
+
rgb = accelerator.gather(rgb)
|
153 |
+
rgb[torch.isnan(rgb) | torch.isinf(rgb)] = 1
|
154 |
+
if padding > 0:
|
155 |
+
rgb = rgb[: -padding]
|
156 |
+
vc.append(rgb)
|
157 |
+
vc = torch.cat(vc, dim=0)
|
158 |
+
return vc
|
159 |
+
|
160 |
+
|
161 |
+
def auto_normals(verts, faces):
|
162 |
+
i0 = faces[:, 0]
|
163 |
+
i1 = faces[:, 1]
|
164 |
+
i2 = faces[:, 2]
|
165 |
+
|
166 |
+
v0 = verts[i0, :]
|
167 |
+
v1 = verts[i1, :]
|
168 |
+
v2 = verts[i2, :]
|
169 |
+
|
170 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
171 |
+
|
172 |
+
# Splat face normals to vertices
|
173 |
+
v_nrm = torch.zeros_like(verts)
|
174 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
175 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
176 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
177 |
+
|
178 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
179 |
+
v_nrm = torch.where((v_nrm ** 2).sum(dim=-1, keepdims=True) > 1e-20, v_nrm,
|
180 |
+
torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=verts.device))
|
181 |
+
v_nrm = F.normalize(v_nrm, dim=-1)
|
182 |
+
return v_nrm
|
183 |
+
|
184 |
+
|
185 |
+
def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01, logger=None, main_process=True):
|
186 |
+
# verts: [N, 3]
|
187 |
+
# faces: [N, 3]
|
188 |
+
tbar = tqdm(total=9, desc='Clean mesh', leave=False, disable=not main_process)
|
189 |
+
_ori_vert_shape = verts.shape
|
190 |
+
_ori_face_shape = faces.shape
|
191 |
+
|
192 |
+
m = pml.Mesh(verts, faces)
|
193 |
+
ms = pml.MeshSet()
|
194 |
+
ms.add_mesh(m, 'mesh') # will copy!
|
195 |
+
|
196 |
+
# filters
|
197 |
+
tbar.set_description('Remove unreferenced vertices')
|
198 |
+
ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
|
199 |
+
tbar.update()
|
200 |
+
|
201 |
+
if v_pct > 0:
|
202 |
+
tbar.set_description('Remove unreferenced vertices')
|
203 |
+
ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
|
204 |
+
tbar.update()
|
205 |
+
|
206 |
+
tbar.set_description('Remove duplicate faces')
|
207 |
+
ms.meshing_remove_duplicate_faces() # faces defined by the same verts
|
208 |
+
tbar.update()
|
209 |
+
|
210 |
+
tbar.set_description('Remove null faces')
|
211 |
+
ms.meshing_remove_null_faces() # faces with area == 0
|
212 |
+
tbar.update()
|
213 |
+
|
214 |
+
if min_d > 0:
|
215 |
+
tbar.set_description('Remove connected component by diameter')
|
216 |
+
ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
|
217 |
+
tbar.update()
|
218 |
+
|
219 |
+
if min_f > 0:
|
220 |
+
tbar.set_description('Remove connected component by face number')
|
221 |
+
ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
|
222 |
+
tbar.update()
|
223 |
+
|
224 |
+
if repair:
|
225 |
+
# tbar.set_description('Remove t vertices')
|
226 |
+
# ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
|
227 |
+
tbar.set_description('Repair non manifold edges')
|
228 |
+
ms.meshing_repair_non_manifold_edges(method=0)
|
229 |
+
tbar.update()
|
230 |
+
tbar.set_description('Repair non manifold vertices')
|
231 |
+
ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
|
232 |
+
tbar.update()
|
233 |
+
else:
|
234 |
+
tbar.update(2)
|
235 |
+
if remesh:
|
236 |
+
# tbar.set_description('Coord taubin smoothing')
|
237 |
+
# ms.apply_coord_taubin_smoothing()
|
238 |
+
tbar.set_description('Isotropic explicit remeshing')
|
239 |
+
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))
|
240 |
+
tbar.update()
|
241 |
+
|
242 |
+
# extract mesh
|
243 |
+
m = ms.current_mesh()
|
244 |
+
verts = m.vertex_matrix()
|
245 |
+
faces = m.face_matrix()
|
246 |
+
|
247 |
+
if logger is not None:
|
248 |
+
logger.info(f'Mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
|
249 |
+
|
250 |
+
return verts, faces
|
251 |
+
|
252 |
+
|
253 |
+
def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True, logger=None):
|
254 |
+
# optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
|
255 |
+
|
256 |
+
_ori_vert_shape = verts.shape
|
257 |
+
_ori_face_shape = faces.shape
|
258 |
+
|
259 |
+
if backend == 'pyfqmr':
|
260 |
+
import pyfqmr
|
261 |
+
solver = pyfqmr.Simplify()
|
262 |
+
solver.setMesh(verts, faces)
|
263 |
+
solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
|
264 |
+
verts, faces, normals = solver.getMesh()
|
265 |
+
else:
|
266 |
+
|
267 |
+
m = pml.Mesh(verts, faces)
|
268 |
+
ms = pml.MeshSet()
|
269 |
+
ms.add_mesh(m, 'mesh') # will copy!
|
270 |
+
|
271 |
+
# filters
|
272 |
+
# ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
|
273 |
+
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
|
274 |
+
|
275 |
+
if remesh:
|
276 |
+
# ms.apply_coord_taubin_smoothing()
|
277 |
+
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
|
278 |
+
|
279 |
+
# extract mesh
|
280 |
+
m = ms.current_mesh()
|
281 |
+
verts = m.vertex_matrix()
|
282 |
+
faces = m.face_matrix()
|
283 |
+
|
284 |
+
if logger is not None:
|
285 |
+
logger.info(f'Mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
|
286 |
+
|
287 |
+
return verts, faces
|
288 |
+
|
289 |
+
|
290 |
+
def main(unused_argv):
|
291 |
+
config = configs.load_config()
|
292 |
+
config.compute_visibility = True
|
293 |
+
|
294 |
+
config.exp_path = os.path.join("exp", config.exp_name)
|
295 |
+
config.mesh_path = os.path.join("exp", config.exp_name, "mesh")
|
296 |
+
config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints')
|
297 |
+
os.makedirs(config.mesh_path, exist_ok=True)
|
298 |
+
|
299 |
+
# accelerator for DDP
|
300 |
+
accelerator = accelerate.Accelerator()
|
301 |
+
device = accelerator.device
|
302 |
+
|
303 |
+
# setup logger
|
304 |
+
logging.basicConfig(
|
305 |
+
format="%(asctime)s: %(message)s",
|
306 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
307 |
+
force=True,
|
308 |
+
handlers=[logging.StreamHandler(sys.stdout),
|
309 |
+
logging.FileHandler(os.path.join(config.exp_path, 'log_extract.txt'))],
|
310 |
+
level=logging.INFO,
|
311 |
+
)
|
312 |
+
sys.excepthook = utils.handle_exception
|
313 |
+
logger = accelerate.logging.get_logger(__name__)
|
314 |
+
logger.info(config)
|
315 |
+
logger.info(accelerator.state, main_process_only=False)
|
316 |
+
|
317 |
+
config.world_size = accelerator.num_processes
|
318 |
+
config.global_rank = accelerator.process_index
|
319 |
+
accelerate.utils.set_seed(config.seed, device_specific=True)
|
320 |
+
|
321 |
+
# setup model and optimizer
|
322 |
+
model = models.Model(config=config)
|
323 |
+
model = accelerator.prepare(model)
|
324 |
+
step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger)
|
325 |
+
model.eval()
|
326 |
+
module = accelerator.unwrap_model(model)
|
327 |
+
|
328 |
+
visibility_path = os.path.join(config.mesh_path, 'visibility_mask_{:.1f}.pt'.format(config.mesh_radius))
|
329 |
+
visibility_resolution = config.visibility_resolution
|
330 |
+
if not os.path.exists(visibility_path):
|
331 |
+
logger.info('Generate visibility mask...')
|
332 |
+
# load dataset
|
333 |
+
dataset = datasets.load_dataset('train', config.data_dir, config)
|
334 |
+
dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)),
|
335 |
+
num_workers=4,
|
336 |
+
shuffle=True,
|
337 |
+
batch_size=1,
|
338 |
+
collate_fn=dataset.collate_fn,
|
339 |
+
persistent_workers=True,
|
340 |
+
)
|
341 |
+
|
342 |
+
visibility_mask = torch.ones(
|
343 |
+
(1, 1, visibility_resolution, visibility_resolution, visibility_resolution), requires_grad=True
|
344 |
+
).to(device)
|
345 |
+
visibility_mask.retain_grad()
|
346 |
+
tbar = tqdm(dataloader, desc='Generating visibility grid', disable=not accelerator.is_main_process)
|
347 |
+
for index, batch in enumerate(tbar):
|
348 |
+
batch = accelerate.utils.send_to_device(batch, accelerator.device)
|
349 |
+
|
350 |
+
rendering = models.render_image(model, accelerator,
|
351 |
+
batch, False, 1, config,
|
352 |
+
verbose=False, return_weights=True)
|
353 |
+
|
354 |
+
coords = rendering['coord'].reshape(-1, 3)
|
355 |
+
weights = rendering['weights'].reshape(-1)
|
356 |
+
|
357 |
+
valid_points = coords[weights > config.valid_weight_thresh]
|
358 |
+
valid_points /= config.mesh_radius
|
359 |
+
# update mask based on ray samples
|
360 |
+
with torch.enable_grad():
|
361 |
+
out = torch.nn.functional.grid_sample(visibility_mask,
|
362 |
+
valid_points[None, None, None],
|
363 |
+
align_corners=True)
|
364 |
+
out.sum().backward()
|
365 |
+
tbar.set_postfix({"visibility_mask": (visibility_mask.grad > 0.0001).float().mean().item()})
|
366 |
+
# if index == 10:
|
367 |
+
# break
|
368 |
+
visibility_mask = (visibility_mask.grad > 0.0001).float()
|
369 |
+
if accelerator.is_main_process:
|
370 |
+
torch.save(visibility_mask.detach().cpu(), visibility_path)
|
371 |
+
else:
|
372 |
+
logger.info('Load visibility mask from {}'.format(visibility_path))
|
373 |
+
visibility_mask = torch.load(visibility_path, map_location=device)
|
374 |
+
|
375 |
+
space = config.mesh_radius * 2 / (config.visibility_resolution - 1)
|
376 |
+
|
377 |
+
logger.info("Extract mesh from visibility mask...")
|
378 |
+
visibility_mask_np = visibility_mask[0, 0].permute(2, 1, 0).detach().cpu().numpy()
|
379 |
+
verts, faces, normals, values = measure.marching_cubes(
|
380 |
+
volume=-visibility_mask_np,
|
381 |
+
level=-0.5,
|
382 |
+
spacing=(space, space, space))
|
383 |
+
verts -= config.mesh_radius
|
384 |
+
if config.extract_visibility:
|
385 |
+
meshexport = trimesh.Trimesh(verts, faces)
|
386 |
+
meshexport.export(os.path.join(config.mesh_path, "visibility_mask_{}.ply".format(config.mesh_radius)), "ply")
|
387 |
+
logger.info("Extract visibility mask done.")
|
388 |
+
|
389 |
+
# Initialize variables
|
390 |
+
crop_n = 512
|
391 |
+
grid_min = verts.min(axis=0)
|
392 |
+
grid_max = verts.max(axis=0)
|
393 |
+
space = ((grid_max - grid_min).prod() / config.mesh_voxels) ** (1 / 3)
|
394 |
+
world_size = ((grid_max - grid_min) / space).astype(np.int32)
|
395 |
+
Nx, Ny, Nz = np.maximum(1, world_size // crop_n)
|
396 |
+
crop_n_x, crop_n_y, crop_n_z = world_size // [Nx, Ny, Nz]
|
397 |
+
xs = np.linspace(grid_min[0], grid_max[0], Nx + 1)
|
398 |
+
ys = np.linspace(grid_min[1], grid_max[1], Ny + 1)
|
399 |
+
zs = np.linspace(grid_min[2], grid_max[2], Nz + 1)
|
400 |
+
# Initialize meshes list
|
401 |
+
meshes = []
|
402 |
+
|
403 |
+
# Iterate over the grid
|
404 |
+
for i in range(Nx):
|
405 |
+
for j in range(Ny):
|
406 |
+
for k in range(Nz):
|
407 |
+
logger.info(f"Process grid cell ({i + 1}/{Nx}, {j + 1}/{Ny}, {k + 1}/{Nz})...")
|
408 |
+
# Calculate grid cell boundaries
|
409 |
+
x_min, x_max = xs[i], xs[i + 1]
|
410 |
+
y_min, y_max = ys[j], ys[j + 1]
|
411 |
+
z_min, z_max = zs[k], zs[k + 1]
|
412 |
+
|
413 |
+
# Create point grid
|
414 |
+
x = np.linspace(x_min, x_max, crop_n_x)
|
415 |
+
y = np.linspace(y_min, y_max, crop_n_y)
|
416 |
+
z = np.linspace(z_min, z_max, crop_n_z)
|
417 |
+
xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
|
418 |
+
points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T,
|
419 |
+
dtype=torch.float,
|
420 |
+
device=device)
|
421 |
+
# Construct point pyramids
|
422 |
+
points_tmp = points.reshape(crop_n_x, crop_n_y, crop_n_z, 3)[None]
|
423 |
+
points_tmp /= config.mesh_radius
|
424 |
+
current_mask = torch.nn.functional.grid_sample(visibility_mask, points_tmp, align_corners=True)
|
425 |
+
current_mask = (current_mask > 0.0).cpu().numpy()[0, 0]
|
426 |
+
|
427 |
+
pts_density = evaluate_density(module, accelerator, points,
|
428 |
+
config, std_value=config.std_value)
|
429 |
+
|
430 |
+
# bound the vertices
|
431 |
+
points_world = coord.inv_contract(2 * points)
|
432 |
+
pts_density[points_world.norm(dim=-1) > config.mesh_max_radius] = 0.0
|
433 |
+
|
434 |
+
z = pts_density.detach().cpu().numpy()
|
435 |
+
|
436 |
+
# Skip if no surface found
|
437 |
+
valid_z = z.reshape(crop_n_x, crop_n_y, crop_n_z)[current_mask]
|
438 |
+
if valid_z.shape[0] <= 0 or (
|
439 |
+
np.min(valid_z) > config.isosurface_threshold or np.max(
|
440 |
+
valid_z) < config.isosurface_threshold
|
441 |
+
):
|
442 |
+
continue
|
443 |
+
|
444 |
+
if not (np.min(z) > config.isosurface_threshold or np.max(z) < config.isosurface_threshold):
|
445 |
+
# Extract mesh
|
446 |
+
logger.info('Extract mesh...')
|
447 |
+
z = z.astype(np.float32)
|
448 |
+
verts, faces, _, _ = measure.marching_cubes(
|
449 |
+
volume=-z.reshape(crop_n_x, crop_n_y, crop_n_z),
|
450 |
+
level=-config.isosurface_threshold,
|
451 |
+
spacing=(
|
452 |
+
(x_max - x_min) / (crop_n_x - 1),
|
453 |
+
(y_max - y_min) / (crop_n_y - 1),
|
454 |
+
(z_max - z_min) / (crop_n_z - 1),
|
455 |
+
),
|
456 |
+
mask=current_mask,
|
457 |
+
)
|
458 |
+
verts = verts + np.array([x_min, y_min, z_min])
|
459 |
+
|
460 |
+
meshcrop = trimesh.Trimesh(verts, faces)
|
461 |
+
logger.info('Extract vertices: {}, faces: {}'.format(meshcrop.vertices.shape[0],
|
462 |
+
meshcrop.faces.shape[0]))
|
463 |
+
meshes.append(meshcrop)
|
464 |
+
# Save mesh
|
465 |
+
logger.info('Concatenate mesh...')
|
466 |
+
combined_mesh = trimesh.util.concatenate(meshes)
|
467 |
+
|
468 |
+
# from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/renderer.py
|
469 |
+
# clean
|
470 |
+
logger.info('Clean mesh...')
|
471 |
+
vertices = combined_mesh.vertices.astype(np.float32)
|
472 |
+
faces = combined_mesh.faces.astype(np.int32)
|
473 |
+
|
474 |
+
vertices, faces = clean_mesh(vertices, faces,
|
475 |
+
remesh=False, remesh_size=0.01,
|
476 |
+
logger=logger, main_process=accelerator.is_main_process)
|
477 |
+
|
478 |
+
v = torch.from_numpy(vertices).contiguous().float().to(device)
|
479 |
+
v = coord.inv_contract(2 * v)
|
480 |
+
vertices = v.detach().cpu().numpy()
|
481 |
+
f = torch.from_numpy(faces).contiguous().int().to(device)
|
482 |
+
|
483 |
+
# decimation
|
484 |
+
if config.decimate_target > 0 and faces.shape[0] > config.decimate_target:
|
485 |
+
logger.info('Decimate mesh...')
|
486 |
+
vertices, triangles = decimate_mesh(vertices, faces, config.decimate_target, logger=logger)
|
487 |
+
# import ipdb; ipdb.set_trace()
|
488 |
+
if config.vertex_color:
|
489 |
+
# batched inference to avoid OOM
|
490 |
+
logger.info('Evaluate mesh vertex color...')
|
491 |
+
if config.vertex_projection:
|
492 |
+
rgbs = evaluate_color_projection(module, accelerator, v, f, config)
|
493 |
+
else:
|
494 |
+
rgbs = evaluate_color(module, accelerator, v,
|
495 |
+
config, std_value=config.std_value)
|
496 |
+
rgbs = (rgbs * 255).detach().cpu().numpy().astype(np.uint8)
|
497 |
+
if accelerator.is_main_process:
|
498 |
+
logger.info('Export mesh (vertex color)...')
|
499 |
+
mesh = trimesh.Trimesh(vertices, faces,
|
500 |
+
vertex_colors=rgbs,
|
501 |
+
process=False) # important, process=True leads to seg fault...
|
502 |
+
mesh.export(os.path.join(config.mesh_path, 'mesh_{}.ply'.format(config.mesh_radius)))
|
503 |
+
logger.info('Finish extracting mesh.')
|
504 |
+
return
|
505 |
+
|
506 |
+
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
|
507 |
+
logger.info('Export mesh (atlas)...')
|
508 |
+
# v, f: torch Tensor
|
509 |
+
device = v.device
|
510 |
+
v_np = v.cpu().numpy() # [N, 3]
|
511 |
+
f_np = f.cpu().numpy() # [M, 3]
|
512 |
+
|
513 |
+
# unwrap uvs
|
514 |
+
import xatlas
|
515 |
+
import nvdiffrast.torch as dr
|
516 |
+
from sklearn.neighbors import NearestNeighbors
|
517 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
518 |
+
|
519 |
+
logger.info(f'Running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
|
520 |
+
atlas = xatlas.Atlas()
|
521 |
+
atlas.add_mesh(v_np, f_np)
|
522 |
+
chart_options = xatlas.ChartOptions()
|
523 |
+
chart_options.max_iterations = 4 # for faster unwrap...
|
524 |
+
atlas.generate(chart_options=chart_options)
|
525 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
526 |
+
|
527 |
+
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
|
528 |
+
|
529 |
+
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
|
530 |
+
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
|
531 |
+
|
532 |
+
# render uv maps
|
533 |
+
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
|
534 |
+
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
|
535 |
+
|
536 |
+
if ssaa > 1:
|
537 |
+
h = int(h0 * ssaa)
|
538 |
+
w = int(w0 * ssaa)
|
539 |
+
else:
|
540 |
+
h, w = h0, w0
|
541 |
+
|
542 |
+
if h <= 2048 and w <= 2048:
|
543 |
+
glctx = dr.RasterizeCudaContext()
|
544 |
+
else:
|
545 |
+
glctx = dr.RasterizeGLContext()
|
546 |
+
|
547 |
+
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
|
548 |
+
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
|
549 |
+
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
|
550 |
+
|
551 |
+
# masked query
|
552 |
+
xyzs = xyzs.view(-1, 3)
|
553 |
+
mask = (mask > 0).view(-1)
|
554 |
+
|
555 |
+
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
|
556 |
+
|
557 |
+
if mask.any():
|
558 |
+
xyzs = xyzs[mask] # [M, 3]
|
559 |
+
|
560 |
+
# batched inference to avoid OOM
|
561 |
+
all_feats = evaluate_color(module, accelerator, xyzs,
|
562 |
+
config, std_value=config.std_value)
|
563 |
+
feats[mask] = all_feats
|
564 |
+
|
565 |
+
feats = feats.view(h, w, -1)
|
566 |
+
mask = mask.view(h, w)
|
567 |
+
|
568 |
+
# quantize [0.0, 1.0] to [0, 255]
|
569 |
+
feats = feats.cpu().numpy()
|
570 |
+
feats = (feats * 255).astype(np.uint8)
|
571 |
+
|
572 |
+
### NN search as an antialiasing ...
|
573 |
+
mask = mask.cpu().numpy()
|
574 |
+
|
575 |
+
inpaint_region = binary_dilation(mask, iterations=3)
|
576 |
+
inpaint_region[mask] = 0
|
577 |
+
|
578 |
+
search_region = mask.copy()
|
579 |
+
not_search_region = binary_erosion(search_region, iterations=2)
|
580 |
+
search_region[not_search_region] = 0
|
581 |
+
|
582 |
+
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
583 |
+
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
584 |
+
|
585 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
|
586 |
+
_, indices = knn.kneighbors(inpaint_coords)
|
587 |
+
|
588 |
+
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
|
589 |
+
|
590 |
+
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
|
591 |
+
|
592 |
+
# do ssaa after the NN search, in numpy
|
593 |
+
if ssaa > 1:
|
594 |
+
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
|
595 |
+
|
596 |
+
cv2.imwrite(os.path.join(config.mesh_path, f'{name}albedo.png'), feats)
|
597 |
+
|
598 |
+
# save obj (v, vt, f /)
|
599 |
+
obj_file = os.path.join(config.mesh_path, f'{name}mesh.obj')
|
600 |
+
mtl_file = os.path.join(config.mesh_path, f'{name}mesh.mtl')
|
601 |
+
|
602 |
+
logger.info(f'writing obj mesh to {obj_file}')
|
603 |
+
with open(obj_file, "w") as fp:
|
604 |
+
fp.write(f'mtllib {name}mesh.mtl \n')
|
605 |
+
|
606 |
+
logger.info(f'writing vertices {v_np.shape}')
|
607 |
+
for v in v_np:
|
608 |
+
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
|
609 |
+
|
610 |
+
logger.info(f'writing vertices texture coords {vt_np.shape}')
|
611 |
+
for v in vt_np:
|
612 |
+
fp.write(f'vt {v[0]} {1 - v[1]} \n')
|
613 |
+
|
614 |
+
logger.info(f'writing faces {f_np.shape}')
|
615 |
+
fp.write(f'usemtl mat0 \n')
|
616 |
+
for i in range(len(f_np)):
|
617 |
+
fp.write(
|
618 |
+
f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
|
619 |
+
|
620 |
+
with open(mtl_file, "w") as fp:
|
621 |
+
fp.write(f'newmtl mat0 \n')
|
622 |
+
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
|
623 |
+
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
|
624 |
+
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
|
625 |
+
fp.write(f'Tr 1.000000 \n')
|
626 |
+
fp.write(f'illum 1 \n')
|
627 |
+
fp.write(f'Ns 0.000000 \n')
|
628 |
+
fp.write(f'map_Kd {name}albedo.png \n')
|
629 |
+
|
630 |
+
# could be extremely slow
|
631 |
+
_export(v, f)
|
632 |
+
|
633 |
+
logger.info('Finish extracting mesh.')
|
634 |
+
|
635 |
+
|
636 |
+
if __name__ == '__main__':
|
637 |
+
with gin.config_scope('bake'):
|
638 |
+
app.run(main)
|
gridencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .grid import GridEncoder
|
gridencoder/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_grid_encoder',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'gridencoder.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
gridencoder/grid.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _gridencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
_gridtype_to_id = {
|
15 |
+
'hash': 0,
|
16 |
+
'tiled': 1,
|
17 |
+
}
|
18 |
+
|
19 |
+
_interp_to_id = {
|
20 |
+
'linear': 0,
|
21 |
+
'smoothstep': 1,
|
22 |
+
}
|
23 |
+
|
24 |
+
class _grid_encode(Function):
|
25 |
+
@staticmethod
|
26 |
+
@custom_fwd
|
27 |
+
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0):
|
28 |
+
# inputs: [B, D], float in [0, 1]
|
29 |
+
# embeddings: [sO, C], float
|
30 |
+
# offsets: [L + 1], int
|
31 |
+
# RETURN: [B, F], float
|
32 |
+
|
33 |
+
inputs = inputs.contiguous()
|
34 |
+
|
35 |
+
B, D = inputs.shape # batch size, coord dim
|
36 |
+
L = offsets.shape[0] - 1 # level
|
37 |
+
C = embeddings.shape[1] # embedding dim for each level
|
38 |
+
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
39 |
+
H = base_resolution # base resolution
|
40 |
+
|
41 |
+
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
42 |
+
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
43 |
+
if torch.is_autocast_enabled() and C % 2 == 0:
|
44 |
+
embeddings = embeddings.to(torch.half)
|
45 |
+
|
46 |
+
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
47 |
+
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
48 |
+
|
49 |
+
if calc_grad_inputs:
|
50 |
+
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
51 |
+
else:
|
52 |
+
dy_dx = None
|
53 |
+
|
54 |
+
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation)
|
55 |
+
|
56 |
+
# permute back to [B, L * C]
|
57 |
+
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
58 |
+
|
59 |
+
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
60 |
+
ctx.dims = [B, D, C, L, S, H, gridtype, interpolation]
|
61 |
+
ctx.align_corners = align_corners
|
62 |
+
|
63 |
+
return outputs
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
#@once_differentiable
|
67 |
+
@custom_bwd
|
68 |
+
def backward(ctx, grad):
|
69 |
+
|
70 |
+
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
71 |
+
B, D, C, L, S, H, gridtype, interpolation = ctx.dims
|
72 |
+
align_corners = ctx.align_corners
|
73 |
+
|
74 |
+
# grad: [B, L * C] --> [L, B, C]
|
75 |
+
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
76 |
+
|
77 |
+
grad_embeddings = torch.zeros_like(embeddings)
|
78 |
+
|
79 |
+
if dy_dx is not None:
|
80 |
+
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
81 |
+
else:
|
82 |
+
grad_inputs = None
|
83 |
+
|
84 |
+
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
|
85 |
+
|
86 |
+
if dy_dx is not None:
|
87 |
+
grad_inputs = grad_inputs.to(inputs.dtype)
|
88 |
+
|
89 |
+
return grad_inputs, grad_embeddings, None, None, None, None, None, None, None
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
grid_encode = _grid_encode.apply
|
94 |
+
|
95 |
+
|
96 |
+
class GridEncoder(nn.Module):
|
97 |
+
def __init__(self, input_dim=3, num_levels=16, level_dim=2,
|
98 |
+
per_level_scale=2, base_resolution=16,
|
99 |
+
log2_hashmap_size=19, desired_resolution=None,
|
100 |
+
gridtype='hash', align_corners=False,
|
101 |
+
interpolation='linear', init_std=1e-4):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
105 |
+
if desired_resolution is not None:
|
106 |
+
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
107 |
+
|
108 |
+
self.input_dim = input_dim # coord dims, 2 or 3
|
109 |
+
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
110 |
+
self.level_dim = level_dim # encode channels per level
|
111 |
+
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
112 |
+
self.log2_hashmap_size = log2_hashmap_size
|
113 |
+
self.base_resolution = base_resolution
|
114 |
+
self.output_dim = num_levels * level_dim
|
115 |
+
self.gridtype = gridtype
|
116 |
+
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
117 |
+
self.interpolation = interpolation
|
118 |
+
self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
|
119 |
+
self.align_corners = align_corners
|
120 |
+
self.init_std = init_std
|
121 |
+
|
122 |
+
# allocate parameters
|
123 |
+
resolutions = []
|
124 |
+
offsets = []
|
125 |
+
offset = 0
|
126 |
+
self.max_params = 2 ** log2_hashmap_size
|
127 |
+
for i in range(num_levels):
|
128 |
+
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
129 |
+
resolution = (resolution if align_corners else resolution + 1)
|
130 |
+
params_in_level = min(self.max_params, resolution ** input_dim) # limit max number
|
131 |
+
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
132 |
+
resolutions.append(resolution)
|
133 |
+
offsets.append(offset)
|
134 |
+
offset += params_in_level
|
135 |
+
offsets.append(offset)
|
136 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
137 |
+
self.register_buffer('offsets', offsets)
|
138 |
+
idx = torch.empty(offset, dtype=torch.long)
|
139 |
+
for i in range(self.num_levels):
|
140 |
+
idx[offsets[i]:offsets[i+1]] = i
|
141 |
+
self.register_buffer('idx', idx)
|
142 |
+
self.register_buffer('grid_sizes', torch.from_numpy(np.array(resolutions, dtype=np.int32)))
|
143 |
+
|
144 |
+
self.n_params = offsets[-1] * level_dim
|
145 |
+
|
146 |
+
# parameters
|
147 |
+
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
148 |
+
|
149 |
+
self.reset_parameters()
|
150 |
+
|
151 |
+
def reset_parameters(self):
|
152 |
+
std = self.init_std
|
153 |
+
self.embeddings.data.uniform_(-std, std)
|
154 |
+
|
155 |
+
def __repr__(self):
|
156 |
+
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
|
157 |
+
|
158 |
+
def forward(self, inputs, bound=1):
|
159 |
+
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
160 |
+
# return: [..., num_levels * level_dim]
|
161 |
+
|
162 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
163 |
+
# inputs = inputs.clamp(0, 1)
|
164 |
+
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
165 |
+
|
166 |
+
prefix_shape = list(inputs.shape[:-1])
|
167 |
+
inputs = inputs.view(-1, self.input_dim)
|
168 |
+
|
169 |
+
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id)
|
170 |
+
outputs = outputs.view(prefix_shape + [self.output_dim])
|
171 |
+
|
172 |
+
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
173 |
+
|
174 |
+
return outputs
|
175 |
+
|
176 |
+
# always run in float precision!
|
177 |
+
@torch.cuda.amp.autocast(enabled=False)
|
178 |
+
def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
|
179 |
+
# inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
|
180 |
+
|
181 |
+
D = self.input_dim
|
182 |
+
C = self.embeddings.shape[1] # embedding dim for each level
|
183 |
+
L = self.offsets.shape[0] - 1 # level
|
184 |
+
S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
185 |
+
H = self.base_resolution # base resolution
|
186 |
+
|
187 |
+
if inputs is None:
|
188 |
+
# randomized in [0, 1]
|
189 |
+
inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
|
190 |
+
else:
|
191 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
192 |
+
inputs = inputs.view(-1, self.input_dim)
|
193 |
+
B = inputs.shape[0]
|
194 |
+
|
195 |
+
if self.embeddings.grad is None:
|
196 |
+
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
|
197 |
+
|
198 |
+
_backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
|
gridencoder/setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
setup(
|
33 |
+
name='gridencoder', # package name, import this to use python API
|
34 |
+
ext_modules=[
|
35 |
+
CUDAExtension(
|
36 |
+
name='_gridencoder', # extension name, import this to use CUDA API
|
37 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
38 |
+
'gridencoder.cu',
|
39 |
+
'bindings.cpp',
|
40 |
+
]],
|
41 |
+
extra_compile_args={
|
42 |
+
'cxx': c_flags,
|
43 |
+
'nvcc': nvcc_flags,
|
44 |
+
}
|
45 |
+
),
|
46 |
+
],
|
47 |
+
cmdclass={
|
48 |
+
'build_ext': BuildExtension,
|
49 |
+
}
|
50 |
+
)
|
gridencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "gridencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
7 |
+
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
8 |
+
m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
|
9 |
+
}
|
gridencoder/src/gridencoder.cu
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <torch/torch.h>
|
7 |
+
|
8 |
+
#include <algorithm>
|
9 |
+
#include <stdexcept>
|
10 |
+
|
11 |
+
#include <stdint.h>
|
12 |
+
#include <cstdio>
|
13 |
+
|
14 |
+
|
15 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
16 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
17 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
18 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
19 |
+
|
20 |
+
|
21 |
+
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
|
22 |
+
__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
|
23 |
+
// requires CUDA >= 10 and ARCH >= 70
|
24 |
+
// this is very slow compared to float or __half2, never use it.
|
25 |
+
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
template <typename T>
|
30 |
+
__host__ __device__ inline T div_round_up(T val, T divisor) {
|
31 |
+
return (val + divisor - 1) / divisor;
|
32 |
+
}
|
33 |
+
|
34 |
+
template <typename T, typename T2>
|
35 |
+
__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) {
|
36 |
+
return min(max(v, lo), hi);
|
37 |
+
}
|
38 |
+
|
39 |
+
template <typename T>
|
40 |
+
__device__ inline T smoothstep(T val) {
|
41 |
+
return val*val*(3.0f - 2.0f * val);
|
42 |
+
}
|
43 |
+
|
44 |
+
template <typename T>
|
45 |
+
__device__ inline T smoothstep_derivative(T val) {
|
46 |
+
return 6*val*(1.0f - val);
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
template <uint32_t D>
|
51 |
+
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
52 |
+
|
53 |
+
// coherent type of hashing
|
54 |
+
constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
|
55 |
+
|
56 |
+
uint32_t result = 0;
|
57 |
+
#pragma unroll
|
58 |
+
for (uint32_t i = 0; i < D; ++i) {
|
59 |
+
result ^= pos_grid[i] * primes[i];
|
60 |
+
}
|
61 |
+
|
62 |
+
return result;
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
template <uint32_t D, uint32_t C>
|
67 |
+
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
68 |
+
uint32_t stride = 1;
|
69 |
+
uint32_t index = 0;
|
70 |
+
|
71 |
+
#pragma unroll
|
72 |
+
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
73 |
+
index += pos_grid[d] * stride;
|
74 |
+
stride *= align_corners ? resolution: (resolution + 1);
|
75 |
+
}
|
76 |
+
|
77 |
+
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
78 |
+
// gridtype: 0 == hash, 1 == tiled
|
79 |
+
if (gridtype == 0 && stride > hashmap_size) {
|
80 |
+
index = fast_hash<D>(pos_grid);
|
81 |
+
}
|
82 |
+
|
83 |
+
return (index % hashmap_size) * C + ch;
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
88 |
+
__global__ void kernel_grid(
|
89 |
+
const float * __restrict__ inputs,
|
90 |
+
const scalar_t * __restrict__ grid,
|
91 |
+
const int * __restrict__ offsets,
|
92 |
+
scalar_t * __restrict__ outputs,
|
93 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
94 |
+
scalar_t * __restrict__ dy_dx,
|
95 |
+
const uint32_t gridtype,
|
96 |
+
const bool align_corners,
|
97 |
+
const uint32_t interp
|
98 |
+
) {
|
99 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
100 |
+
|
101 |
+
if (b >= B) return;
|
102 |
+
|
103 |
+
const uint32_t level = blockIdx.y;
|
104 |
+
|
105 |
+
// locate
|
106 |
+
grid += (uint32_t)offsets[level] * C;
|
107 |
+
inputs += b * D;
|
108 |
+
outputs += level * B * C + b * C;
|
109 |
+
|
110 |
+
// check input range (should be in [0, 1])
|
111 |
+
bool flag_oob = false;
|
112 |
+
#pragma unroll
|
113 |
+
for (uint32_t d = 0; d < D; d++) {
|
114 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
115 |
+
flag_oob = true;
|
116 |
+
}
|
117 |
+
}
|
118 |
+
// if input out of bound, just set output to 0
|
119 |
+
if (flag_oob) {
|
120 |
+
#pragma unroll
|
121 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
122 |
+
outputs[ch] = 0;
|
123 |
+
}
|
124 |
+
if (dy_dx) {
|
125 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
126 |
+
#pragma unroll
|
127 |
+
for (uint32_t d = 0; d < D; d++) {
|
128 |
+
#pragma unroll
|
129 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
130 |
+
dy_dx[d * C + ch] = 0;
|
131 |
+
}
|
132 |
+
}
|
133 |
+
}
|
134 |
+
return;
|
135 |
+
}
|
136 |
+
|
137 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
138 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
139 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
140 |
+
|
141 |
+
// calculate coordinate (always use float for precision!)
|
142 |
+
float pos[D];
|
143 |
+
float pos_deriv[D];
|
144 |
+
uint32_t pos_grid[D];
|
145 |
+
|
146 |
+
#pragma unroll
|
147 |
+
for (uint32_t d = 0; d < D; d++) {
|
148 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
149 |
+
pos_grid[d] = floorf(pos[d]);
|
150 |
+
pos[d] -= (float)pos_grid[d];
|
151 |
+
// smoothstep instead of linear
|
152 |
+
if (interp == 1) {
|
153 |
+
pos_deriv[d] = smoothstep_derivative(pos[d]);
|
154 |
+
pos[d] = smoothstep(pos[d]);
|
155 |
+
} else {
|
156 |
+
pos_deriv[d] = 1.0f; // linear deriv is default to 1
|
157 |
+
}
|
158 |
+
|
159 |
+
}
|
160 |
+
|
161 |
+
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
162 |
+
|
163 |
+
// interpolate
|
164 |
+
scalar_t results[C] = {0}; // temp results in register
|
165 |
+
|
166 |
+
#pragma unroll
|
167 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
168 |
+
float w = 1;
|
169 |
+
uint32_t pos_grid_local[D];
|
170 |
+
|
171 |
+
#pragma unroll
|
172 |
+
for (uint32_t d = 0; d < D; d++) {
|
173 |
+
if ((idx & (1 << d)) == 0) {
|
174 |
+
w *= 1 - pos[d];
|
175 |
+
pos_grid_local[d] = pos_grid[d];
|
176 |
+
} else {
|
177 |
+
w *= pos[d];
|
178 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
179 |
+
}
|
180 |
+
}
|
181 |
+
|
182 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
183 |
+
|
184 |
+
// writing to register (fast)
|
185 |
+
#pragma unroll
|
186 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
187 |
+
results[ch] += w * grid[index + ch];
|
188 |
+
}
|
189 |
+
|
190 |
+
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
191 |
+
}
|
192 |
+
|
193 |
+
// writing to global memory (slow)
|
194 |
+
#pragma unroll
|
195 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
196 |
+
outputs[ch] = results[ch];
|
197 |
+
}
|
198 |
+
|
199 |
+
// prepare dy_dx
|
200 |
+
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
201 |
+
if (dy_dx) {
|
202 |
+
|
203 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
204 |
+
|
205 |
+
#pragma unroll
|
206 |
+
for (uint32_t gd = 0; gd < D; gd++) {
|
207 |
+
|
208 |
+
scalar_t results_grad[C] = {0};
|
209 |
+
|
210 |
+
#pragma unroll
|
211 |
+
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
212 |
+
float w = scale;
|
213 |
+
uint32_t pos_grid_local[D];
|
214 |
+
|
215 |
+
#pragma unroll
|
216 |
+
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
217 |
+
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
218 |
+
|
219 |
+
if ((idx & (1 << nd)) == 0) {
|
220 |
+
w *= 1 - pos[d];
|
221 |
+
pos_grid_local[d] = pos_grid[d];
|
222 |
+
} else {
|
223 |
+
w *= pos[d];
|
224 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
225 |
+
}
|
226 |
+
}
|
227 |
+
|
228 |
+
pos_grid_local[gd] = pos_grid[gd];
|
229 |
+
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
230 |
+
pos_grid_local[gd] = pos_grid[gd] + 1;
|
231 |
+
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
232 |
+
|
233 |
+
#pragma unroll
|
234 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
235 |
+
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
|
236 |
+
}
|
237 |
+
}
|
238 |
+
|
239 |
+
#pragma unroll
|
240 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
241 |
+
dy_dx[gd * C + ch] = results_grad[ch];
|
242 |
+
}
|
243 |
+
}
|
244 |
+
}
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
249 |
+
__global__ void kernel_grid_backward(
|
250 |
+
const scalar_t * __restrict__ grad,
|
251 |
+
const float * __restrict__ inputs,
|
252 |
+
const scalar_t * __restrict__ grid,
|
253 |
+
const int * __restrict__ offsets,
|
254 |
+
scalar_t * __restrict__ grad_grid,
|
255 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
256 |
+
const uint32_t gridtype,
|
257 |
+
const bool align_corners,
|
258 |
+
const uint32_t interp
|
259 |
+
) {
|
260 |
+
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
261 |
+
if (b >= B) return;
|
262 |
+
|
263 |
+
const uint32_t level = blockIdx.y;
|
264 |
+
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
265 |
+
|
266 |
+
// locate
|
267 |
+
grad_grid += offsets[level] * C;
|
268 |
+
inputs += b * D;
|
269 |
+
grad += level * B * C + b * C + ch; // L, B, C
|
270 |
+
|
271 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
272 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
273 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
274 |
+
|
275 |
+
// check input range (should be in [0, 1])
|
276 |
+
#pragma unroll
|
277 |
+
for (uint32_t d = 0; d < D; d++) {
|
278 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
279 |
+
return; // grad is init as 0, so we simply return.
|
280 |
+
}
|
281 |
+
}
|
282 |
+
|
283 |
+
// calculate coordinate
|
284 |
+
float pos[D];
|
285 |
+
uint32_t pos_grid[D];
|
286 |
+
|
287 |
+
#pragma unroll
|
288 |
+
for (uint32_t d = 0; d < D; d++) {
|
289 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
290 |
+
pos_grid[d] = floorf(pos[d]);
|
291 |
+
pos[d] -= (float)pos_grid[d];
|
292 |
+
// smoothstep instead of linear
|
293 |
+
if (interp == 1) {
|
294 |
+
pos[d] = smoothstep(pos[d]);
|
295 |
+
}
|
296 |
+
}
|
297 |
+
|
298 |
+
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
299 |
+
#pragma unroll
|
300 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
301 |
+
grad_cur[c] = grad[c];
|
302 |
+
}
|
303 |
+
|
304 |
+
// interpolate
|
305 |
+
#pragma unroll
|
306 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
307 |
+
float w = 1;
|
308 |
+
uint32_t pos_grid_local[D];
|
309 |
+
|
310 |
+
#pragma unroll
|
311 |
+
for (uint32_t d = 0; d < D; d++) {
|
312 |
+
if ((idx & (1 << d)) == 0) {
|
313 |
+
w *= 1 - pos[d];
|
314 |
+
pos_grid_local[d] = pos_grid[d];
|
315 |
+
} else {
|
316 |
+
w *= pos[d];
|
317 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
318 |
+
}
|
319 |
+
}
|
320 |
+
|
321 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
322 |
+
|
323 |
+
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
324 |
+
// TODO: use float which is better than __half, if N_C % 2 != 0
|
325 |
+
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
326 |
+
#pragma unroll
|
327 |
+
for (uint32_t c = 0; c < N_C; c += 2) {
|
328 |
+
// process two __half at once (by interpreting as a __half2)
|
329 |
+
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
330 |
+
atomicAdd((__half2*)&grad_grid[index + c], v);
|
331 |
+
}
|
332 |
+
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
333 |
+
} else {
|
334 |
+
#pragma unroll
|
335 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
336 |
+
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
337 |
+
}
|
338 |
+
}
|
339 |
+
}
|
340 |
+
}
|
341 |
+
|
342 |
+
|
343 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
344 |
+
__global__ void kernel_input_backward(
|
345 |
+
const scalar_t * __restrict__ grad,
|
346 |
+
const scalar_t * __restrict__ dy_dx,
|
347 |
+
scalar_t * __restrict__ grad_inputs,
|
348 |
+
uint32_t B, uint32_t L
|
349 |
+
) {
|
350 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
351 |
+
if (t >= B * D) return;
|
352 |
+
|
353 |
+
const uint32_t b = t / D;
|
354 |
+
const uint32_t d = t - b * D;
|
355 |
+
|
356 |
+
dy_dx += b * L * D * C;
|
357 |
+
|
358 |
+
scalar_t result = 0;
|
359 |
+
|
360 |
+
# pragma unroll
|
361 |
+
for (int l = 0; l < L; l++) {
|
362 |
+
# pragma unroll
|
363 |
+
for (int ch = 0; ch < C; ch++) {
|
364 |
+
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
365 |
+
}
|
366 |
+
}
|
367 |
+
|
368 |
+
grad_inputs[t] = result;
|
369 |
+
}
|
370 |
+
|
371 |
+
|
372 |
+
template <typename scalar_t, uint32_t D>
|
373 |
+
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
374 |
+
static constexpr uint32_t N_THREAD = 512;
|
375 |
+
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
376 |
+
switch (C) {
|
377 |
+
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
378 |
+
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
379 |
+
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
380 |
+
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
381 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
382 |
+
}
|
383 |
+
}
|
384 |
+
|
385 |
+
// inputs: [B, D], float, in [0, 1]
|
386 |
+
// embeddings: [sO, C], float
|
387 |
+
// offsets: [L + 1], uint32_t
|
388 |
+
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
389 |
+
// H: base resolution
|
390 |
+
// dy_dx: [B, L * D * C]
|
391 |
+
template <typename scalar_t>
|
392 |
+
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
393 |
+
switch (D) {
|
394 |
+
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
395 |
+
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
396 |
+
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
397 |
+
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
398 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
399 |
+
}
|
400 |
+
}
|
401 |
+
|
402 |
+
template <typename scalar_t, uint32_t D>
|
403 |
+
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
404 |
+
static constexpr uint32_t N_THREAD = 256;
|
405 |
+
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
406 |
+
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
|
407 |
+
switch (C) {
|
408 |
+
case 1:
|
409 |
+
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
410 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
411 |
+
break;
|
412 |
+
case 2:
|
413 |
+
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
414 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
415 |
+
break;
|
416 |
+
case 4:
|
417 |
+
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
418 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
419 |
+
break;
|
420 |
+
case 8:
|
421 |
+
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
422 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
423 |
+
break;
|
424 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
425 |
+
}
|
426 |
+
}
|
427 |
+
|
428 |
+
|
429 |
+
// grad: [L, B, C], float
|
430 |
+
// inputs: [B, D], float, in [0, 1]
|
431 |
+
// embeddings: [sO, C], float
|
432 |
+
// offsets: [L + 1], uint32_t
|
433 |
+
// grad_embeddings: [sO, C]
|
434 |
+
// H: base resolution
|
435 |
+
template <typename scalar_t>
|
436 |
+
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
437 |
+
switch (D) {
|
438 |
+
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
439 |
+
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
440 |
+
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
441 |
+
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
442 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
443 |
+
}
|
444 |
+
}
|
445 |
+
|
446 |
+
|
447 |
+
|
448 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
449 |
+
CHECK_CUDA(inputs);
|
450 |
+
CHECK_CUDA(embeddings);
|
451 |
+
CHECK_CUDA(offsets);
|
452 |
+
CHECK_CUDA(outputs);
|
453 |
+
// CHECK_CUDA(dy_dx);
|
454 |
+
|
455 |
+
CHECK_CONTIGUOUS(inputs);
|
456 |
+
CHECK_CONTIGUOUS(embeddings);
|
457 |
+
CHECK_CONTIGUOUS(offsets);
|
458 |
+
CHECK_CONTIGUOUS(outputs);
|
459 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
460 |
+
|
461 |
+
CHECK_IS_FLOATING(inputs);
|
462 |
+
CHECK_IS_FLOATING(embeddings);
|
463 |
+
CHECK_IS_INT(offsets);
|
464 |
+
CHECK_IS_FLOATING(outputs);
|
465 |
+
// CHECK_IS_FLOATING(dy_dx);
|
466 |
+
|
467 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
468 |
+
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
469 |
+
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
|
470 |
+
}));
|
471 |
+
}
|
472 |
+
|
473 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
474 |
+
CHECK_CUDA(grad);
|
475 |
+
CHECK_CUDA(inputs);
|
476 |
+
CHECK_CUDA(embeddings);
|
477 |
+
CHECK_CUDA(offsets);
|
478 |
+
CHECK_CUDA(grad_embeddings);
|
479 |
+
// CHECK_CUDA(dy_dx);
|
480 |
+
// CHECK_CUDA(grad_inputs);
|
481 |
+
|
482 |
+
CHECK_CONTIGUOUS(grad);
|
483 |
+
CHECK_CONTIGUOUS(inputs);
|
484 |
+
CHECK_CONTIGUOUS(embeddings);
|
485 |
+
CHECK_CONTIGUOUS(offsets);
|
486 |
+
CHECK_CONTIGUOUS(grad_embeddings);
|
487 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
488 |
+
// CHECK_CONTIGUOUS(grad_inputs);
|
489 |
+
|
490 |
+
CHECK_IS_FLOATING(grad);
|
491 |
+
CHECK_IS_FLOATING(inputs);
|
492 |
+
CHECK_IS_FLOATING(embeddings);
|
493 |
+
CHECK_IS_INT(offsets);
|
494 |
+
CHECK_IS_FLOATING(grad_embeddings);
|
495 |
+
// CHECK_IS_FLOATING(dy_dx);
|
496 |
+
// CHECK_IS_FLOATING(grad_inputs);
|
497 |
+
|
498 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
499 |
+
grad.scalar_type(), "grid_encode_backward", ([&] {
|
500 |
+
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
|
501 |
+
}));
|
502 |
+
|
503 |
+
}
|
504 |
+
|
505 |
+
|
506 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
507 |
+
__global__ void kernel_grad_tv(
|
508 |
+
const scalar_t * __restrict__ inputs,
|
509 |
+
const scalar_t * __restrict__ grid,
|
510 |
+
scalar_t * __restrict__ grad,
|
511 |
+
const int * __restrict__ offsets,
|
512 |
+
const float weight,
|
513 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
514 |
+
const uint32_t gridtype,
|
515 |
+
const bool align_corners
|
516 |
+
) {
|
517 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
518 |
+
|
519 |
+
if (b >= B) return;
|
520 |
+
|
521 |
+
const uint32_t level = blockIdx.y;
|
522 |
+
|
523 |
+
// locate
|
524 |
+
inputs += b * D;
|
525 |
+
grid += (uint32_t)offsets[level] * C;
|
526 |
+
grad += (uint32_t)offsets[level] * C;
|
527 |
+
|
528 |
+
// check input range (should be in [0, 1])
|
529 |
+
bool flag_oob = false;
|
530 |
+
#pragma unroll
|
531 |
+
for (uint32_t d = 0; d < D; d++) {
|
532 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
533 |
+
flag_oob = true;
|
534 |
+
}
|
535 |
+
}
|
536 |
+
|
537 |
+
// if input out of bound, do nothing
|
538 |
+
if (flag_oob) return;
|
539 |
+
|
540 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
541 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
542 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
543 |
+
|
544 |
+
// calculate coordinate
|
545 |
+
float pos[D];
|
546 |
+
uint32_t pos_grid[D]; // [0, resolution]
|
547 |
+
|
548 |
+
#pragma unroll
|
549 |
+
for (uint32_t d = 0; d < D; d++) {
|
550 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
551 |
+
pos_grid[d] = floorf(pos[d]);
|
552 |
+
// pos[d] -= (float)pos_grid[d]; // not used
|
553 |
+
}
|
554 |
+
|
555 |
+
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
556 |
+
|
557 |
+
// total variation on pos_grid
|
558 |
+
scalar_t results[C] = {0}; // temp results in register
|
559 |
+
scalar_t idelta[C] = {0};
|
560 |
+
|
561 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
|
562 |
+
|
563 |
+
scalar_t w = weight / (2 * D);
|
564 |
+
|
565 |
+
#pragma unroll
|
566 |
+
for (uint32_t d = 0; d < D; d++) {
|
567 |
+
|
568 |
+
uint32_t cur_d = pos_grid[d];
|
569 |
+
scalar_t grad_val;
|
570 |
+
|
571 |
+
// right side
|
572 |
+
if (cur_d < resolution) {
|
573 |
+
pos_grid[d] = cur_d + 1;
|
574 |
+
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
|
575 |
+
|
576 |
+
#pragma unroll
|
577 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
578 |
+
// results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f);
|
579 |
+
grad_val = (grid[index + ch] - grid[index_right + ch]);
|
580 |
+
results[ch] += grad_val;
|
581 |
+
idelta[ch] += grad_val * grad_val;
|
582 |
+
}
|
583 |
+
}
|
584 |
+
|
585 |
+
// left side
|
586 |
+
if (cur_d > 0) {
|
587 |
+
pos_grid[d] = cur_d - 1;
|
588 |
+
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid);
|
589 |
+
|
590 |
+
#pragma unroll
|
591 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
592 |
+
// results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f);
|
593 |
+
grad_val = (grid[index + ch] - grid[index_left + ch]);
|
594 |
+
results[ch] += grad_val;
|
595 |
+
idelta[ch] += grad_val * grad_val;
|
596 |
+
}
|
597 |
+
}
|
598 |
+
|
599 |
+
// reset
|
600 |
+
pos_grid[d] = cur_d;
|
601 |
+
}
|
602 |
+
|
603 |
+
// writing to global memory (slow)
|
604 |
+
#pragma unroll
|
605 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
606 |
+
// index may collide, so use atomic!
|
607 |
+
atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
|
608 |
+
}
|
609 |
+
|
610 |
+
}
|
611 |
+
|
612 |
+
|
613 |
+
template <typename scalar_t, uint32_t D>
|
614 |
+
void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
615 |
+
static constexpr uint32_t N_THREAD = 512;
|
616 |
+
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
617 |
+
switch (C) {
|
618 |
+
case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
619 |
+
case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
620 |
+
case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
621 |
+
case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
622 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
623 |
+
}
|
624 |
+
}
|
625 |
+
|
626 |
+
|
627 |
+
template <typename scalar_t>
|
628 |
+
void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
629 |
+
switch (D) {
|
630 |
+
case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
631 |
+
case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
632 |
+
case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
633 |
+
case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
634 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
635 |
+
}
|
636 |
+
}
|
637 |
+
|
638 |
+
|
639 |
+
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
640 |
+
|
641 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
642 |
+
embeddings.scalar_type(), "grad_total_variation", ([&] {
|
643 |
+
grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners);
|
644 |
+
}));
|
645 |
+
}
|
gridencoder/src/gridencoder.h
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _HASH_ENCODE_H
|
2 |
+
#define _HASH_ENCODE_H
|
3 |
+
|
4 |
+
#include <stdint.h>
|
5 |
+
#include <torch/torch.h>
|
6 |
+
|
7 |
+
// inputs: [B, D], float, in [0, 1]
|
8 |
+
// embeddings: [sO, C], float
|
9 |
+
// offsets: [L + 1], uint32_t
|
10 |
+
// outputs: [B, L * C], float
|
11 |
+
// H: base resolution
|
12 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
|
13 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
|
14 |
+
|
15 |
+
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
|
16 |
+
|
17 |
+
#endif
|
internal/camera_utils.py
ADDED
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
from internal import configs
|
3 |
+
from internal import stepfun
|
4 |
+
from internal import utils
|
5 |
+
import numpy as np
|
6 |
+
import scipy
|
7 |
+
|
8 |
+
|
9 |
+
def convert_to_ndc(origins,
|
10 |
+
directions,
|
11 |
+
pixtocam,
|
12 |
+
near: float = 1.):
|
13 |
+
"""Converts a set of rays to normalized device coordinates (NDC).
|
14 |
+
|
15 |
+
Args:
|
16 |
+
origins: ndarray(float32), [..., 3], world space ray origins.
|
17 |
+
directions: ndarray(float32), [..., 3], world space ray directions.
|
18 |
+
pixtocam: ndarray(float32), [3, 3], inverse intrinsic matrix.
|
19 |
+
near: float, near plane along the negative z axis.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
origins_ndc: ndarray(float32), [..., 3].
|
23 |
+
directions_ndc: ndarray(float32), [..., 3].
|
24 |
+
|
25 |
+
This function assumes input rays should be mapped into the NDC space for a
|
26 |
+
perspective projection pinhole camera, with identity extrinsic matrix (pose)
|
27 |
+
and intrinsic parameters defined by inputs focal, width, and height.
|
28 |
+
|
29 |
+
The near value specifies the near plane of the frustum, and the far plane is
|
30 |
+
assumed to be infinity.
|
31 |
+
|
32 |
+
The ray bundle for the identity pose camera will be remapped to parallel rays
|
33 |
+
within the (-1, -1, -1) to (1, 1, 1) cube. Any other ray in the original
|
34 |
+
world space can be remapped as long as it has dz < 0 (ray direction has a
|
35 |
+
negative z-coord); this allows us to share a common NDC space for "forward
|
36 |
+
facing" scenes.
|
37 |
+
|
38 |
+
Note that
|
39 |
+
projection(origins + t * directions)
|
40 |
+
will NOT be equal to
|
41 |
+
origins_ndc + t * directions_ndc
|
42 |
+
and that the directions_ndc are not unit length. Rather, directions_ndc is
|
43 |
+
defined such that the valid near and far planes in NDC will be 0 and 1.
|
44 |
+
|
45 |
+
See Appendix C in https://arxiv.org/abs/2003.08934 for additional details.
|
46 |
+
"""
|
47 |
+
|
48 |
+
# Shift ray origins to near plane, such that oz = -near.
|
49 |
+
# This makes the new near bound equal to 0.
|
50 |
+
t = -(near + origins[..., 2]) / directions[..., 2]
|
51 |
+
origins = origins + t[..., None] * directions
|
52 |
+
|
53 |
+
dx, dy, dz = np.moveaxis(directions, -1, 0)
|
54 |
+
ox, oy, oz = np.moveaxis(origins, -1, 0)
|
55 |
+
|
56 |
+
xmult = 1. / pixtocam[0, 2] # Equal to -2. * focal / cx
|
57 |
+
ymult = 1. / pixtocam[1, 2] # Equal to -2. * focal / cy
|
58 |
+
|
59 |
+
# Perspective projection into NDC for the t = 0 near points
|
60 |
+
# origins + 0 * directions
|
61 |
+
origins_ndc = np.stack([xmult * ox / oz, ymult * oy / oz,
|
62 |
+
-np.ones_like(oz)], axis=-1)
|
63 |
+
|
64 |
+
# Perspective projection into NDC for the t = infinity far points
|
65 |
+
# origins + infinity * directions
|
66 |
+
infinity_ndc = np.stack([xmult * dx / dz, ymult * dy / dz,
|
67 |
+
np.ones_like(oz)],
|
68 |
+
axis=-1)
|
69 |
+
|
70 |
+
# directions_ndc points from origins_ndc to infinity_ndc
|
71 |
+
directions_ndc = infinity_ndc - origins_ndc
|
72 |
+
|
73 |
+
return origins_ndc, directions_ndc
|
74 |
+
|
75 |
+
|
76 |
+
def pad_poses(p):
|
77 |
+
"""Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
|
78 |
+
bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
|
79 |
+
return np.concatenate([p[..., :3, :4], bottom], axis=-2)
|
80 |
+
|
81 |
+
|
82 |
+
def unpad_poses(p):
|
83 |
+
"""Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
|
84 |
+
return p[..., :3, :4]
|
85 |
+
|
86 |
+
|
87 |
+
def recenter_poses(poses):
|
88 |
+
"""Recenter poses around the origin."""
|
89 |
+
cam2world = average_pose(poses)
|
90 |
+
transform = np.linalg.inv(pad_poses(cam2world))
|
91 |
+
poses = transform @ pad_poses(poses)
|
92 |
+
return unpad_poses(poses), transform
|
93 |
+
|
94 |
+
|
95 |
+
def average_pose(poses):
|
96 |
+
"""New pose using average position, z-axis, and up vector of input poses."""
|
97 |
+
position = poses[:, :3, 3].mean(0)
|
98 |
+
z_axis = poses[:, :3, 2].mean(0)
|
99 |
+
up = poses[:, :3, 1].mean(0)
|
100 |
+
cam2world = viewmatrix(z_axis, up, position)
|
101 |
+
return cam2world
|
102 |
+
|
103 |
+
|
104 |
+
def viewmatrix(lookdir, up, position):
|
105 |
+
"""Construct lookat view matrix."""
|
106 |
+
vec2 = normalize(lookdir)
|
107 |
+
vec0 = normalize(np.cross(up, vec2))
|
108 |
+
vec1 = normalize(np.cross(vec2, vec0))
|
109 |
+
m = np.stack([vec0, vec1, vec2, position], axis=1)
|
110 |
+
return m
|
111 |
+
|
112 |
+
|
113 |
+
def normalize(x):
|
114 |
+
"""Normalization helper function."""
|
115 |
+
return x / np.linalg.norm(x)
|
116 |
+
|
117 |
+
|
118 |
+
def focus_point_fn(poses):
|
119 |
+
"""Calculate nearest point to all focal axes in poses."""
|
120 |
+
directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
|
121 |
+
m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
|
122 |
+
mt_m = np.transpose(m, [0, 2, 1]) @ m
|
123 |
+
focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
|
124 |
+
return focus_pt
|
125 |
+
|
126 |
+
|
127 |
+
# Constants for generate_spiral_path():
|
128 |
+
NEAR_STRETCH = .9 # Push forward near bound for forward facing render path.
|
129 |
+
FAR_STRETCH = 5. # Push back far bound for forward facing render path.
|
130 |
+
FOCUS_DISTANCE = .75 # Relative weighting of near, far bounds for render path.
|
131 |
+
|
132 |
+
|
133 |
+
def generate_spiral_path(poses, bounds, n_frames=120, n_rots=2, zrate=.5):
|
134 |
+
"""Calculates a forward facing spiral path for rendering."""
|
135 |
+
# Find a reasonable 'focus depth' for this dataset as a weighted average
|
136 |
+
# of conservative near and far bounds in disparity space.
|
137 |
+
near_bound = bounds.min() * NEAR_STRETCH
|
138 |
+
far_bound = bounds.max() * FAR_STRETCH
|
139 |
+
# All cameras will point towards the world space point (0, 0, -focal).
|
140 |
+
focal = 1 / (((1 - FOCUS_DISTANCE) / near_bound + FOCUS_DISTANCE / far_bound))
|
141 |
+
|
142 |
+
# Get radii for spiral path using 90th percentile of camera positions.
|
143 |
+
positions = poses[:, :3, 3]
|
144 |
+
radii = np.percentile(np.abs(positions), 90, 0)
|
145 |
+
radii = np.concatenate([radii, [1.]])
|
146 |
+
|
147 |
+
# Generate poses for spiral path.
|
148 |
+
render_poses = []
|
149 |
+
cam2world = average_pose(poses)
|
150 |
+
up = poses[:, :3, 1].mean(0)
|
151 |
+
for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
|
152 |
+
t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
|
153 |
+
position = cam2world @ t
|
154 |
+
lookat = cam2world @ [0, 0, -focal, 1.]
|
155 |
+
z_axis = position - lookat
|
156 |
+
render_poses.append(viewmatrix(z_axis, up, position))
|
157 |
+
render_poses = np.stack(render_poses, axis=0)
|
158 |
+
return render_poses
|
159 |
+
|
160 |
+
|
161 |
+
def transform_poses_pca(poses):
|
162 |
+
"""Transforms poses so principal components lie on XYZ axes.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
A tuple (poses, transform), with the transformed poses and the applied
|
169 |
+
camera_to_world transforms.
|
170 |
+
"""
|
171 |
+
t = poses[:, :3, 3]
|
172 |
+
t_mean = t.mean(axis=0)
|
173 |
+
t = t - t_mean
|
174 |
+
|
175 |
+
eigval, eigvec = np.linalg.eig(t.T @ t)
|
176 |
+
# Sort eigenvectors in order of largest to smallest eigenvalue.
|
177 |
+
inds = np.argsort(eigval)[::-1]
|
178 |
+
eigvec = eigvec[:, inds]
|
179 |
+
rot = eigvec.T
|
180 |
+
if np.linalg.det(rot) < 0:
|
181 |
+
rot = np.diag(np.array([1, 1, -1])) @ rot
|
182 |
+
|
183 |
+
transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
|
184 |
+
poses_recentered = unpad_poses(transform @ pad_poses(poses))
|
185 |
+
transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
|
186 |
+
|
187 |
+
# Flip coordinate system if z component of y-axis is negative
|
188 |
+
if poses_recentered.mean(axis=0)[2, 1] < 0:
|
189 |
+
poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
|
190 |
+
transform = np.diag(np.array([1, -1, -1, 1])) @ transform
|
191 |
+
|
192 |
+
# Just make sure it's it in the [-1, 1]^3 cube
|
193 |
+
scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
|
194 |
+
poses_recentered[:, :3, 3] *= scale_factor
|
195 |
+
transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
|
196 |
+
|
197 |
+
return poses_recentered, transform
|
198 |
+
|
199 |
+
|
200 |
+
def generate_ellipse_path(poses, n_frames=120, const_speed=True, z_variation=0., z_phase=0.):
|
201 |
+
"""Generate an elliptical render path based on the given poses."""
|
202 |
+
# Calculate the focal point for the path (cameras point toward this).
|
203 |
+
center = focus_point_fn(poses)
|
204 |
+
# Path height sits at z=0 (in middle of zero-mean capture pattern).
|
205 |
+
offset = np.array([center[0], center[1], 0])
|
206 |
+
|
207 |
+
# Calculate scaling for ellipse axes based on input camera positions.
|
208 |
+
sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
|
209 |
+
# Use ellipse that is symmetric about the focal point in xy.
|
210 |
+
low = -sc + offset
|
211 |
+
high = sc + offset
|
212 |
+
# Optional height variation need not be symmetric
|
213 |
+
z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
|
214 |
+
z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
|
215 |
+
|
216 |
+
def get_positions(theta):
|
217 |
+
# Interpolate between bounds with trig functions to get ellipse in x-y.
|
218 |
+
# Optionally also interpolate in z to change camera height along path.
|
219 |
+
return np.stack([
|
220 |
+
low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
|
221 |
+
low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
|
222 |
+
z_variation * (z_low[2] + (z_high - z_low)[2] *
|
223 |
+
(np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
|
224 |
+
], -1)
|
225 |
+
|
226 |
+
theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
|
227 |
+
positions = get_positions(theta)
|
228 |
+
|
229 |
+
if const_speed:
|
230 |
+
# Resample theta angles so that the velocity is closer to constant.
|
231 |
+
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
|
232 |
+
theta = stepfun.sample_np(None, theta, np.log(lengths), n_frames + 1)
|
233 |
+
positions = get_positions(theta)
|
234 |
+
|
235 |
+
# Throw away duplicated last position.
|
236 |
+
positions = positions[:-1]
|
237 |
+
|
238 |
+
# Set path's up vector to axis closest to average of input pose up vectors.
|
239 |
+
avg_up = poses[:, :3, 1].mean(0)
|
240 |
+
avg_up = avg_up / np.linalg.norm(avg_up)
|
241 |
+
ind_up = np.argmax(np.abs(avg_up))
|
242 |
+
up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
|
243 |
+
|
244 |
+
return np.stack([viewmatrix(p - center, up, p) for p in positions])
|
245 |
+
|
246 |
+
|
247 |
+
def generate_interpolated_path(poses, n_interp, spline_degree=5,
|
248 |
+
smoothness=.03, rot_weight=.1):
|
249 |
+
"""Creates a smooth spline path between input keyframe camera poses.
|
250 |
+
|
251 |
+
Spline is calculated with poses in format (position, lookat-point, up-point).
|
252 |
+
|
253 |
+
Args:
|
254 |
+
poses: (n, 3, 4) array of input pose keyframes.
|
255 |
+
n_interp: returned path will have n_interp * (n - 1) total poses.
|
256 |
+
spline_degree: polynomial degree of B-spline.
|
257 |
+
smoothness: parameter for spline smoothing, 0 forces exact interpolation.
|
258 |
+
rot_weight: relative weighting of rotation/translation in spline solve.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
|
262 |
+
"""
|
263 |
+
|
264 |
+
def poses_to_points(poses, dist):
|
265 |
+
"""Converts from pose matrices to (position, lookat, up) format."""
|
266 |
+
pos = poses[:, :3, -1]
|
267 |
+
lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
|
268 |
+
up = poses[:, :3, -1] + dist * poses[:, :3, 1]
|
269 |
+
return np.stack([pos, lookat, up], 1)
|
270 |
+
|
271 |
+
def points_to_poses(points):
|
272 |
+
"""Converts from (position, lookat, up) format to pose matrices."""
|
273 |
+
return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
|
274 |
+
|
275 |
+
def interp(points, n, k, s):
|
276 |
+
"""Runs multidimensional B-spline interpolation on the input points."""
|
277 |
+
sh = points.shape
|
278 |
+
pts = np.reshape(points, (sh[0], -1))
|
279 |
+
k = min(k, sh[0] - 1)
|
280 |
+
tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
|
281 |
+
u = np.linspace(0, 1, n, endpoint=False)
|
282 |
+
new_points = np.array(scipy.interpolate.splev(u, tck))
|
283 |
+
new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
|
284 |
+
return new_points
|
285 |
+
|
286 |
+
points = poses_to_points(poses, dist=rot_weight)
|
287 |
+
new_points = interp(points,
|
288 |
+
n_interp * (points.shape[0] - 1),
|
289 |
+
k=spline_degree,
|
290 |
+
s=smoothness)
|
291 |
+
return points_to_poses(new_points)
|
292 |
+
|
293 |
+
|
294 |
+
def interpolate_1d(x, n_interp, spline_degree, smoothness):
|
295 |
+
"""Interpolate 1d signal x (by a factor of n_interp times)."""
|
296 |
+
t = np.linspace(0, 1, len(x), endpoint=True)
|
297 |
+
tck = scipy.interpolate.splrep(t, x, s=smoothness, k=spline_degree)
|
298 |
+
n = n_interp * (len(x) - 1)
|
299 |
+
u = np.linspace(0, 1, n, endpoint=False)
|
300 |
+
return scipy.interpolate.splev(u, tck)
|
301 |
+
|
302 |
+
|
303 |
+
def create_render_spline_path(config, image_names, poses, exposures):
|
304 |
+
"""Creates spline interpolation render path from subset of dataset poses.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
config: configs.Config object.
|
308 |
+
image_names: either a directory of images or a text file of image names.
|
309 |
+
poses: [N, 3, 4] array of extrinsic camera pose matrices.
|
310 |
+
exposures: optional list of floating point exposure values.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
spline_indices: list of indices used to select spline keyframe poses.
|
314 |
+
render_poses: array of interpolated extrinsic camera poses for the path.
|
315 |
+
render_exposures: optional list of interpolated exposures for the path.
|
316 |
+
"""
|
317 |
+
if utils.isdir(config.render_spline_keyframes):
|
318 |
+
# If directory, use image filenames.
|
319 |
+
keyframe_names = sorted(utils.listdir(config.render_spline_keyframes))
|
320 |
+
else:
|
321 |
+
# If text file, treat each line as an image filename.
|
322 |
+
with utils.open_file(config.render_spline_keyframes, 'r') as fp:
|
323 |
+
# Decode bytes into string and split into lines.
|
324 |
+
keyframe_names = fp.read().decode('utf-8').splitlines()
|
325 |
+
# Grab poses corresponding to the image filenames.
|
326 |
+
spline_indices = np.array(
|
327 |
+
[i for i, n in enumerate(image_names) if n in keyframe_names])
|
328 |
+
keyframes = poses[spline_indices]
|
329 |
+
render_poses = generate_interpolated_path(
|
330 |
+
keyframes,
|
331 |
+
n_interp=config.render_spline_n_interp,
|
332 |
+
spline_degree=config.render_spline_degree,
|
333 |
+
smoothness=config.render_spline_smoothness,
|
334 |
+
rot_weight=.1)
|
335 |
+
if config.render_spline_interpolate_exposure:
|
336 |
+
if exposures is None:
|
337 |
+
raise ValueError('config.render_spline_interpolate_exposure is True but '
|
338 |
+
'create_render_spline_path() was passed exposures=None.')
|
339 |
+
# Interpolate per-frame exposure value.
|
340 |
+
log_exposure = np.log(exposures[spline_indices])
|
341 |
+
# Use aggressive smoothing for exposure interpolation to avoid flickering.
|
342 |
+
log_exposure_interp = interpolate_1d(
|
343 |
+
log_exposure,
|
344 |
+
config.render_spline_n_interp,
|
345 |
+
spline_degree=5,
|
346 |
+
smoothness=20)
|
347 |
+
render_exposures = np.exp(log_exposure_interp)
|
348 |
+
else:
|
349 |
+
render_exposures = None
|
350 |
+
return spline_indices, render_poses, render_exposures
|
351 |
+
|
352 |
+
|
353 |
+
def intrinsic_matrix(fx, fy, cx, cy):
|
354 |
+
"""Intrinsic matrix for a pinhole camera in OpenCV coordinate system."""
|
355 |
+
return np.array([
|
356 |
+
[fx, 0, cx],
|
357 |
+
[0, fy, cy],
|
358 |
+
[0, 0, 1.],
|
359 |
+
])
|
360 |
+
|
361 |
+
|
362 |
+
def get_pixtocam(focal, width, height):
|
363 |
+
"""Inverse intrinsic matrix for a perfect pinhole camera."""
|
364 |
+
camtopix = intrinsic_matrix(focal, focal, width * .5, height * .5)
|
365 |
+
return np.linalg.inv(camtopix)
|
366 |
+
|
367 |
+
|
368 |
+
def pixel_coordinates(width, height):
|
369 |
+
"""Tuple of the x and y integer coordinates for a grid of pixels."""
|
370 |
+
return np.meshgrid(np.arange(width), np.arange(height), indexing='xy')
|
371 |
+
|
372 |
+
|
373 |
+
def _compute_residual_and_jacobian(x, y, xd, yd,
|
374 |
+
k1=0.0, k2=0.0, k3=0.0,
|
375 |
+
k4=0.0, p1=0.0, p2=0.0, ):
|
376 |
+
"""Auxiliary function of radial_and_tangential_undistort()."""
|
377 |
+
# Adapted from https://github.com/google/nerfies/blob/main/nerfies/camera.py
|
378 |
+
# let r(x, y) = x^2 + y^2;
|
379 |
+
# d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +
|
380 |
+
# k4 * r(x, y)^4;
|
381 |
+
r = x * x + y * y
|
382 |
+
d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4)))
|
383 |
+
|
384 |
+
# The perfect projection is:
|
385 |
+
# xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
|
386 |
+
# yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
|
387 |
+
#
|
388 |
+
# Let's define
|
389 |
+
#
|
390 |
+
# fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
|
391 |
+
# fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
|
392 |
+
#
|
393 |
+
# We are looking for a solution that satisfies
|
394 |
+
# fx(x, y) = fy(x, y) = 0;
|
395 |
+
fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
|
396 |
+
fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
|
397 |
+
|
398 |
+
# Compute derivative of d over [x, y]
|
399 |
+
d_r = (k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4)))
|
400 |
+
d_x = 2.0 * x * d_r
|
401 |
+
d_y = 2.0 * y * d_r
|
402 |
+
|
403 |
+
# Compute derivative of fx over x and y.
|
404 |
+
fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
|
405 |
+
fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
|
406 |
+
|
407 |
+
# Compute derivative of fy over x and y.
|
408 |
+
fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
|
409 |
+
fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
|
410 |
+
|
411 |
+
return fx, fy, fx_x, fx_y, fy_x, fy_y
|
412 |
+
|
413 |
+
|
414 |
+
def _radial_and_tangential_undistort(xd, yd, k1=0, k2=0,
|
415 |
+
k3=0, k4=0, p1=0,
|
416 |
+
p2=0, eps=1e-9, max_iterations=10):
|
417 |
+
"""Computes undistorted (x, y) from (xd, yd)."""
|
418 |
+
# From https://github.com/google/nerfies/blob/main/nerfies/camera.py
|
419 |
+
# Initialize from the distorted point.
|
420 |
+
x = np.copy(xd)
|
421 |
+
y = np.copy(yd)
|
422 |
+
|
423 |
+
for _ in range(max_iterations):
|
424 |
+
fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
|
425 |
+
x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, k4=k4, p1=p1, p2=p2)
|
426 |
+
denominator = fy_x * fx_y - fx_x * fy_y
|
427 |
+
x_numerator = fx * fy_y - fy * fx_y
|
428 |
+
y_numerator = fy * fx_x - fx * fy_x
|
429 |
+
step_x = np.where(
|
430 |
+
np.abs(denominator) > eps, x_numerator / denominator,
|
431 |
+
np.zeros_like(denominator))
|
432 |
+
step_y = np.where(
|
433 |
+
np.abs(denominator) > eps, y_numerator / denominator,
|
434 |
+
np.zeros_like(denominator))
|
435 |
+
|
436 |
+
x = x + step_x
|
437 |
+
y = y + step_y
|
438 |
+
|
439 |
+
return x, y
|
440 |
+
|
441 |
+
|
442 |
+
class ProjectionType(enum.Enum):
|
443 |
+
"""Camera projection type (standard perspective pinhole or fisheye model)."""
|
444 |
+
PERSPECTIVE = 'perspective'
|
445 |
+
FISHEYE = 'fisheye'
|
446 |
+
|
447 |
+
|
448 |
+
def pixels_to_rays(pix_x_int, pix_y_int, pixtocams,
|
449 |
+
camtoworlds,
|
450 |
+
distortion_params=None,
|
451 |
+
pixtocam_ndc=None,
|
452 |
+
camtype=ProjectionType.PERSPECTIVE):
|
453 |
+
"""Calculates rays given pixel coordinates, intrinisics, and extrinsics.
|
454 |
+
|
455 |
+
Given 2D pixel coordinates pix_x_int, pix_y_int for cameras with
|
456 |
+
inverse intrinsics pixtocams and extrinsics camtoworlds (and optional
|
457 |
+
distortion coefficients distortion_params and NDC space projection matrix
|
458 |
+
pixtocam_ndc), computes the corresponding 3D camera rays.
|
459 |
+
|
460 |
+
Vectorized over the leading dimensions of the first four arguments.
|
461 |
+
|
462 |
+
Args:
|
463 |
+
pix_x_int: int array, shape SH, x coordinates of image pixels.
|
464 |
+
pix_y_int: int array, shape SH, y coordinates of image pixels.
|
465 |
+
pixtocams: float array, broadcastable to SH + [3, 3], inverse intrinsics.
|
466 |
+
camtoworlds: float array, broadcastable to SH + [3, 4], camera extrinsics.
|
467 |
+
distortion_params: dict of floats, optional camera distortion parameters.
|
468 |
+
pixtocam_ndc: float array, [3, 3], optional inverse intrinsics for NDC.
|
469 |
+
camtype: camera_utils.ProjectionType, fisheye or perspective camera.
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
origins: float array, shape SH + [3], ray origin points.
|
473 |
+
directions: float array, shape SH + [3], ray direction vectors.
|
474 |
+
viewdirs: float array, shape SH + [3], normalized ray direction vectors.
|
475 |
+
radii: float array, shape SH + [1], ray differential radii.
|
476 |
+
imageplane: float array, shape SH + [2], xy coordinates on the image plane.
|
477 |
+
If the image plane is at world space distance 1 from the pinhole, then
|
478 |
+
imageplane will be the xy coordinates of a pixel in that space (so the
|
479 |
+
camera ray direction at the origin would be (x, y, -1) in OpenGL coords).
|
480 |
+
"""
|
481 |
+
|
482 |
+
# Must add half pixel offset to shoot rays through pixel centers.
|
483 |
+
def pix_to_dir(x, y):
|
484 |
+
return np.stack([x + .5, y + .5, np.ones_like(x)], axis=-1)
|
485 |
+
|
486 |
+
# We need the dx and dy rays to calculate ray radii for mip-NeRF cones.
|
487 |
+
pixel_dirs_stacked = np.stack([
|
488 |
+
pix_to_dir(pix_x_int, pix_y_int),
|
489 |
+
pix_to_dir(pix_x_int + 1, pix_y_int),
|
490 |
+
pix_to_dir(pix_x_int, pix_y_int + 1)
|
491 |
+
], axis=0)
|
492 |
+
|
493 |
+
matmul = np.matmul
|
494 |
+
mat_vec_mul = lambda A, b: matmul(A, b[..., None])[..., 0]
|
495 |
+
|
496 |
+
# Apply inverse intrinsic matrices.
|
497 |
+
camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked)
|
498 |
+
|
499 |
+
if distortion_params is not None:
|
500 |
+
# Correct for distortion.
|
501 |
+
x, y = _radial_and_tangential_undistort(
|
502 |
+
camera_dirs_stacked[..., 0],
|
503 |
+
camera_dirs_stacked[..., 1],
|
504 |
+
**distortion_params)
|
505 |
+
camera_dirs_stacked = np.stack([x, y, np.ones_like(x)], -1)
|
506 |
+
|
507 |
+
if camtype == ProjectionType.FISHEYE:
|
508 |
+
theta = np.sqrt(np.sum(np.square(camera_dirs_stacked[..., :2]), axis=-1))
|
509 |
+
theta = np.minimum(np.pi, theta)
|
510 |
+
|
511 |
+
sin_theta_over_theta = np.sin(theta) / theta
|
512 |
+
camera_dirs_stacked = np.stack([
|
513 |
+
camera_dirs_stacked[..., 0] * sin_theta_over_theta,
|
514 |
+
camera_dirs_stacked[..., 1] * sin_theta_over_theta,
|
515 |
+
np.cos(theta),
|
516 |
+
], axis=-1)
|
517 |
+
|
518 |
+
# Flip from OpenCV to OpenGL coordinate system.
|
519 |
+
camera_dirs_stacked = matmul(camera_dirs_stacked,
|
520 |
+
np.diag(np.array([1., -1., -1.])))
|
521 |
+
|
522 |
+
# Extract 2D image plane (x, y) coordinates.
|
523 |
+
imageplane = camera_dirs_stacked[0, ..., :2]
|
524 |
+
|
525 |
+
# Apply camera rotation matrices.
|
526 |
+
directions_stacked = mat_vec_mul(camtoworlds[..., :3, :3],
|
527 |
+
camera_dirs_stacked)
|
528 |
+
# Extract the offset rays.
|
529 |
+
directions, dx, dy = directions_stacked
|
530 |
+
|
531 |
+
origins = np.broadcast_to(camtoworlds[..., :3, -1], directions.shape)
|
532 |
+
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
533 |
+
|
534 |
+
if pixtocam_ndc is None:
|
535 |
+
# Distance from each unit-norm direction vector to its neighbors.
|
536 |
+
dx_norm = np.linalg.norm(dx - directions, axis=-1)
|
537 |
+
dy_norm = np.linalg.norm(dy - directions, axis=-1)
|
538 |
+
|
539 |
+
else:
|
540 |
+
# Convert ray origins and directions into projective NDC space.
|
541 |
+
origins_dx, _ = convert_to_ndc(origins, dx, pixtocam_ndc)
|
542 |
+
origins_dy, _ = convert_to_ndc(origins, dy, pixtocam_ndc)
|
543 |
+
origins, directions = convert_to_ndc(origins, directions, pixtocam_ndc)
|
544 |
+
|
545 |
+
# In NDC space, we use the offset between origins instead of directions.
|
546 |
+
dx_norm = np.linalg.norm(origins_dx - origins, axis=-1)
|
547 |
+
dy_norm = np.linalg.norm(origins_dy - origins, axis=-1)
|
548 |
+
|
549 |
+
# Cut the distance in half, multiply it to match the variance of a uniform
|
550 |
+
# distribution the size of a pixel (1/12, see the original mipnerf paper).
|
551 |
+
radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / np.sqrt(12)
|
552 |
+
return origins, directions, viewdirs, radii, imageplane
|
553 |
+
|
554 |
+
|
555 |
+
def cast_ray_batch(cameras, pixels, camtype):
|
556 |
+
"""Maps from input cameras and Pixel batch to output Ray batch.
|
557 |
+
|
558 |
+
`cameras` is a Tuple of four sets of camera parameters.
|
559 |
+
pixtocams: 1 or N stacked [3, 3] inverse intrinsic matrices.
|
560 |
+
camtoworlds: 1 or N stacked [3, 4] extrinsic pose matrices.
|
561 |
+
distortion_params: optional, dict[str, float] containing pinhole model
|
562 |
+
distortion parameters.
|
563 |
+
pixtocam_ndc: optional, [3, 3] inverse intrinsic matrix for mapping to NDC.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
cameras: described above.
|
567 |
+
pixels: integer pixel coordinates and camera indices, plus ray metadata.
|
568 |
+
These fields can be an arbitrary batch shape.
|
569 |
+
camtype: camera_utils.ProjectionType, fisheye or perspective camera.
|
570 |
+
|
571 |
+
Returns:
|
572 |
+
rays: Rays dataclass with computed 3D world space ray data.
|
573 |
+
"""
|
574 |
+
pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras
|
575 |
+
|
576 |
+
# pixels.cam_idx has shape [..., 1], remove this hanging dimension.
|
577 |
+
cam_idx = pixels['cam_idx'][..., 0]
|
578 |
+
batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx]
|
579 |
+
|
580 |
+
# Compute rays from pixel coordinates.
|
581 |
+
origins, directions, viewdirs, radii, imageplane = pixels_to_rays(
|
582 |
+
pixels['pix_x_int'],
|
583 |
+
pixels['pix_y_int'],
|
584 |
+
batch_index(pixtocams),
|
585 |
+
batch_index(camtoworlds),
|
586 |
+
distortion_params=distortion_params,
|
587 |
+
pixtocam_ndc=pixtocam_ndc,
|
588 |
+
camtype=camtype)
|
589 |
+
|
590 |
+
# Create Rays data structure.
|
591 |
+
return dict(
|
592 |
+
origins=origins,
|
593 |
+
directions=directions,
|
594 |
+
viewdirs=viewdirs,
|
595 |
+
radii=radii,
|
596 |
+
imageplane=imageplane,
|
597 |
+
lossmult=pixels.get('lossmult'),
|
598 |
+
near=pixels.get('near'),
|
599 |
+
far=pixels.get('far'),
|
600 |
+
cam_idx=pixels.get('cam_idx'),
|
601 |
+
exposure_idx=pixels.get('exposure_idx'),
|
602 |
+
exposure_values=pixels.get('exposure_values'),
|
603 |
+
)
|
604 |
+
|
605 |
+
|
606 |
+
def cast_pinhole_rays(camtoworld, height, width, focal, near, far):
|
607 |
+
"""Wrapper for generating a pinhole camera ray batch (w/o distortion)."""
|
608 |
+
|
609 |
+
pix_x_int, pix_y_int = pixel_coordinates(width, height)
|
610 |
+
pixtocam = get_pixtocam(focal, width, height)
|
611 |
+
|
612 |
+
origins, directions, viewdirs, radii, imageplane = pixels_to_rays(pix_x_int, pix_y_int, pixtocam, camtoworld)
|
613 |
+
|
614 |
+
broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
|
615 |
+
ray_kwargs = {
|
616 |
+
'lossmult': broadcast_scalar(1.),
|
617 |
+
'near': broadcast_scalar(near),
|
618 |
+
'far': broadcast_scalar(far),
|
619 |
+
'cam_idx': broadcast_scalar(0),
|
620 |
+
}
|
621 |
+
|
622 |
+
return dict(origins=origins,
|
623 |
+
directions=directions,
|
624 |
+
viewdirs=viewdirs,
|
625 |
+
radii=radii,
|
626 |
+
imageplane=imageplane,
|
627 |
+
**ray_kwargs)
|
628 |
+
|
629 |
+
|
630 |
+
def cast_spherical_rays(camtoworld, height, width, near, far):
|
631 |
+
"""Generates a spherical camera ray batch."""
|
632 |
+
|
633 |
+
theta_vals = np.linspace(0, 2 * np.pi, width + 1)
|
634 |
+
phi_vals = np.linspace(0, np.pi, height + 1)
|
635 |
+
theta, phi = np.meshgrid(theta_vals, phi_vals, indexing='xy')
|
636 |
+
|
637 |
+
# Spherical coordinates in camera reference frame (y is up).
|
638 |
+
directions = np.stack([
|
639 |
+
-np.sin(phi) * np.sin(theta),
|
640 |
+
np.cos(phi),
|
641 |
+
np.sin(phi) * np.cos(theta),
|
642 |
+
], axis=-1)
|
643 |
+
|
644 |
+
matmul = np.matmul
|
645 |
+
directions = matmul(camtoworld[:3, :3], directions[..., None])[..., 0]
|
646 |
+
|
647 |
+
dy = np.diff(directions[:, :-1], axis=0)
|
648 |
+
dx = np.diff(directions[:-1, :], axis=1)
|
649 |
+
directions = directions[:-1, :-1]
|
650 |
+
viewdirs = directions
|
651 |
+
|
652 |
+
origins = np.broadcast_to(camtoworld[:3, -1], directions.shape)
|
653 |
+
|
654 |
+
dx_norm = np.linalg.norm(dx, axis=-1)
|
655 |
+
dy_norm = np.linalg.norm(dy, axis=-1)
|
656 |
+
radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / np.sqrt(12)
|
657 |
+
|
658 |
+
imageplane = np.zeros_like(directions[..., :2])
|
659 |
+
|
660 |
+
broadcast_scalar = lambda x: np.broadcast_to(x, radii.shape[:-1])[..., None]
|
661 |
+
ray_kwargs = {
|
662 |
+
'lossmult': broadcast_scalar(1.),
|
663 |
+
'near': broadcast_scalar(near),
|
664 |
+
'far': broadcast_scalar(far),
|
665 |
+
'cam_idx': broadcast_scalar(0),
|
666 |
+
}
|
667 |
+
|
668 |
+
return dict(origins=origins,
|
669 |
+
directions=directions,
|
670 |
+
viewdirs=viewdirs,
|
671 |
+
radii=radii,
|
672 |
+
imageplane=imageplane,
|
673 |
+
**ray_kwargs)
|
internal/checkpoints.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
import accelerate
|
5 |
+
import torch
|
6 |
+
import glob
|
7 |
+
|
8 |
+
|
9 |
+
def restore_checkpoint(
|
10 |
+
checkpoint_dir,
|
11 |
+
accelerator: accelerate.Accelerator,
|
12 |
+
logger=None
|
13 |
+
):
|
14 |
+
dirs = glob.glob(os.path.join(checkpoint_dir, "*"))
|
15 |
+
dirs.sort()
|
16 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
17 |
+
if path is None:
|
18 |
+
if logger is not None:
|
19 |
+
logger.info("Checkpoint does not exist. Starting a new training run.")
|
20 |
+
init_step = 0
|
21 |
+
else:
|
22 |
+
if logger is not None:
|
23 |
+
logger.info(f"Resuming from checkpoint {path}")
|
24 |
+
accelerator.load_state(path)
|
25 |
+
init_step = int(os.path.basename(path))
|
26 |
+
return init_step
|
27 |
+
|
28 |
+
|
29 |
+
def save_checkpoint(save_dir,
|
30 |
+
accelerator: accelerate.Accelerator,
|
31 |
+
step=0,
|
32 |
+
total_limit=3):
|
33 |
+
if total_limit > 0:
|
34 |
+
folders = glob.glob(os.path.join(save_dir, "*"))
|
35 |
+
folders.sort()
|
36 |
+
for folder in folders[: len(folders) + 1 - total_limit]:
|
37 |
+
shutil.rmtree(folder)
|
38 |
+
accelerator.save_state(os.path.join(save_dir, f"{step:06d}"))
|
internal/configs.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import os
|
3 |
+
from typing import Any, Callable, Optional, Tuple, List
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from absl import flags
|
8 |
+
import gin
|
9 |
+
from internal import utils
|
10 |
+
|
11 |
+
gin.add_config_file_search_path('configs/')
|
12 |
+
|
13 |
+
configurables = {
|
14 |
+
'torch': [torch.reciprocal, torch.log, torch.log1p, torch.exp, torch.sqrt, torch.square],
|
15 |
+
}
|
16 |
+
|
17 |
+
for module, configurables in configurables.items():
|
18 |
+
for configurable in configurables:
|
19 |
+
gin.config.external_configurable(configurable, module=module)
|
20 |
+
|
21 |
+
|
22 |
+
@gin.configurable()
|
23 |
+
@dataclasses.dataclass
|
24 |
+
class Config:
|
25 |
+
"""Configuration flags for everything."""
|
26 |
+
seed = 0
|
27 |
+
dataset_loader: str = 'llff' # The type of dataset loader to use.
|
28 |
+
batching: str = 'all_images' # Batch composition, [single_image, all_images].
|
29 |
+
batch_size: int = 2 ** 16 # The number of rays/pixels in each batch.
|
30 |
+
patch_size: int = 1 # Resolution of patches sampled for training batches.
|
31 |
+
factor: int = 4 # The downsample factor of images, 0 for no downsampling.
|
32 |
+
multiscale: bool = False # use multiscale data for training.
|
33 |
+
multiscale_levels: int = 4 # number of multiscale levels.
|
34 |
+
# ordering (affects heldout test set).
|
35 |
+
forward_facing: bool = False # Set to True for forward-facing LLFF captures.
|
36 |
+
render_path: bool = False # If True, render a path. Used only by LLFF.
|
37 |
+
llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF.
|
38 |
+
# If true, use all input images for training.
|
39 |
+
llff_use_all_images_for_training: bool = False
|
40 |
+
llff_use_all_images_for_testing: bool = False
|
41 |
+
use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender.
|
42 |
+
compute_disp_metrics: bool = False # If True, load and compute disparity MSE.
|
43 |
+
compute_normal_metrics: bool = False # If True, load and compute normal MAE.
|
44 |
+
disable_multiscale_loss: bool = False # If True, disable multiscale loss.
|
45 |
+
randomized: bool = True # Use randomized stratified sampling.
|
46 |
+
near: float = 2. # Near plane distance.
|
47 |
+
far: float = 6. # Far plane distance.
|
48 |
+
exp_name: str = "test" # experiment name
|
49 |
+
data_dir: Optional[str] = "/SSD_DISK/datasets/360_v2/bicycle" # Input data directory.
|
50 |
+
vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP.
|
51 |
+
render_chunk_size: int = 65536 # Chunk size for whole-image renderings.
|
52 |
+
num_showcase_images: int = 5 # The number of test-set images to showcase.
|
53 |
+
deterministic_showcase: bool = True # If True, showcase the same images.
|
54 |
+
vis_num_rays: int = 16 # The number of rays to visualize.
|
55 |
+
# Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage.
|
56 |
+
vis_decimate: int = 0
|
57 |
+
|
58 |
+
# Only used by train.py:
|
59 |
+
max_steps: int = 25000 # The number of optimization steps.
|
60 |
+
early_exit_steps: Optional[int] = None # Early stopping, for debugging.
|
61 |
+
checkpoint_every: int = 5000 # The number of steps to save a checkpoint.
|
62 |
+
resume_from_checkpoint: bool = True # whether to resume from checkpoint.
|
63 |
+
checkpoints_total_limit: int = 1
|
64 |
+
gradient_scaling: bool = False # If True, scale gradients as in https://gradient-scaling.github.io/.
|
65 |
+
print_every: int = 100 # The number of steps between reports to tensorboard.
|
66 |
+
train_render_every: int = 500 # Steps between test set renders when training
|
67 |
+
data_loss_type: str = 'charb' # What kind of loss to use ('mse' or 'charb').
|
68 |
+
charb_padding: float = 0.001 # The padding used for Charbonnier loss.
|
69 |
+
data_loss_mult: float = 1.0 # Mult for the finest data term in the loss.
|
70 |
+
data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms.
|
71 |
+
interlevel_loss_mult: float = 0.0 # Mult. for the loss on the proposal MLP.
|
72 |
+
anti_interlevel_loss_mult: float = 0.01 # Mult. for the loss on the proposal MLP.
|
73 |
+
pulse_width = [0.03, 0.003] # Mult. for the loss on the proposal MLP.
|
74 |
+
orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss.
|
75 |
+
orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights.
|
76 |
+
# What that loss is imposed on, options are 'normals' or 'normals_pred'.
|
77 |
+
orientation_loss_target: str = 'normals_pred'
|
78 |
+
predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss.
|
79 |
+
# Mult. on the coarser predicted normal loss.
|
80 |
+
predicted_normal_coarse_loss_mult: float = 0.0
|
81 |
+
hash_decay_mults: float = 0.1
|
82 |
+
|
83 |
+
lr_init: float = 0.01 # The initial learning rate.
|
84 |
+
lr_final: float = 0.001 # The final learning rate.
|
85 |
+
lr_delay_steps: int = 5000 # The number of "warmup" learning steps.
|
86 |
+
lr_delay_mult: float = 1e-8 # How much sever the "warmup" should be.
|
87 |
+
adam_beta1: float = 0.9 # Adam's beta2 hyperparameter.
|
88 |
+
adam_beta2: float = 0.99 # Adam's beta2 hyperparameter.
|
89 |
+
adam_eps: float = 1e-15 # Adam's epsilon hyperparameter.
|
90 |
+
grad_max_norm: float = 0. # Gradient clipping magnitude, disabled if == 0.
|
91 |
+
grad_max_val: float = 0. # Gradient clipping value, disabled if == 0.
|
92 |
+
distortion_loss_mult: float = 0.005 # Multiplier on the distortion loss.
|
93 |
+
opacity_loss_mult: float = 0. # Multiplier on the distortion loss.
|
94 |
+
|
95 |
+
# Only used by eval.py:
|
96 |
+
eval_only_once: bool = True # If True evaluate the model only once, ow loop.
|
97 |
+
eval_save_output: bool = True # If True save predicted images to disk.
|
98 |
+
eval_save_ray_data: bool = False # If True save individual ray traces.
|
99 |
+
eval_render_interval: int = 1 # The interval between images saved to disk.
|
100 |
+
eval_dataset_limit: int = np.iinfo(np.int32).max # Num test images to eval.
|
101 |
+
eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images.
|
102 |
+
eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]).
|
103 |
+
|
104 |
+
# Only used by render.py
|
105 |
+
render_video_fps: int = 60 # Framerate in frames-per-second.
|
106 |
+
render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality.
|
107 |
+
render_path_frames: int = 120 # Number of frames in render path.
|
108 |
+
z_variation: float = 0. # How much height variation in render path.
|
109 |
+
z_phase: float = 0. # Phase offset for height variation in render path.
|
110 |
+
render_dist_percentile: float = 0.5 # How much to trim from near/far planes.
|
111 |
+
render_dist_curve_fn: Callable[..., Any] = np.log # How depth is curved.
|
112 |
+
render_path_file: Optional[str] = None # Numpy render pose file to load.
|
113 |
+
render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as
|
114 |
+
# (width, height).
|
115 |
+
render_focal: Optional[float] = None # Render focal length.
|
116 |
+
render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'.
|
117 |
+
render_spherical: bool = False # Render spherical 360 panoramas.
|
118 |
+
render_save_async: bool = True # Save to CNS using a separate thread.
|
119 |
+
|
120 |
+
render_spline_keyframes: Optional[str] = None # Text file containing names of
|
121 |
+
# images to be used as spline
|
122 |
+
# keyframes, OR directory
|
123 |
+
# containing those images.
|
124 |
+
render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe.
|
125 |
+
render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation.
|
126 |
+
render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for
|
127 |
+
# exact interpolation of keyframes.
|
128 |
+
# Interpolate per-frame exposure value from spline keyframes.
|
129 |
+
render_spline_interpolate_exposure: bool = False
|
130 |
+
|
131 |
+
# Flags for raw datasets.
|
132 |
+
rawnerf_mode: bool = False # Load raw images and train in raw color space.
|
133 |
+
exposure_percentile: float = 97. # Image percentile to expose as white.
|
134 |
+
num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border
|
135 |
+
# around each input image.
|
136 |
+
apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask.
|
137 |
+
autoexpose_renders: bool = False # During rendering, autoexpose each image.
|
138 |
+
# For raw test scenes, use affine raw-space color correction.
|
139 |
+
eval_raw_affine_cc: bool = False
|
140 |
+
|
141 |
+
zero_glo: bool = False
|
142 |
+
|
143 |
+
# marching cubes
|
144 |
+
valid_weight_thresh: float = 0.05
|
145 |
+
isosurface_threshold: float = 20
|
146 |
+
mesh_voxels: int = 512 ** 3
|
147 |
+
visibility_resolution: int = 512
|
148 |
+
mesh_radius: float = 1.0 # mesh radius * 2 = in contract space
|
149 |
+
mesh_max_radius: float = 10.0 # in world space
|
150 |
+
std_value: float = 0.0 # std of the sampled points
|
151 |
+
compute_visibility: bool = False
|
152 |
+
extract_visibility: bool = True
|
153 |
+
decimate_target: int = -1
|
154 |
+
vertex_color: bool = True
|
155 |
+
vertex_projection: bool = True
|
156 |
+
|
157 |
+
# tsdf
|
158 |
+
tsdf_radius: float = 2.0
|
159 |
+
tsdf_resolution: int = 512
|
160 |
+
truncation_margin: float = 5.0
|
161 |
+
tsdf_max_radius: float = 10.0 # in world space
|
162 |
+
|
163 |
+
|
164 |
+
def define_common_flags():
|
165 |
+
# Define the flags used by both train.py and eval.py
|
166 |
+
flags.DEFINE_string('mode', None, 'Required by GINXM, not used.')
|
167 |
+
flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.')
|
168 |
+
flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.')
|
169 |
+
flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.')
|
170 |
+
|
171 |
+
|
172 |
+
def load_config():
|
173 |
+
"""Load the config, and optionally checkpoint it."""
|
174 |
+
gin.parse_config_files_and_bindings(
|
175 |
+
flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True)
|
176 |
+
config = Config()
|
177 |
+
return config
|
internal/coord.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from internal import math
|
2 |
+
from internal import utils
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
# from torch.func import vmap, jacrev
|
6 |
+
|
7 |
+
|
8 |
+
def contract(x):
|
9 |
+
"""Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077)."""
|
10 |
+
eps = torch.finfo(x.dtype).eps
|
11 |
+
# eps = 1e-3
|
12 |
+
# Clamping to eps prevents non-finite gradients when x == 0.
|
13 |
+
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
|
14 |
+
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
|
15 |
+
return z
|
16 |
+
|
17 |
+
|
18 |
+
def inv_contract(z):
|
19 |
+
"""The inverse of contract()."""
|
20 |
+
eps = torch.finfo(z.dtype).eps
|
21 |
+
|
22 |
+
# Clamping to eps prevents non-finite gradients when z == 0.
|
23 |
+
z_mag_sq = torch.sum(z ** 2, dim=-1, keepdim=True).clamp_min(eps)
|
24 |
+
x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq).clamp_min(eps))
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
def inv_contract_np(z):
|
29 |
+
"""The inverse of contract()."""
|
30 |
+
eps = np.finfo(z.dtype).eps
|
31 |
+
|
32 |
+
# Clamping to eps prevents non-finite gradients when z == 0.
|
33 |
+
z_mag_sq = np.maximum(np.sum(z ** 2, axis=-1, keepdims=True), eps)
|
34 |
+
x = np.where(z_mag_sq <= 1, z, z / np.maximum(2 * np.sqrt(z_mag_sq) - z_mag_sq, eps))
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
def contract_tuple(x):
|
39 |
+
res = contract(x)
|
40 |
+
return res, res
|
41 |
+
|
42 |
+
|
43 |
+
def contract_mean_jacobi(x):
|
44 |
+
eps = torch.finfo(x.dtype).eps
|
45 |
+
# eps = 1e-3
|
46 |
+
|
47 |
+
# Clamping to eps prevents non-finite gradients when x == 0.
|
48 |
+
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
|
49 |
+
x_mag_sqrt = torch.sqrt(x_mag_sq)
|
50 |
+
x_xT = math.matmul(x[..., None], x[..., None, :])
|
51 |
+
mask = x_mag_sq <= 1
|
52 |
+
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
|
53 |
+
|
54 |
+
eye = torch.broadcast_to(torch.eye(3, device=x.device), z.shape[:-1] + z.shape[-1:] * 2)
|
55 |
+
jacobi = (2 * x_xT * (1 - x_mag_sqrt[..., None]) + (2 * x_mag_sqrt[..., None] ** 3 - x_mag_sqrt[..., None] ** 2) * eye) / x_mag_sqrt[..., None] ** 4
|
56 |
+
jacobi = torch.where(mask[..., None], eye, jacobi)
|
57 |
+
return z, jacobi
|
58 |
+
|
59 |
+
|
60 |
+
def contract_mean_std(x, std):
|
61 |
+
eps = torch.finfo(x.dtype).eps
|
62 |
+
# eps = 1e-3
|
63 |
+
# Clamping to eps prevents non-finite gradients when x == 0.
|
64 |
+
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
|
65 |
+
x_mag_sqrt = torch.sqrt(x_mag_sq)
|
66 |
+
mask = x_mag_sq <= 1
|
67 |
+
z = torch.where(mask, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
|
68 |
+
# det_13 = ((1 / x_mag_sq) * ((2 / x_mag_sqrt - 1 / x_mag_sq) ** 2)) ** (1 / 3)
|
69 |
+
det_13 = (torch.pow(2 * x_mag_sqrt - 1, 1/3) / x_mag_sqrt) ** 2
|
70 |
+
|
71 |
+
std = torch.where(mask[..., 0], std, det_13[..., 0] * std)
|
72 |
+
return z, std
|
73 |
+
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def track_linearize(fn, mean, std):
|
77 |
+
"""Apply function `fn` to a set of means and covariances, ala a Kalman filter.
|
78 |
+
|
79 |
+
We can analytically transform a Gaussian parameterized by `mean` and `cov`
|
80 |
+
with a function `fn` by linearizing `fn` around `mean`, and taking advantage
|
81 |
+
of the fact that Covar[Ax + y] = A(Covar[x])A^T (see
|
82 |
+
https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details).
|
83 |
+
|
84 |
+
Args:
|
85 |
+
fn: the function applied to the Gaussians parameterized by (mean, cov).
|
86 |
+
mean: a tensor of means, where the last axis is the dimension.
|
87 |
+
std: a tensor of covariances, where the last two axes are the dimensions.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
fn_mean: the transformed means.
|
91 |
+
fn_cov: the transformed covariances.
|
92 |
+
"""
|
93 |
+
if fn == 'contract':
|
94 |
+
fn = contract_mean_jacobi
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
pre_shape = mean.shape[:-1]
|
99 |
+
mean = mean.reshape(-1, 3)
|
100 |
+
std = std.reshape(-1)
|
101 |
+
|
102 |
+
# jvp_1, mean_1 = vmap(jacrev(contract_tuple, has_aux=True))(mean)
|
103 |
+
# std_1 = std * torch.linalg.det(jvp_1) ** (1 / mean.shape[-1])
|
104 |
+
#
|
105 |
+
# mean_2, jvp_2 = fn(mean)
|
106 |
+
# std_2 = std * torch.linalg.det(jvp_2) ** (1 / mean.shape[-1])
|
107 |
+
#
|
108 |
+
# mean_3, std_3 = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
|
109 |
+
# torch.allclose(std_1, std_3, atol=1e-7) # True
|
110 |
+
# torch.allclose(mean_1, mean_3) # True
|
111 |
+
# import ipdb; ipdb.set_trace()
|
112 |
+
mean, std = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
|
113 |
+
|
114 |
+
mean = mean.reshape(*pre_shape, 3)
|
115 |
+
std = std.reshape(*pre_shape)
|
116 |
+
return mean, std
|
117 |
+
|
118 |
+
|
119 |
+
def power_transformation(x, lam):
|
120 |
+
"""
|
121 |
+
power transformation for Eq(4) in zip-nerf
|
122 |
+
"""
|
123 |
+
lam_1 = np.abs(lam - 1)
|
124 |
+
return lam_1 / lam * ((x / lam_1 + 1) ** lam - 1)
|
125 |
+
|
126 |
+
|
127 |
+
def inv_power_transformation(x, lam):
|
128 |
+
"""
|
129 |
+
inverse power transformation
|
130 |
+
"""
|
131 |
+
lam_1 = np.abs(lam - 1)
|
132 |
+
eps = torch.finfo(x.dtype).eps # may cause inf
|
133 |
+
# eps = 1e-3
|
134 |
+
return ((x * lam / lam_1 + 1 + eps) ** (1 / lam) - 1) * lam_1
|
135 |
+
|
136 |
+
|
137 |
+
def construct_ray_warps(fn, t_near, t_far, lam=None):
|
138 |
+
"""Construct a bijection between metric distances and normalized distances.
|
139 |
+
|
140 |
+
See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a
|
141 |
+
detailed explanation.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
fn: the function to ray distances.
|
145 |
+
t_near: a tensor of near-plane distances.
|
146 |
+
t_far: a tensor of far-plane distances.
|
147 |
+
lam: for lam in Eq(4) in zip-nerf
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
t_to_s: a function that maps distances to normalized distances in [0, 1].
|
151 |
+
s_to_t: the inverse of t_to_s.
|
152 |
+
"""
|
153 |
+
if fn is None:
|
154 |
+
fn_fwd = lambda x: x
|
155 |
+
fn_inv = lambda x: x
|
156 |
+
elif fn == 'piecewise':
|
157 |
+
# Piecewise spacing combining identity and 1/x functions to allow t_near=0.
|
158 |
+
fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x)
|
159 |
+
fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x))
|
160 |
+
elif fn == 'power_transformation':
|
161 |
+
fn_fwd = lambda x: power_transformation(x * 2, lam=lam)
|
162 |
+
fn_inv = lambda y: inv_power_transformation(y, lam=lam) / 2
|
163 |
+
else:
|
164 |
+
inv_mapping = {
|
165 |
+
'reciprocal': torch.reciprocal,
|
166 |
+
'log': torch.exp,
|
167 |
+
'exp': torch.log,
|
168 |
+
'sqrt': torch.square,
|
169 |
+
'square': torch.sqrt,
|
170 |
+
}
|
171 |
+
fn_fwd = fn
|
172 |
+
fn_inv = inv_mapping[fn.__name__]
|
173 |
+
|
174 |
+
s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)]
|
175 |
+
t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near)
|
176 |
+
s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near)
|
177 |
+
return t_to_s, s_to_t
|
178 |
+
|
179 |
+
|
180 |
+
def expected_sin(mean, var):
|
181 |
+
"""Compute the mean of sin(x), x ~ N(mean, var)."""
|
182 |
+
return torch.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value.
|
183 |
+
|
184 |
+
|
185 |
+
def integrated_pos_enc(mean, var, min_deg, max_deg):
|
186 |
+
"""Encode `x` with sinusoids scaled by 2^[min_deg, max_deg).
|
187 |
+
|
188 |
+
Args:
|
189 |
+
mean: tensor, the mean coordinates to be encoded
|
190 |
+
var: tensor, the variance of the coordinates to be encoded.
|
191 |
+
min_deg: int, the min degree of the encoding.
|
192 |
+
max_deg: int, the max degree of the encoding.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
encoded: tensor, encoded variables.
|
196 |
+
"""
|
197 |
+
scales = 2 ** torch.arange(min_deg, max_deg, device=mean.device)
|
198 |
+
shape = mean.shape[:-1] + (-1,)
|
199 |
+
scaled_mean = (mean[..., None, :] * scales[:, None]).reshape(*shape)
|
200 |
+
scaled_var = (var[..., None, :] * scales[:, None] ** 2).reshape(*shape)
|
201 |
+
|
202 |
+
return expected_sin(
|
203 |
+
torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1),
|
204 |
+
torch.cat([scaled_var] * 2, dim=-1))
|
205 |
+
|
206 |
+
|
207 |
+
def lift_and_diagonalize(mean, cov, basis):
|
208 |
+
"""Project `mean` and `cov` onto basis and diagonalize the projected cov."""
|
209 |
+
fn_mean = math.matmul(mean, basis)
|
210 |
+
fn_cov_diag = torch.sum(basis * math.matmul(cov, basis), dim=-2)
|
211 |
+
return fn_mean, fn_cov_diag
|
212 |
+
|
213 |
+
|
214 |
+
def pos_enc(x, min_deg, max_deg, append_identity=True):
|
215 |
+
"""The positional encoding used by the original NeRF paper."""
|
216 |
+
scales = 2 ** torch.arange(min_deg, max_deg, device=x.device)
|
217 |
+
shape = x.shape[:-1] + (-1,)
|
218 |
+
scaled_x = (x[..., None, :] * scales[:, None]).reshape(*shape)
|
219 |
+
# Note that we're not using safe_sin, unlike IPE.
|
220 |
+
four_feat = torch.sin(
|
221 |
+
torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1))
|
222 |
+
if append_identity:
|
223 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
224 |
+
else:
|
225 |
+
return four_feat
|
internal/datasets.py
ADDED
@@ -0,0 +1,1016 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
from internal import camera_utils
|
7 |
+
from internal import configs
|
8 |
+
from internal import image as lib_image
|
9 |
+
from internal import raw_utils
|
10 |
+
from internal import utils
|
11 |
+
from collections import defaultdict
|
12 |
+
import numpy as np
|
13 |
+
import cv2
|
14 |
+
from PIL import Image
|
15 |
+
import torch
|
16 |
+
from tqdm import tqdm
|
17 |
+
# This is ugly, but it works.
|
18 |
+
import sys
|
19 |
+
|
20 |
+
sys.path.insert(0, 'internal/pycolmap')
|
21 |
+
sys.path.insert(0, 'internal/pycolmap/pycolmap')
|
22 |
+
import pycolmap
|
23 |
+
|
24 |
+
|
25 |
+
def load_dataset(split, train_dir, config: configs.Config):
|
26 |
+
"""Loads a split of a dataset using the data_loader specified by `config`."""
|
27 |
+
if config.multiscale:
|
28 |
+
dataset_dict = {
|
29 |
+
'llff': MultiLLFF,
|
30 |
+
}
|
31 |
+
else:
|
32 |
+
dataset_dict = {
|
33 |
+
'blender': Blender,
|
34 |
+
'llff': LLFF,
|
35 |
+
'tat_nerfpp': TanksAndTemplesNerfPP,
|
36 |
+
'tat_fvs': TanksAndTemplesFVS,
|
37 |
+
'dtu': DTU,
|
38 |
+
}
|
39 |
+
return dataset_dict[config.dataset_loader](split, train_dir, config)
|
40 |
+
|
41 |
+
|
42 |
+
class NeRFSceneManager(pycolmap.SceneManager):
|
43 |
+
"""COLMAP pose loader.
|
44 |
+
|
45 |
+
Minor NeRF-specific extension to the third_party Python COLMAP loader:
|
46 |
+
google3/third_party/py/pycolmap/scene_manager.py
|
47 |
+
"""
|
48 |
+
|
49 |
+
def process(self):
|
50 |
+
"""Applies NeRF-specific postprocessing to the loaded pose data.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
a tuple [image_names, poses, pixtocam, distortion_params].
|
54 |
+
image_names: contains the only the basename of the images.
|
55 |
+
poses: [N, 4, 4] array containing the camera to world matrices.
|
56 |
+
pixtocam: [N, 3, 3] array containing the camera to pixel space matrices.
|
57 |
+
distortion_params: mapping of distortion param name to distortion
|
58 |
+
parameters. Cameras share intrinsics. Valid keys are k1, k2, p1 and p2.
|
59 |
+
"""
|
60 |
+
|
61 |
+
self.load_cameras()
|
62 |
+
self.load_images()
|
63 |
+
# self.load_points3D() # For now, we do not need the point cloud data.
|
64 |
+
|
65 |
+
# Assume shared intrinsics between all cameras.
|
66 |
+
cam = self.cameras[1]
|
67 |
+
|
68 |
+
# Extract focal lengths and principal point parameters.
|
69 |
+
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
|
70 |
+
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))
|
71 |
+
|
72 |
+
# Extract extrinsic matrices in world-to-camera format.
|
73 |
+
imdata = self.images
|
74 |
+
w2c_mats = []
|
75 |
+
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
|
76 |
+
for k in imdata:
|
77 |
+
im = imdata[k]
|
78 |
+
rot = im.R()
|
79 |
+
trans = im.tvec.reshape(3, 1)
|
80 |
+
w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
|
81 |
+
w2c_mats.append(w2c)
|
82 |
+
w2c_mats = np.stack(w2c_mats, axis=0)
|
83 |
+
|
84 |
+
# Convert extrinsics to camera-to-world.
|
85 |
+
c2w_mats = np.linalg.inv(w2c_mats)
|
86 |
+
poses = c2w_mats[:, :3, :4]
|
87 |
+
|
88 |
+
# Image names from COLMAP. No need for permuting the poses according to
|
89 |
+
# image names anymore.
|
90 |
+
names = [imdata[k].name for k in imdata]
|
91 |
+
|
92 |
+
# Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.
|
93 |
+
poses = poses @ np.diag([1, -1, -1, 1])
|
94 |
+
|
95 |
+
# Get distortion parameters.
|
96 |
+
type_ = cam.camera_type
|
97 |
+
|
98 |
+
if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
|
99 |
+
params = None
|
100 |
+
camtype = camera_utils.ProjectionType.PERSPECTIVE
|
101 |
+
|
102 |
+
elif type_ == 1 or type_ == 'PINHOLE':
|
103 |
+
params = None
|
104 |
+
camtype = camera_utils.ProjectionType.PERSPECTIVE
|
105 |
+
|
106 |
+
if type_ == 2 or type_ == 'SIMPLE_RADIAL':
|
107 |
+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
|
108 |
+
params['k1'] = cam.k1
|
109 |
+
camtype = camera_utils.ProjectionType.PERSPECTIVE
|
110 |
+
|
111 |
+
elif type_ == 3 or type_ == 'RADIAL':
|
112 |
+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
|
113 |
+
params['k1'] = cam.k1
|
114 |
+
params['k2'] = cam.k2
|
115 |
+
camtype = camera_utils.ProjectionType.PERSPECTIVE
|
116 |
+
|
117 |
+
elif type_ == 4 or type_ == 'OPENCV':
|
118 |
+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']}
|
119 |
+
params['k1'] = cam.k1
|
120 |
+
params['k2'] = cam.k2
|
121 |
+
params['p1'] = cam.p1
|
122 |
+
params['p2'] = cam.p2
|
123 |
+
camtype = camera_utils.ProjectionType.PERSPECTIVE
|
124 |
+
|
125 |
+
elif type_ == 5 or type_ == 'OPENCV_FISHEYE':
|
126 |
+
params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']}
|
127 |
+
params['k1'] = cam.k1
|
128 |
+
params['k2'] = cam.k2
|
129 |
+
params['k3'] = cam.k3
|
130 |
+
params['k4'] = cam.k4
|
131 |
+
camtype = camera_utils.ProjectionType.FISHEYE
|
132 |
+
|
133 |
+
return names, poses, pixtocam, params, camtype
|
134 |
+
|
135 |
+
|
136 |
+
def load_blender_posedata(data_dir, split=None):
|
137 |
+
"""Load poses from `transforms.json` file, as used in Blender/NGP datasets."""
|
138 |
+
suffix = '' if split is None else f'_{split}'
|
139 |
+
pose_file = os.path.join(data_dir, f'transforms{suffix}.json')
|
140 |
+
with utils.open_file(pose_file, 'r') as fp:
|
141 |
+
meta = json.load(fp)
|
142 |
+
names = []
|
143 |
+
poses = []
|
144 |
+
for _, frame in enumerate(meta['frames']):
|
145 |
+
filepath = os.path.join(data_dir, frame['file_path'])
|
146 |
+
if utils.file_exists(filepath):
|
147 |
+
names.append(frame['file_path'].split('/')[-1])
|
148 |
+
poses.append(np.array(frame['transform_matrix'], dtype=np.float32))
|
149 |
+
poses = np.stack(poses, axis=0)
|
150 |
+
|
151 |
+
w = meta['w']
|
152 |
+
h = meta['h']
|
153 |
+
cx = meta['cx'] if 'cx' in meta else w / 2.
|
154 |
+
cy = meta['cy'] if 'cy' in meta else h / 2.
|
155 |
+
if 'fl_x' in meta:
|
156 |
+
fx = meta['fl_x']
|
157 |
+
else:
|
158 |
+
fx = 0.5 * w / np.tan(0.5 * float(meta['camera_angle_x']))
|
159 |
+
if 'fl_y' in meta:
|
160 |
+
fy = meta['fl_y']
|
161 |
+
else:
|
162 |
+
fy = 0.5 * h / np.tan(0.5 * float(meta['camera_angle_y']))
|
163 |
+
pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy))
|
164 |
+
coeffs = ['k1', 'k2', 'p1', 'p2']
|
165 |
+
if not any([c in meta for c in coeffs]):
|
166 |
+
params = None
|
167 |
+
else:
|
168 |
+
params = {c: (meta[c] if c in meta else 0.) for c in coeffs}
|
169 |
+
camtype = camera_utils.ProjectionType.PERSPECTIVE
|
170 |
+
return names, poses, pixtocam, params, camtype
|
171 |
+
|
172 |
+
|
173 |
+
class Dataset(torch.utils.data.Dataset):
|
174 |
+
"""Dataset Base Class.
|
175 |
+
|
176 |
+
Base class for a NeRF dataset. Creates batches of ray and color data used for
|
177 |
+
training or rendering a NeRF model.
|
178 |
+
|
179 |
+
Each subclass is responsible for loading images and camera poses from disk by
|
180 |
+
implementing the _load_renderings() method. This data is used to generate
|
181 |
+
train and test batches of ray + color data for feeding through the NeRF model.
|
182 |
+
The ray parameters are calculated in _generate_rays().
|
183 |
+
|
184 |
+
The public interface mimics the behavior of a standard machine learning
|
185 |
+
pipeline dataset provider that can provide infinite batches of data to the
|
186 |
+
training/testing pipelines without exposing any details of how the batches are
|
187 |
+
loaded/created or how this is parallelized. Therefore, the initializer runs
|
188 |
+
all setup, including data loading from disk using _load_renderings(), and
|
189 |
+
begins the thread using its parent start() method. After the initializer
|
190 |
+
returns, the caller can request batches of data straight away.
|
191 |
+
|
192 |
+
The internal self._queue is initialized as queue.Queue(3), so the infinite
|
193 |
+
loop in run() will block on the call self._queue.put(self._next_fn()) once
|
194 |
+
there are 3 elements. The main thread training job runs in a loop that pops 1
|
195 |
+
element at a time off the front of the queue. The Dataset thread's run() loop
|
196 |
+
will populate the queue with 3 elements, then wait until a batch has been
|
197 |
+
removed and push one more onto the end.
|
198 |
+
|
199 |
+
This repeats indefinitely until the main thread's training loop completes
|
200 |
+
(typically hundreds of thousands of iterations), then the main thread will
|
201 |
+
exit and the Dataset thread will automatically be killed since it is a daemon.
|
202 |
+
|
203 |
+
Attributes:
|
204 |
+
alphas: np.ndarray, optional array of alpha channel data.
|
205 |
+
cameras: tuple summarizing all camera extrinsic/intrinsic/distortion params.
|
206 |
+
camtoworlds: np.ndarray, a list of extrinsic camera pose matrices.
|
207 |
+
camtype: camera_utils.ProjectionType, fisheye or perspective camera.
|
208 |
+
data_dir: str, location of the dataset on disk.
|
209 |
+
disp_images: np.ndarray, optional array of disparity (inverse depth) data.
|
210 |
+
distortion_params: dict, the camera distortion model parameters.
|
211 |
+
exposures: optional per-image exposure value (shutter * ISO / 1000).
|
212 |
+
far: float, far plane value for rays.
|
213 |
+
focal: float, focal length from camera intrinsics.
|
214 |
+
height: int, height of images.
|
215 |
+
images: np.ndarray, array of RGB image data.
|
216 |
+
metadata: dict, optional metadata for raw datasets.
|
217 |
+
near: float, near plane value for rays.
|
218 |
+
normal_images: np.ndarray, optional array of surface normal vector data.
|
219 |
+
pixtocams: np.ndarray, one or a list of inverse intrinsic camera matrices.
|
220 |
+
pixtocam_ndc: np.ndarray, the inverse intrinsic matrix used for NDC space.
|
221 |
+
poses: np.ndarray, optional array of auxiliary camera pose data.
|
222 |
+
rays: utils.Rays, ray data for every pixel in the dataset.
|
223 |
+
render_exposures: optional list of exposure values for the render path.
|
224 |
+
render_path: bool, indicates if a smooth camera path should be generated.
|
225 |
+
size: int, number of images in the dataset.
|
226 |
+
split: str, indicates if this is a "train" or "test" dataset.
|
227 |
+
width: int, width of images.
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(self,
|
231 |
+
split: str,
|
232 |
+
data_dir: str,
|
233 |
+
config: configs.Config):
|
234 |
+
super().__init__()
|
235 |
+
|
236 |
+
# Initialize attributes
|
237 |
+
self._patch_size = max(config.patch_size, 1)
|
238 |
+
self._batch_size = config.batch_size // config.world_size
|
239 |
+
if self._patch_size ** 2 > self._batch_size:
|
240 |
+
raise ValueError(f'Patch size {self._patch_size}^2 too large for ' +
|
241 |
+
f'per-process batch size {self._batch_size}')
|
242 |
+
self._batching = utils.BatchingMethod(config.batching)
|
243 |
+
self._use_tiffs = config.use_tiffs
|
244 |
+
self._load_disps = config.compute_disp_metrics
|
245 |
+
self._load_normals = config.compute_normal_metrics
|
246 |
+
self._num_border_pixels_to_mask = config.num_border_pixels_to_mask
|
247 |
+
self._apply_bayer_mask = config.apply_bayer_mask
|
248 |
+
self._render_spherical = False
|
249 |
+
|
250 |
+
self.config = config
|
251 |
+
self.global_rank = config.global_rank
|
252 |
+
self.world_size = config.world_size
|
253 |
+
self.split = utils.DataSplit(split)
|
254 |
+
self.data_dir = data_dir
|
255 |
+
self.near = config.near
|
256 |
+
self.far = config.far
|
257 |
+
self.render_path = config.render_path
|
258 |
+
self.distortion_params = None
|
259 |
+
self.disp_images = None
|
260 |
+
self.normal_images = None
|
261 |
+
self.alphas = None
|
262 |
+
self.poses = None
|
263 |
+
self.pixtocam_ndc = None
|
264 |
+
self.metadata = None
|
265 |
+
self.camtype = camera_utils.ProjectionType.PERSPECTIVE
|
266 |
+
self.exposures = None
|
267 |
+
self.render_exposures = None
|
268 |
+
|
269 |
+
# Providing type comments for these attributes, they must be correctly
|
270 |
+
# initialized by _load_renderings() (see docstring) in any subclass.
|
271 |
+
self.images: np.ndarray = None
|
272 |
+
self.camtoworlds: np.ndarray = None
|
273 |
+
self.pixtocams: np.ndarray = None
|
274 |
+
self.height: int = None
|
275 |
+
self.width: int = None
|
276 |
+
|
277 |
+
# Load data from disk using provided config parameters.
|
278 |
+
self._load_renderings(config)
|
279 |
+
|
280 |
+
if self.render_path:
|
281 |
+
if config.render_path_file is not None:
|
282 |
+
with utils.open_file(config.render_path_file, 'rb') as fp:
|
283 |
+
render_poses = np.load(fp)
|
284 |
+
self.camtoworlds = render_poses
|
285 |
+
if config.render_resolution is not None:
|
286 |
+
self.width, self.height = config.render_resolution
|
287 |
+
if config.render_focal is not None:
|
288 |
+
self.focal = config.render_focal
|
289 |
+
if config.render_camtype is not None:
|
290 |
+
if config.render_camtype == 'pano':
|
291 |
+
self._render_spherical = True
|
292 |
+
else:
|
293 |
+
self.camtype = camera_utils.ProjectionType(config.render_camtype)
|
294 |
+
|
295 |
+
self.distortion_params = None
|
296 |
+
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
|
297 |
+
self.height)
|
298 |
+
|
299 |
+
self._n_examples = self.camtoworlds.shape[0]
|
300 |
+
|
301 |
+
self.cameras = (self.pixtocams,
|
302 |
+
self.camtoworlds,
|
303 |
+
self.distortion_params,
|
304 |
+
self.pixtocam_ndc)
|
305 |
+
|
306 |
+
# Seed the queue with one batch to avoid race condition.
|
307 |
+
if self.split == utils.DataSplit.TRAIN and not config.compute_visibility:
|
308 |
+
self._next_fn = self._next_train
|
309 |
+
else:
|
310 |
+
self._next_fn = self._next_test
|
311 |
+
|
312 |
+
@property
|
313 |
+
def size(self):
|
314 |
+
return self._n_examples
|
315 |
+
|
316 |
+
def __len__(self):
|
317 |
+
if self.split == utils.DataSplit.TRAIN and not self.config.compute_visibility:
|
318 |
+
return 1000
|
319 |
+
else:
|
320 |
+
return self._n_examples
|
321 |
+
|
322 |
+
@abc.abstractmethod
|
323 |
+
def _load_renderings(self, config):
|
324 |
+
"""Load images and poses from disk.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
config: utils.Config, user-specified config parameters.
|
328 |
+
In inherited classes, this method must set the following public attributes:
|
329 |
+
images: [N, height, width, 3] array for RGB images.
|
330 |
+
disp_images: [N, height, width] array for depth data (optional).
|
331 |
+
normal_images: [N, height, width, 3] array for normals (optional).
|
332 |
+
camtoworlds: [N, 3, 4] array of extrinsic pose matrices.
|
333 |
+
poses: [..., 3, 4] array of auxiliary pose data (optional).
|
334 |
+
pixtocams: [N, 3, 4] array of inverse intrinsic matrices.
|
335 |
+
distortion_params: dict, camera lens distortion model parameters.
|
336 |
+
height: int, height of images.
|
337 |
+
width: int, width of images.
|
338 |
+
focal: float, focal length to use for ideal pinhole rendering.
|
339 |
+
"""
|
340 |
+
|
341 |
+
def _make_ray_batch(self,
|
342 |
+
pix_x_int,
|
343 |
+
pix_y_int,
|
344 |
+
cam_idx,
|
345 |
+
lossmult=None
|
346 |
+
):
|
347 |
+
"""Creates ray data batch from pixel coordinates and camera indices.
|
348 |
+
|
349 |
+
All arguments must have broadcastable shapes. If the arguments together
|
350 |
+
broadcast to a shape [a, b, c, ..., z] then the returned utils.Rays object
|
351 |
+
will have array attributes with shape [a, b, c, ..., z, N], where N=3 for
|
352 |
+
3D vectors and N=1 for per-ray scalar attributes.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
pix_x_int: int array, x coordinates of image pixels.
|
356 |
+
pix_y_int: int array, y coordinates of image pixels.
|
357 |
+
cam_idx: int or int array, camera indices.
|
358 |
+
lossmult: float array, weight to apply to each ray when computing loss fn.
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
A dict mapping from strings utils.Rays or arrays of image data.
|
362 |
+
This is the batch provided for one NeRF train or test iteration.
|
363 |
+
"""
|
364 |
+
|
365 |
+
broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
|
366 |
+
ray_kwargs = {
|
367 |
+
'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult,
|
368 |
+
'near': broadcast_scalar(self.near),
|
369 |
+
'far': broadcast_scalar(self.far),
|
370 |
+
'cam_idx': broadcast_scalar(cam_idx),
|
371 |
+
}
|
372 |
+
# Collect per-camera information needed for each ray.
|
373 |
+
if self.metadata is not None:
|
374 |
+
# Exposure index and relative shutter speed, needed for RawNeRF.
|
375 |
+
for key in ['exposure_idx', 'exposure_values']:
|
376 |
+
idx = 0 if self.render_path else cam_idx
|
377 |
+
ray_kwargs[key] = broadcast_scalar(self.metadata[key][idx])
|
378 |
+
if self.exposures is not None:
|
379 |
+
idx = 0 if self.render_path else cam_idx
|
380 |
+
ray_kwargs['exposure_values'] = broadcast_scalar(self.exposures[idx])
|
381 |
+
if self.render_path and self.render_exposures is not None:
|
382 |
+
ray_kwargs['exposure_values'] = broadcast_scalar(
|
383 |
+
self.render_exposures[cam_idx])
|
384 |
+
|
385 |
+
pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs)
|
386 |
+
|
387 |
+
# Slow path, do ray computation using numpy (on CPU).
|
388 |
+
batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype)
|
389 |
+
batch['cam_dirs'] = -self.camtoworlds[ray_kwargs['cam_idx'][..., 0]][..., :3, 2]
|
390 |
+
|
391 |
+
# import trimesh
|
392 |
+
# pts = batch['origins'][..., None, :] + batch['directions'][..., None, :] * np.linspace(0, 1, 5)[:, None]
|
393 |
+
# trimesh.Trimesh(vertices=pts.reshape(-1, 3)).export("test.ply", "ply")
|
394 |
+
#
|
395 |
+
# pts = batch['origins'][0, 0, None, :] - self.camtoworlds[cam_idx][:, 2] * np.linspace(0, 1, 100)[:, None]
|
396 |
+
# trimesh.Trimesh(vertices=pts.reshape(-1, 3)).export("test2.ply", "ply")
|
397 |
+
|
398 |
+
if not self.render_path:
|
399 |
+
batch['rgb'] = self.images[cam_idx, pix_y_int, pix_x_int]
|
400 |
+
if self._load_disps:
|
401 |
+
batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int]
|
402 |
+
if self._load_normals:
|
403 |
+
batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int]
|
404 |
+
batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int]
|
405 |
+
return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()}
|
406 |
+
|
407 |
+
def _next_train(self, item):
|
408 |
+
"""Sample next training batch (random rays)."""
|
409 |
+
# We assume all images in the dataset are the same resolution, so we can use
|
410 |
+
# the same width/height for sampling all pixels coordinates in the batch.
|
411 |
+
# Batch/patch sampling parameters.
|
412 |
+
num_patches = self._batch_size // self._patch_size ** 2
|
413 |
+
lower_border = self._num_border_pixels_to_mask
|
414 |
+
upper_border = self._num_border_pixels_to_mask + self._patch_size - 1
|
415 |
+
# Random pixel patch x-coordinates.
|
416 |
+
pix_x_int = np.random.randint(lower_border, self.width - upper_border,
|
417 |
+
(num_patches, 1, 1))
|
418 |
+
# Random pixel patch y-coordinates.
|
419 |
+
pix_y_int = np.random.randint(lower_border, self.height - upper_border,
|
420 |
+
(num_patches, 1, 1))
|
421 |
+
# Add patch coordinate offsets.
|
422 |
+
# Shape will broadcast to (num_patches, _patch_size, _patch_size).
|
423 |
+
patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates(
|
424 |
+
self._patch_size, self._patch_size)
|
425 |
+
pix_x_int = pix_x_int + patch_dx_int
|
426 |
+
pix_y_int = pix_y_int + patch_dy_int
|
427 |
+
# Random camera indices.
|
428 |
+
if self._batching == utils.BatchingMethod.ALL_IMAGES:
|
429 |
+
cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1))
|
430 |
+
else:
|
431 |
+
cam_idx = np.random.randint(0, self._n_examples, (1,))
|
432 |
+
|
433 |
+
if self._apply_bayer_mask:
|
434 |
+
# Compute the Bayer mosaic mask for each pixel in the batch.
|
435 |
+
lossmult = raw_utils.pixels_to_bayer_mask(pix_x_int, pix_y_int)
|
436 |
+
else:
|
437 |
+
lossmult = None
|
438 |
+
|
439 |
+
return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx,
|
440 |
+
lossmult=lossmult)
|
441 |
+
|
442 |
+
def generate_ray_batch(self, cam_idx: int):
|
443 |
+
"""Generate ray batch for a specified camera in the dataset."""
|
444 |
+
if self._render_spherical:
|
445 |
+
camtoworld = self.camtoworlds[cam_idx]
|
446 |
+
rays = camera_utils.cast_spherical_rays(
|
447 |
+
camtoworld, self.height, self.width, self.near, self.far)
|
448 |
+
return rays
|
449 |
+
else:
|
450 |
+
# Generate rays for all pixels in the image.
|
451 |
+
pix_x_int, pix_y_int = camera_utils.pixel_coordinates(
|
452 |
+
self.width, self.height)
|
453 |
+
return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx)
|
454 |
+
|
455 |
+
def _next_test(self, item):
|
456 |
+
"""Sample next test batch (one full image)."""
|
457 |
+
return self.generate_ray_batch(item)
|
458 |
+
|
459 |
+
def collate_fn(self, item):
|
460 |
+
return self._next_fn(item[0])
|
461 |
+
|
462 |
+
def __getitem__(self, item):
|
463 |
+
return self._next_fn(item)
|
464 |
+
|
465 |
+
|
466 |
+
class Blender(Dataset):
|
467 |
+
"""Blender Dataset."""
|
468 |
+
|
469 |
+
def _load_renderings(self, config):
|
470 |
+
"""Load images from disk."""
|
471 |
+
if config.render_path:
|
472 |
+
raise ValueError('render_path cannot be used for the blender dataset.')
|
473 |
+
pose_file = os.path.join(self.data_dir, f'transforms_{self.split.value}.json')
|
474 |
+
with utils.open_file(pose_file, 'r') as fp:
|
475 |
+
meta = json.load(fp)
|
476 |
+
images = []
|
477 |
+
disp_images = []
|
478 |
+
normal_images = []
|
479 |
+
cams = []
|
480 |
+
for idx, frame in enumerate(tqdm(meta['frames'], desc='Loading Blender dataset', disable=self.global_rank != 0, leave=False)):
|
481 |
+
fprefix = os.path.join(self.data_dir, frame['file_path'])
|
482 |
+
|
483 |
+
def get_img(f, fprefix=fprefix):
|
484 |
+
image = utils.load_img(fprefix + f)
|
485 |
+
if config.factor > 1:
|
486 |
+
image = lib_image.downsample(image, config.factor)
|
487 |
+
return image
|
488 |
+
|
489 |
+
if self._use_tiffs:
|
490 |
+
channels = [get_img(f'_{ch}.tiff') for ch in ['R', 'G', 'B', 'A']]
|
491 |
+
# Convert image to sRGB color space.
|
492 |
+
image = lib_image.linear_to_srgb_np(np.stack(channels, axis=-1))
|
493 |
+
else:
|
494 |
+
image = get_img('.png') / 255.
|
495 |
+
images.append(image)
|
496 |
+
|
497 |
+
if self._load_disps:
|
498 |
+
disp_image = get_img('_disp.tiff')
|
499 |
+
disp_images.append(disp_image)
|
500 |
+
if self._load_normals:
|
501 |
+
normal_image = get_img('_normal.png')[..., :3] * 2. / 255. - 1.
|
502 |
+
normal_images.append(normal_image)
|
503 |
+
|
504 |
+
cams.append(np.array(frame['transform_matrix'], dtype=np.float32))
|
505 |
+
|
506 |
+
self.images = np.stack(images, axis=0)
|
507 |
+
if self._load_disps:
|
508 |
+
self.disp_images = np.stack(disp_images, axis=0)
|
509 |
+
if self._load_normals:
|
510 |
+
self.normal_images = np.stack(normal_images, axis=0)
|
511 |
+
self.alphas = self.images[..., -1]
|
512 |
+
|
513 |
+
rgb, alpha = self.images[..., :3], self.images[..., -1:]
|
514 |
+
self.images = rgb * alpha + (1. - alpha) # Use a white background.
|
515 |
+
self.height, self.width = self.images.shape[1:3]
|
516 |
+
self.camtoworlds = np.stack(cams, axis=0)
|
517 |
+
self.focal = .5 * self.width / np.tan(.5 * float(meta['camera_angle_x']))
|
518 |
+
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
|
519 |
+
self.height)
|
520 |
+
|
521 |
+
|
522 |
+
class LLFF(Dataset):
|
523 |
+
"""LLFF Dataset."""
|
524 |
+
|
525 |
+
def _load_renderings(self, config):
|
526 |
+
"""Load images from disk."""
|
527 |
+
# Set up scaling factor.
|
528 |
+
image_dir_suffix = ''
|
529 |
+
# Use downsampling factor (unless loading training split for raw dataset,
|
530 |
+
# we train raw at full resolution because of the Bayer mosaic pattern).
|
531 |
+
if config.factor > 0 and not (config.rawnerf_mode and
|
532 |
+
self.split == utils.DataSplit.TRAIN):
|
533 |
+
image_dir_suffix = f'_{config.factor}'
|
534 |
+
factor = config.factor
|
535 |
+
else:
|
536 |
+
factor = 1
|
537 |
+
|
538 |
+
# Copy COLMAP data to local disk for faster loading.
|
539 |
+
colmap_dir = os.path.join(self.data_dir, 'sparse/0/')
|
540 |
+
|
541 |
+
# Load poses.
|
542 |
+
if utils.file_exists(colmap_dir):
|
543 |
+
pose_data = NeRFSceneManager(colmap_dir).process()
|
544 |
+
else:
|
545 |
+
# # Attempt to load Blender/NGP format if COLMAP data not present.
|
546 |
+
# pose_data = load_blender_posedata(self.data_dir)
|
547 |
+
raise ValueError('COLMAP data not found.')
|
548 |
+
image_names, poses, pixtocam, distortion_params, camtype = pose_data
|
549 |
+
|
550 |
+
# Previous NeRF results were generated with images sorted by filename,
|
551 |
+
# use this flag to ensure metrics are reported on the same test set.
|
552 |
+
inds = np.argsort(image_names)
|
553 |
+
image_names = [image_names[i] for i in inds]
|
554 |
+
poses = poses[inds]
|
555 |
+
|
556 |
+
# Load bounds if possible (only used in forward facing scenes).
|
557 |
+
posefile = os.path.join(self.data_dir, 'poses_bounds.npy')
|
558 |
+
if utils.file_exists(posefile):
|
559 |
+
with utils.open_file(posefile, 'rb') as fp:
|
560 |
+
poses_arr = np.load(fp)
|
561 |
+
bounds = poses_arr[:, -2:]
|
562 |
+
else:
|
563 |
+
bounds = np.array([0.01, 1.])
|
564 |
+
self.colmap_to_world_transform = np.eye(4)
|
565 |
+
|
566 |
+
# Scale the inverse intrinsics matrix by the image downsampling factor.
|
567 |
+
pixtocam = pixtocam @ np.diag([factor, factor, 1.])
|
568 |
+
self.pixtocams = pixtocam.astype(np.float32)
|
569 |
+
self.focal = 1. / self.pixtocams[0, 0]
|
570 |
+
self.distortion_params = distortion_params
|
571 |
+
self.camtype = camtype
|
572 |
+
|
573 |
+
# Separate out 360 versus forward facing scenes.
|
574 |
+
if config.forward_facing:
|
575 |
+
# Set the projective matrix defining the NDC transformation.
|
576 |
+
self.pixtocam_ndc = self.pixtocams.reshape(-1, 3, 3)[0]
|
577 |
+
# Rescale according to a default bd factor.
|
578 |
+
scale = 1. / (bounds.min() * .75)
|
579 |
+
poses[:, :3, 3] *= scale
|
580 |
+
self.colmap_to_world_transform = np.diag([scale] * 3 + [1])
|
581 |
+
bounds *= scale
|
582 |
+
# Recenter poses.
|
583 |
+
poses, transform = camera_utils.recenter_poses(poses)
|
584 |
+
self.colmap_to_world_transform = (
|
585 |
+
transform @ self.colmap_to_world_transform)
|
586 |
+
# Forward-facing spiral render path.
|
587 |
+
self.render_poses = camera_utils.generate_spiral_path(
|
588 |
+
poses, bounds, n_frames=config.render_path_frames)
|
589 |
+
else:
|
590 |
+
# Rotate/scale poses to align ground with xy plane and fit to unit cube.
|
591 |
+
poses, transform = camera_utils.transform_poses_pca(poses)
|
592 |
+
self.colmap_to_world_transform = transform
|
593 |
+
if config.render_spline_keyframes is not None:
|
594 |
+
rets = camera_utils.create_render_spline_path(config, image_names,
|
595 |
+
poses, self.exposures)
|
596 |
+
self.spline_indices, self.render_poses, self.render_exposures = rets
|
597 |
+
else:
|
598 |
+
# Automatically generated inward-facing elliptical render path.
|
599 |
+
self.render_poses = camera_utils.generate_ellipse_path(
|
600 |
+
poses,
|
601 |
+
n_frames=config.render_path_frames,
|
602 |
+
z_variation=config.z_variation,
|
603 |
+
z_phase=config.z_phase)
|
604 |
+
|
605 |
+
# Select the split.
|
606 |
+
all_indices = np.arange(len(image_names))
|
607 |
+
if config.llff_use_all_images_for_training:
|
608 |
+
train_indices = all_indices
|
609 |
+
else:
|
610 |
+
train_indices = all_indices % config.llffhold != 0
|
611 |
+
if config.llff_use_all_images_for_testing:
|
612 |
+
test_indices = all_indices
|
613 |
+
else:
|
614 |
+
test_indices = all_indices % config.llffhold == 0
|
615 |
+
split_indices = {
|
616 |
+
utils.DataSplit.TEST: all_indices[test_indices],
|
617 |
+
utils.DataSplit.TRAIN: all_indices[train_indices],
|
618 |
+
}
|
619 |
+
indices = split_indices[self.split]
|
620 |
+
image_names = [image_names[i] for i in indices]
|
621 |
+
poses = poses[indices]
|
622 |
+
# if self.split == utils.DataSplit.TRAIN:
|
623 |
+
# # load different training data on different rank
|
624 |
+
# local_indices = [i for i in range(len(image_names)) if (i + self.global_rank) % self.world_size == 0]
|
625 |
+
# image_names = [image_names[i] for i in local_indices]
|
626 |
+
# poses = poses[local_indices]
|
627 |
+
# indices = local_indices
|
628 |
+
|
629 |
+
raw_testscene = False
|
630 |
+
if config.rawnerf_mode:
|
631 |
+
# Load raw images and metadata.
|
632 |
+
images, metadata, raw_testscene = raw_utils.load_raw_dataset(
|
633 |
+
self.split,
|
634 |
+
self.data_dir,
|
635 |
+
image_names,
|
636 |
+
config.exposure_percentile,
|
637 |
+
factor)
|
638 |
+
self.metadata = metadata
|
639 |
+
|
640 |
+
else:
|
641 |
+
# Load images.
|
642 |
+
colmap_image_dir = os.path.join(self.data_dir, 'images')
|
643 |
+
image_dir = os.path.join(self.data_dir, 'images' + image_dir_suffix)
|
644 |
+
for d in [image_dir, colmap_image_dir]:
|
645 |
+
if not utils.file_exists(d):
|
646 |
+
raise ValueError(f'Image folder {d} does not exist.')
|
647 |
+
# Downsampled images may have different names vs images used for COLMAP,
|
648 |
+
# so we need to map between the two sorted lists of files.
|
649 |
+
colmap_files = sorted(utils.listdir(colmap_image_dir))
|
650 |
+
image_files = sorted(utils.listdir(image_dir))
|
651 |
+
colmap_to_image = dict(zip(colmap_files, image_files))
|
652 |
+
image_paths = [os.path.join(image_dir, colmap_to_image[f])
|
653 |
+
for f in image_names]
|
654 |
+
images = [utils.load_img(x) for x in tqdm(image_paths, desc='Loading LLFF dataset', disable=self.global_rank != 0, leave=False)]
|
655 |
+
images = np.stack(images, axis=0) / 255.
|
656 |
+
|
657 |
+
# EXIF data is usually only present in the original JPEG images.
|
658 |
+
jpeg_paths = [os.path.join(colmap_image_dir, f) for f in image_names]
|
659 |
+
exifs = [utils.load_exif(x) for x in jpeg_paths]
|
660 |
+
self.exifs = exifs
|
661 |
+
if 'ExposureTime' in exifs[0] and 'ISOSpeedRatings' in exifs[0]:
|
662 |
+
gather_exif_value = lambda k: np.array([float(x[k]) for x in exifs])
|
663 |
+
shutters = gather_exif_value('ExposureTime')
|
664 |
+
isos = gather_exif_value('ISOSpeedRatings')
|
665 |
+
self.exposures = shutters * isos / 1000.
|
666 |
+
|
667 |
+
if raw_testscene:
|
668 |
+
# For raw testscene, the first image sent to COLMAP has the same pose as
|
669 |
+
# the ground truth test image. The remaining images form the training set.
|
670 |
+
raw_testscene_poses = {
|
671 |
+
utils.DataSplit.TEST: poses[:1],
|
672 |
+
utils.DataSplit.TRAIN: poses[1:],
|
673 |
+
}
|
674 |
+
poses = raw_testscene_poses[self.split]
|
675 |
+
|
676 |
+
self.poses = poses
|
677 |
+
self.images = images
|
678 |
+
self.camtoworlds = self.render_poses if config.render_path else poses
|
679 |
+
self.height, self.width = images.shape[1:3]
|
680 |
+
|
681 |
+
|
682 |
+
class TanksAndTemplesNerfPP(Dataset):
|
683 |
+
"""Subset of Tanks and Temples Dataset as processed by NeRF++."""
|
684 |
+
|
685 |
+
def _load_renderings(self, config):
|
686 |
+
"""Load images from disk."""
|
687 |
+
if config.render_path:
|
688 |
+
split_str = 'camera_path'
|
689 |
+
else:
|
690 |
+
split_str = self.split.value
|
691 |
+
|
692 |
+
basedir = os.path.join(self.data_dir, split_str)
|
693 |
+
|
694 |
+
# TODO: need to rewrite this to put different data on different rank
|
695 |
+
def load_files(dirname, load_fn, shape=None):
|
696 |
+
files = [
|
697 |
+
os.path.join(basedir, dirname, f)
|
698 |
+
for f in sorted(utils.listdir(os.path.join(basedir, dirname)))
|
699 |
+
]
|
700 |
+
mats = np.array([load_fn(utils.open_file(f, 'rb')) for f in files])
|
701 |
+
if shape is not None:
|
702 |
+
mats = mats.reshape(mats.shape[:1] + shape)
|
703 |
+
return mats
|
704 |
+
|
705 |
+
poses = load_files('pose', np.loadtxt, (4, 4))
|
706 |
+
# Flip Y and Z axes to get correct coordinate frame.
|
707 |
+
poses = np.matmul(poses, np.diag(np.array([1, -1, -1, 1])))
|
708 |
+
|
709 |
+
# For now, ignore all but the first focal length in intrinsics
|
710 |
+
intrinsics = load_files('intrinsics', np.loadtxt, (4, 4))
|
711 |
+
|
712 |
+
if not config.render_path:
|
713 |
+
images = load_files('rgb', lambda f: np.array(Image.open(f))) / 255.
|
714 |
+
self.images = images
|
715 |
+
self.height, self.width = self.images.shape[1:3]
|
716 |
+
|
717 |
+
else:
|
718 |
+
# Hack to grab the image resolution from a test image
|
719 |
+
d = os.path.join(self.data_dir, 'test', 'rgb')
|
720 |
+
f = os.path.join(d, sorted(utils.listdir(d))[0])
|
721 |
+
shape = utils.load_img(f).shape
|
722 |
+
self.height, self.width = shape[:2]
|
723 |
+
self.images = None
|
724 |
+
|
725 |
+
self.camtoworlds = poses
|
726 |
+
self.focal = intrinsics[0, 0, 0]
|
727 |
+
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
|
728 |
+
self.height)
|
729 |
+
|
730 |
+
|
731 |
+
class TanksAndTemplesFVS(Dataset):
|
732 |
+
"""Subset of Tanks and Temples Dataset as processed by Free View Synthesis."""
|
733 |
+
|
734 |
+
def _load_renderings(self, config):
|
735 |
+
"""Load images from disk."""
|
736 |
+
render_only = config.render_path and self.split == utils.DataSplit.TEST
|
737 |
+
|
738 |
+
basedir = os.path.join(self.data_dir, 'dense')
|
739 |
+
sizes = [f for f in sorted(utils.listdir(basedir)) if f.startswith('ibr3d')]
|
740 |
+
sizes = sizes[::-1]
|
741 |
+
|
742 |
+
if config.factor >= len(sizes):
|
743 |
+
raise ValueError(f'Factor {config.factor} larger than {len(sizes)}')
|
744 |
+
|
745 |
+
basedir = os.path.join(basedir, sizes[config.factor])
|
746 |
+
open_fn = lambda f: utils.open_file(os.path.join(basedir, f), 'rb')
|
747 |
+
|
748 |
+
files = [f for f in sorted(utils.listdir(basedir)) if f.startswith('im_')]
|
749 |
+
if render_only:
|
750 |
+
files = files[:1]
|
751 |
+
images = np.array([np.array(Image.open(open_fn(f))) for f in files]) / 255.
|
752 |
+
|
753 |
+
names = ['Ks', 'Rs', 'ts']
|
754 |
+
intrinsics, rot, trans = (np.load(open_fn(f'{n}.npy')) for n in names)
|
755 |
+
|
756 |
+
# Convert poses from colmap world-to-cam into our cam-to-world.
|
757 |
+
w2c = np.concatenate([rot, trans[..., None]], axis=-1)
|
758 |
+
c2w_colmap = np.linalg.inv(camera_utils.pad_poses(w2c))[:, :3, :4]
|
759 |
+
c2w = c2w_colmap @ np.diag(np.array([1, -1, -1, 1]))
|
760 |
+
|
761 |
+
# Reorient poses so z-axis is up
|
762 |
+
poses, _ = camera_utils.transform_poses_pca(c2w)
|
763 |
+
self.poses = poses
|
764 |
+
|
765 |
+
self.images = images
|
766 |
+
self.height, self.width = self.images.shape[1:3]
|
767 |
+
self.camtoworlds = poses
|
768 |
+
# For now, ignore all but the first focal length in intrinsics
|
769 |
+
self.focal = intrinsics[0, 0, 0]
|
770 |
+
self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width,
|
771 |
+
self.height)
|
772 |
+
|
773 |
+
if render_only:
|
774 |
+
render_path = camera_utils.generate_ellipse_path(
|
775 |
+
poses,
|
776 |
+
config.render_path_frames,
|
777 |
+
z_variation=config.z_variation,
|
778 |
+
z_phase=config.z_phase)
|
779 |
+
self.images = None
|
780 |
+
self.camtoworlds = render_path
|
781 |
+
self.render_poses = render_path
|
782 |
+
else:
|
783 |
+
# Select the split.
|
784 |
+
all_indices = np.arange(images.shape[0])
|
785 |
+
indices = {
|
786 |
+
utils.DataSplit.TEST:
|
787 |
+
all_indices[all_indices % config.llffhold == 0],
|
788 |
+
utils.DataSplit.TRAIN:
|
789 |
+
all_indices[all_indices % config.llffhold != 0],
|
790 |
+
}[self.split]
|
791 |
+
|
792 |
+
self.images = self.images[indices]
|
793 |
+
self.camtoworlds = self.camtoworlds[indices]
|
794 |
+
|
795 |
+
|
796 |
+
class DTU(Dataset):
|
797 |
+
"""DTU Dataset."""
|
798 |
+
|
799 |
+
def _load_renderings(self, config):
|
800 |
+
"""Load images from disk."""
|
801 |
+
if config.render_path:
|
802 |
+
raise ValueError('render_path cannot be used for the DTU dataset.')
|
803 |
+
|
804 |
+
images = []
|
805 |
+
pixtocams = []
|
806 |
+
camtoworlds = []
|
807 |
+
|
808 |
+
# Find out whether the particular scan has 49 or 65 images.
|
809 |
+
n_images = len(utils.listdir(self.data_dir)) // 8
|
810 |
+
|
811 |
+
# Loop over all images.
|
812 |
+
for i in range(1, n_images + 1):
|
813 |
+
# Set light condition string accordingly.
|
814 |
+
if config.dtu_light_cond < 7:
|
815 |
+
light_str = f'{config.dtu_light_cond}_r' + ('5000'
|
816 |
+
if i < 50 else '7000')
|
817 |
+
else:
|
818 |
+
light_str = 'max'
|
819 |
+
|
820 |
+
# Load image.
|
821 |
+
fname = os.path.join(self.data_dir, f'rect_{i:03d}_{light_str}.png')
|
822 |
+
image = utils.load_img(fname) / 255.
|
823 |
+
if config.factor > 1:
|
824 |
+
image = lib_image.downsample(image, config.factor)
|
825 |
+
images.append(image)
|
826 |
+
|
827 |
+
# Load projection matrix from file.
|
828 |
+
fname = os.path.join(self.data_dir, f'../../cal18/pos_{i:03d}.txt')
|
829 |
+
with utils.open_file(fname, 'rb') as f:
|
830 |
+
projection = np.loadtxt(f, dtype=np.float32)
|
831 |
+
|
832 |
+
# Decompose projection matrix into pose and camera matrix.
|
833 |
+
camera_mat, rot_mat, t = cv2.decomposeProjectionMatrix(projection)[:3]
|
834 |
+
camera_mat = camera_mat / camera_mat[2, 2]
|
835 |
+
pose = np.eye(4, dtype=np.float32)
|
836 |
+
pose[:3, :3] = rot_mat.transpose()
|
837 |
+
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
838 |
+
pose = pose[:3]
|
839 |
+
camtoworlds.append(pose)
|
840 |
+
|
841 |
+
if config.factor > 0:
|
842 |
+
# Scale camera matrix according to downsampling factor.
|
843 |
+
camera_mat = np.diag([1. / config.factor, 1. / config.factor, 1.
|
844 |
+
]).astype(np.float32) @ camera_mat
|
845 |
+
pixtocams.append(np.linalg.inv(camera_mat))
|
846 |
+
|
847 |
+
pixtocams = np.stack(pixtocams)
|
848 |
+
camtoworlds = np.stack(camtoworlds)
|
849 |
+
images = np.stack(images)
|
850 |
+
|
851 |
+
def rescale_poses(poses):
|
852 |
+
"""Rescales camera poses according to maximum x/y/z value."""
|
853 |
+
s = np.max(np.abs(poses[:, :3, -1]))
|
854 |
+
out = np.copy(poses)
|
855 |
+
out[:, :3, -1] /= s
|
856 |
+
return out
|
857 |
+
|
858 |
+
# Center and scale poses.
|
859 |
+
camtoworlds, _ = camera_utils.recenter_poses(camtoworlds)
|
860 |
+
camtoworlds = rescale_poses(camtoworlds)
|
861 |
+
# Flip y and z axes to get poses in OpenGL coordinate system.
|
862 |
+
camtoworlds = camtoworlds @ np.diag([1., -1., -1., 1.]).astype(np.float32)
|
863 |
+
|
864 |
+
all_indices = np.arange(images.shape[0])
|
865 |
+
split_indices = {
|
866 |
+
utils.DataSplit.TEST: all_indices[all_indices % config.dtuhold == 0],
|
867 |
+
utils.DataSplit.TRAIN: all_indices[all_indices % config.dtuhold != 0],
|
868 |
+
}
|
869 |
+
indices = split_indices[self.split]
|
870 |
+
|
871 |
+
self.images = images[indices]
|
872 |
+
self.height, self.width = images.shape[1:3]
|
873 |
+
self.camtoworlds = camtoworlds[indices]
|
874 |
+
self.pixtocams = pixtocams[indices]
|
875 |
+
|
876 |
+
|
877 |
+
class Multicam(Dataset):
|
878 |
+
def __init__(self,
|
879 |
+
split: str,
|
880 |
+
data_dir: str,
|
881 |
+
config: configs.Config):
|
882 |
+
super().__init__(split, data_dir, config)
|
883 |
+
|
884 |
+
self.multiscale_levels = config.multiscale_levels
|
885 |
+
|
886 |
+
images, camtoworlds, pixtocams, pixtocam_ndc = \
|
887 |
+
self.images, self.camtoworlds, self.pixtocams, self.pixtocam_ndc
|
888 |
+
self.heights, self.widths, self.focals, self.images, self.camtoworlds, self.pixtocams, self.lossmults = [], [], [], [], [], [], []
|
889 |
+
if pixtocam_ndc is not None:
|
890 |
+
self.pixtocam_ndc = []
|
891 |
+
else:
|
892 |
+
self.pixtocam_ndc = None
|
893 |
+
|
894 |
+
for i in range(self._n_examples):
|
895 |
+
for j in range(self.multiscale_levels):
|
896 |
+
self.heights.append(self.height // 2 ** j)
|
897 |
+
self.widths.append(self.width // 2 ** j)
|
898 |
+
|
899 |
+
self.pixtocams.append(pixtocams @ np.diag([self.height / self.heights[-1],
|
900 |
+
self.width / self.widths[-1],
|
901 |
+
1.]))
|
902 |
+
self.focals.append(1. / self.pixtocams[-1][0, 0])
|
903 |
+
if config.forward_facing:
|
904 |
+
# Set the projective matrix defining the NDC transformation.
|
905 |
+
self.pixtocam_ndc.append(pixtocams.reshape(3, 3))
|
906 |
+
|
907 |
+
self.camtoworlds.append(camtoworlds[i])
|
908 |
+
self.lossmults.append(2. ** j)
|
909 |
+
self.images.append(self.down2(images[i], (self.heights[-1], self.widths[-1])))
|
910 |
+
self.pixtocams = np.stack(self.pixtocams)
|
911 |
+
self.camtoworlds = np.stack(self.camtoworlds)
|
912 |
+
self.cameras = (self.pixtocams,
|
913 |
+
self.camtoworlds,
|
914 |
+
self.distortion_params,
|
915 |
+
np.stack(self.pixtocam_ndc) if self.pixtocam_ndc is not None else None)
|
916 |
+
self._generate_rays()
|
917 |
+
|
918 |
+
if self.split == utils.DataSplit.TRAIN:
|
919 |
+
# Always flatten out the height x width dimensions
|
920 |
+
def flatten(x):
|
921 |
+
if x[0] is not None:
|
922 |
+
x = [y.reshape([-1, y.shape[-1]]) for y in x]
|
923 |
+
if self._batching == utils.BatchingMethod.ALL_IMAGES:
|
924 |
+
# If global batching, also concatenate all data into one list
|
925 |
+
x = np.concatenate(x, axis=0)
|
926 |
+
return x
|
927 |
+
else:
|
928 |
+
return None
|
929 |
+
|
930 |
+
self.batches = {k: flatten(v) for k, v in self.batches.items()}
|
931 |
+
self._n_examples = len(self.camtoworlds)
|
932 |
+
|
933 |
+
# Seed the queue with one batch to avoid race condition.
|
934 |
+
if self.split == utils.DataSplit.TRAIN:
|
935 |
+
self._next_fn = self._next_train
|
936 |
+
else:
|
937 |
+
self._next_fn = self._next_test
|
938 |
+
|
939 |
+
def _generate_rays(self):
|
940 |
+
if self.global_rank == 0:
|
941 |
+
tbar = tqdm(range(len(self.camtoworlds)), desc='Generating rays', leave=False)
|
942 |
+
else:
|
943 |
+
tbar = range(len(self.camtoworlds))
|
944 |
+
|
945 |
+
self.batches = defaultdict(list)
|
946 |
+
for cam_idx in tbar:
|
947 |
+
pix_x_int, pix_y_int = camera_utils.pixel_coordinates(
|
948 |
+
self.widths[cam_idx], self.heights[cam_idx])
|
949 |
+
broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None]
|
950 |
+
ray_kwargs = {
|
951 |
+
'lossmult': broadcast_scalar(self.lossmults[cam_idx]),
|
952 |
+
'near': broadcast_scalar(self.near),
|
953 |
+
'far': broadcast_scalar(self.far),
|
954 |
+
'cam_idx': broadcast_scalar(cam_idx),
|
955 |
+
}
|
956 |
+
|
957 |
+
pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs)
|
958 |
+
|
959 |
+
batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype)
|
960 |
+
if not self.render_path:
|
961 |
+
batch['rgb'] = self.images[cam_idx]
|
962 |
+
if self._load_disps:
|
963 |
+
batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int]
|
964 |
+
if self._load_normals:
|
965 |
+
batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int]
|
966 |
+
batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int]
|
967 |
+
for k, v in batch.items():
|
968 |
+
self.batches[k].append(v)
|
969 |
+
|
970 |
+
def _next_train(self, item):
|
971 |
+
"""Sample next training batch (random rays)."""
|
972 |
+
# We assume all images in the dataset are the same resolution, so we can use
|
973 |
+
# the same width/height for sampling all pixels coordinates in the batch.
|
974 |
+
# Batch/patch sampling parameters.
|
975 |
+
num_patches = self._batch_size // self._patch_size ** 2
|
976 |
+
# Random camera indices.
|
977 |
+
if self._batching == utils.BatchingMethod.ALL_IMAGES:
|
978 |
+
ray_indices = np.random.randint(0, self.batches['origins'].shape[0], (num_patches, 1, 1))
|
979 |
+
batch = {k: v[ray_indices] if v is not None else None for k, v in self.batches.items()}
|
980 |
+
else:
|
981 |
+
image_index = np.random.randint(0, self._n_examples, ())
|
982 |
+
ray_indices = np.random.randint(0, self.batches['origins'][image_index].shape[0], (num_patches,))
|
983 |
+
batch = {k: v[image_index][ray_indices] if v is not None else None for k, v in self.batches.items()}
|
984 |
+
batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2]
|
985 |
+
return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()}
|
986 |
+
|
987 |
+
def _next_test(self, item):
|
988 |
+
"""Sample next test batch (one full image)."""
|
989 |
+
batch = {k: v[item] for k, v in self.batches.items()}
|
990 |
+
batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2]
|
991 |
+
return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()}
|
992 |
+
|
993 |
+
@staticmethod
|
994 |
+
def down2(img, sh):
|
995 |
+
return cv2.resize(img, sh[::-1], interpolation=cv2.INTER_CUBIC)
|
996 |
+
|
997 |
+
|
998 |
+
class MultiLLFF(Multicam, LLFF):
|
999 |
+
pass
|
1000 |
+
|
1001 |
+
|
1002 |
+
if __name__ == '__main__':
|
1003 |
+
from internal import configs
|
1004 |
+
import accelerate
|
1005 |
+
|
1006 |
+
config = configs.Config()
|
1007 |
+
accelerator = accelerate.Accelerator()
|
1008 |
+
config.world_size = accelerator.num_processes
|
1009 |
+
config.global_rank = accelerator.process_index
|
1010 |
+
config.factor = 8
|
1011 |
+
dataset = LLFF('test', '/SSD_DISK/datasets/360_v2/bicycle', config)
|
1012 |
+
print(len(dataset))
|
1013 |
+
for _ in tqdm(dataset):
|
1014 |
+
pass
|
1015 |
+
print('done')
|
1016 |
+
# print(accelerator.process_index)
|
internal/geopoly.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def compute_sq_dist(mat0, mat1=None):
|
6 |
+
"""Compute the squared Euclidean distance between all pairs of columns."""
|
7 |
+
if mat1 is None:
|
8 |
+
mat1 = mat0
|
9 |
+
# Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y.
|
10 |
+
sq_norm0 = np.sum(mat0 ** 2, 0)
|
11 |
+
sq_norm1 = np.sum(mat1 ** 2, 0)
|
12 |
+
sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1
|
13 |
+
sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors.
|
14 |
+
return sq_dist
|
15 |
+
|
16 |
+
|
17 |
+
def compute_tesselation_weights(v):
|
18 |
+
"""Tesselate the vertices of a triangle by a factor of `v`."""
|
19 |
+
if v < 1:
|
20 |
+
raise ValueError(f'v {v} must be >= 1')
|
21 |
+
int_weights = []
|
22 |
+
for i in range(v + 1):
|
23 |
+
for j in range(v + 1 - i):
|
24 |
+
int_weights.append((i, j, v - (i + j)))
|
25 |
+
int_weights = np.array(int_weights)
|
26 |
+
weights = int_weights / v # Barycentric weights.
|
27 |
+
return weights
|
28 |
+
|
29 |
+
|
30 |
+
def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4):
|
31 |
+
"""Tesselate the vertices of a geodesic polyhedron.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
base_verts: tensor of floats, the vertex coordinates of the geodesic.
|
35 |
+
base_faces: tensor of ints, the indices of the vertices of base_verts that
|
36 |
+
constitute eachface of the polyhedra.
|
37 |
+
v: int, the factor of the tesselation (v==1 is a no-op).
|
38 |
+
eps: float, a small value used to determine if two vertices are the same.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
verts: a tensor of floats, the coordinates of the tesselated vertices.
|
42 |
+
"""
|
43 |
+
if not isinstance(v, int):
|
44 |
+
raise ValueError(f'v {v} must an integer')
|
45 |
+
tri_weights = compute_tesselation_weights(v)
|
46 |
+
|
47 |
+
verts = []
|
48 |
+
for base_face in base_faces:
|
49 |
+
new_verts = np.matmul(tri_weights, base_verts[base_face, :])
|
50 |
+
new_verts /= np.sqrt(np.sum(new_verts ** 2, 1, keepdims=True))
|
51 |
+
verts.append(new_verts)
|
52 |
+
verts = np.concatenate(verts, 0)
|
53 |
+
|
54 |
+
sq_dist = compute_sq_dist(verts.T)
|
55 |
+
assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist])
|
56 |
+
unique = np.unique(assignment)
|
57 |
+
verts = verts[unique, :]
|
58 |
+
|
59 |
+
return verts
|
60 |
+
|
61 |
+
|
62 |
+
def generate_basis(base_shape,
|
63 |
+
angular_tesselation,
|
64 |
+
remove_symmetries=True,
|
65 |
+
eps=1e-4):
|
66 |
+
"""Generates a 3D basis by tesselating a geometric polyhedron.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
base_shape: string, the name of the starting polyhedron, must be either
|
70 |
+
'icosahedron' or 'octahedron'.
|
71 |
+
angular_tesselation: int, the number of times to tesselate the polyhedron,
|
72 |
+
must be >= 1 (a value of 1 is a no-op to the polyhedron).
|
73 |
+
remove_symmetries: bool, if True then remove the symmetric basis columns,
|
74 |
+
which is usually a good idea because otherwise projections onto the basis
|
75 |
+
will have redundant negative copies of each other.
|
76 |
+
eps: float, a small number used to determine symmetries.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
basis: a matrix with shape [3, n].
|
80 |
+
"""
|
81 |
+
if base_shape == 'icosahedron':
|
82 |
+
a = (np.sqrt(5) + 1) / 2
|
83 |
+
verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1),
|
84 |
+
(0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0),
|
85 |
+
(-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2)
|
86 |
+
faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1),
|
87 |
+
(8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3),
|
88 |
+
(7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6),
|
89 |
+
(6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5),
|
90 |
+
(7, 2, 11)])
|
91 |
+
verts = tesselate_geodesic(verts, faces, angular_tesselation)
|
92 |
+
elif base_shape == 'octahedron':
|
93 |
+
verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0),
|
94 |
+
(1, 0, 0)])
|
95 |
+
corners = np.array(list(itertools.product([-1, 1], repeat=3)))
|
96 |
+
pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)
|
97 |
+
faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)
|
98 |
+
verts = tesselate_geodesic(verts, faces, angular_tesselation)
|
99 |
+
else:
|
100 |
+
raise ValueError(f'base_shape {base_shape} not supported')
|
101 |
+
|
102 |
+
if remove_symmetries:
|
103 |
+
# Remove elements of `verts` that are reflections of each other.
|
104 |
+
match = compute_sq_dist(verts.T, -verts.T) < eps
|
105 |
+
verts = verts[np.any(np.triu(match), 1), :]
|
106 |
+
|
107 |
+
basis = verts[:, ::-1]
|
108 |
+
return basis
|
internal/image.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from internal import math
|
4 |
+
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
def mse_to_psnr(mse):
|
9 |
+
"""Compute PSNR given an MSE (we assume the maximum pixel value is 1)."""
|
10 |
+
return -10. / np.log(10.) * np.log(mse)
|
11 |
+
|
12 |
+
|
13 |
+
def psnr_to_mse(psnr):
|
14 |
+
"""Compute MSE given a PSNR (we assume the maximum pixel value is 1)."""
|
15 |
+
return np.exp(-0.1 * np.log(10.) * psnr)
|
16 |
+
|
17 |
+
|
18 |
+
def ssim_to_dssim(ssim):
|
19 |
+
"""Compute DSSIM given an SSIM."""
|
20 |
+
return (1 - ssim) / 2
|
21 |
+
|
22 |
+
|
23 |
+
def dssim_to_ssim(dssim):
|
24 |
+
"""Compute DSSIM given an SSIM."""
|
25 |
+
return 1 - 2 * dssim
|
26 |
+
|
27 |
+
|
28 |
+
def linear_to_srgb(linear, eps=None):
|
29 |
+
"""Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
|
30 |
+
if eps is None:
|
31 |
+
eps = torch.finfo(linear.dtype).eps
|
32 |
+
# eps = 1e-3
|
33 |
+
|
34 |
+
srgb0 = 323 / 25 * linear
|
35 |
+
srgb1 = (211 * linear.clamp_min(eps) ** (5 / 12) - 11) / 200
|
36 |
+
return torch.where(linear <= 0.0031308, srgb0, srgb1)
|
37 |
+
|
38 |
+
|
39 |
+
def linear_to_srgb_np(linear, eps=None):
|
40 |
+
"""Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
|
41 |
+
if eps is None:
|
42 |
+
eps = np.finfo(linear.dtype).eps
|
43 |
+
srgb0 = 323 / 25 * linear
|
44 |
+
srgb1 = (211 * np.maximum(eps, linear) ** (5 / 12) - 11) / 200
|
45 |
+
return np.where(linear <= 0.0031308, srgb0, srgb1)
|
46 |
+
|
47 |
+
|
48 |
+
def srgb_to_linear(srgb, eps=None):
|
49 |
+
"""Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
|
50 |
+
if eps is None:
|
51 |
+
eps = np.finfo(srgb.dtype).eps
|
52 |
+
linear0 = 25 / 323 * srgb
|
53 |
+
linear1 = np.maximum(eps, ((200 * srgb + 11) / (211))) ** (12 / 5)
|
54 |
+
return np.where(srgb <= 0.04045, linear0, linear1)
|
55 |
+
|
56 |
+
|
57 |
+
def downsample(img, factor):
|
58 |
+
"""Area downsample img (factor must evenly divide img height and width)."""
|
59 |
+
sh = img.shape
|
60 |
+
if not (sh[0] % factor == 0 and sh[1] % factor == 0):
|
61 |
+
raise ValueError(f'Downsampling factor {factor} does not '
|
62 |
+
f'evenly divide image shape {sh[:2]}')
|
63 |
+
img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:])
|
64 |
+
img = img.mean((1, 3))
|
65 |
+
return img
|
66 |
+
|
67 |
+
|
68 |
+
def color_correct(img, ref, num_iters=5, eps=0.5 / 255):
|
69 |
+
"""Warp `img` to match the colors in `ref_img`."""
|
70 |
+
if img.shape[-1] != ref.shape[-1]:
|
71 |
+
raise ValueError(
|
72 |
+
f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match'
|
73 |
+
)
|
74 |
+
num_channels = img.shape[-1]
|
75 |
+
img_mat = img.reshape([-1, num_channels])
|
76 |
+
ref_mat = ref.reshape([-1, num_channels])
|
77 |
+
is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps].
|
78 |
+
mask0 = is_unclipped(img_mat)
|
79 |
+
# Because the set of saturated pixels may change after solving for a
|
80 |
+
# transformation, we repeatedly solve a system `num_iters` times and update
|
81 |
+
# our estimate of which pixels are saturated.
|
82 |
+
for _ in range(num_iters):
|
83 |
+
# Construct the left hand side of a linear system that contains a quadratic
|
84 |
+
# expansion of each pixel of `img`.
|
85 |
+
a_mat = []
|
86 |
+
for c in range(num_channels):
|
87 |
+
a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term.
|
88 |
+
a_mat.append(img_mat) # Linear term.
|
89 |
+
a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term.
|
90 |
+
a_mat = torch.cat(a_mat, dim=-1)
|
91 |
+
warp = []
|
92 |
+
for c in range(num_channels):
|
93 |
+
# Construct the right hand side of a linear system containing each color
|
94 |
+
# of `ref`.
|
95 |
+
b = ref_mat[:, c]
|
96 |
+
# Ignore rows of the linear system that were saturated in the input or are
|
97 |
+
# saturated in the current corrected color estimate.
|
98 |
+
mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
|
99 |
+
ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
|
100 |
+
mb = torch.where(mask, b, torch.zeros_like(b))
|
101 |
+
w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
|
102 |
+
assert torch.all(torch.isfinite(w))
|
103 |
+
warp.append(w)
|
104 |
+
warp = torch.stack(warp, dim=-1)
|
105 |
+
# Apply the warp to update img_mat.
|
106 |
+
img_mat = torch.clip(math.matmul(a_mat, warp), 0, 1)
|
107 |
+
corrected_img = torch.reshape(img_mat, img.shape)
|
108 |
+
return corrected_img
|
109 |
+
|
110 |
+
|
111 |
+
class MetricHarness:
|
112 |
+
"""A helper class for evaluating several error metrics."""
|
113 |
+
|
114 |
+
def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s):
|
115 |
+
"""Evaluate the error between a predicted rgb image and the true image."""
|
116 |
+
rgb_pred = (rgb_pred * 255).astype(np.uint8)
|
117 |
+
rgb_gt = (rgb_gt * 255).astype(np.uint8)
|
118 |
+
rgb_pred_gray = cv2.cvtColor(rgb_pred, cv2.COLOR_RGB2GRAY)
|
119 |
+
rgb_gt_gray = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY)
|
120 |
+
psnr = float(peak_signal_noise_ratio(rgb_pred, rgb_gt, data_range=255))
|
121 |
+
ssim = float(structural_similarity(rgb_pred_gray, rgb_gt_gray, data_range=255))
|
122 |
+
|
123 |
+
return {
|
124 |
+
name_fn('psnr'): psnr,
|
125 |
+
name_fn('ssim'): ssim,
|
126 |
+
}
|
internal/math.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
@torch.jit.script
|
6 |
+
def erf(x):
|
7 |
+
return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2))
|
8 |
+
|
9 |
+
|
10 |
+
def matmul(a, b):
|
11 |
+
return (a[..., None] * b[..., None, :, :]).sum(dim=-2)
|
12 |
+
# B,3,4,1 B,1,4,3
|
13 |
+
|
14 |
+
# cause nan when fp16
|
15 |
+
# return torch.matmul(a, b)
|
16 |
+
|
17 |
+
|
18 |
+
def safe_trig_helper(x, fn, t=100 * torch.pi):
|
19 |
+
"""Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
|
20 |
+
return fn(torch.where(torch.abs(x) < t, x, x % t))
|
21 |
+
|
22 |
+
|
23 |
+
def safe_cos(x):
|
24 |
+
return safe_trig_helper(x, torch.cos)
|
25 |
+
|
26 |
+
|
27 |
+
def safe_sin(x):
|
28 |
+
return safe_trig_helper(x, torch.sin)
|
29 |
+
|
30 |
+
|
31 |
+
def safe_exp(x):
|
32 |
+
return torch.exp(x.clamp_max(88.))
|
33 |
+
|
34 |
+
|
35 |
+
def safe_exp_jvp(primals, tangents):
|
36 |
+
"""Override safe_exp()'s gradient so that it's large when inputs are large."""
|
37 |
+
x, = primals
|
38 |
+
x_dot, = tangents
|
39 |
+
exp_x = safe_exp(x)
|
40 |
+
exp_x_dot = exp_x * x_dot
|
41 |
+
return exp_x, exp_x_dot
|
42 |
+
|
43 |
+
|
44 |
+
def log_lerp(t, v0, v1):
|
45 |
+
"""Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
|
46 |
+
if v0 <= 0 or v1 <= 0:
|
47 |
+
raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
|
48 |
+
lv0 = np.log(v0)
|
49 |
+
lv1 = np.log(v1)
|
50 |
+
return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0)
|
51 |
+
|
52 |
+
|
53 |
+
def learning_rate_decay(step,
|
54 |
+
lr_init,
|
55 |
+
lr_final,
|
56 |
+
max_steps,
|
57 |
+
lr_delay_steps=0,
|
58 |
+
lr_delay_mult=1):
|
59 |
+
"""Continuous learning rate decay function.
|
60 |
+
|
61 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
62 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
63 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
64 |
+
function of lr_delay_mult, such that the initial learning rate is
|
65 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
66 |
+
to the normal learning rate when steps>lr_delay_steps.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
step: int, the current optimization step.
|
70 |
+
lr_init: float, the initial learning rate.
|
71 |
+
lr_final: float, the final learning rate.
|
72 |
+
max_steps: int, the number of steps during optimization.
|
73 |
+
lr_delay_steps: int, the number of steps to delay the full learning rate.
|
74 |
+
lr_delay_mult: float, the multiplier on the rate when delaying it.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
lr: the learning for current step 'step'.
|
78 |
+
"""
|
79 |
+
if lr_delay_steps > 0:
|
80 |
+
# A kind of reverse cosine decay.
|
81 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
82 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
|
83 |
+
else:
|
84 |
+
delay_rate = 1.
|
85 |
+
return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)
|
86 |
+
|
87 |
+
|
88 |
+
def sorted_interp(x, xp, fp):
|
89 |
+
"""A TPU-friendly version of interp(), where xp and fp must be sorted."""
|
90 |
+
|
91 |
+
# Identify the location in `xp` that corresponds to each `x`.
|
92 |
+
# The final `True` index in `mask` is the start of the matching interval.
|
93 |
+
mask = x[..., None, :] >= xp[..., :, None]
|
94 |
+
|
95 |
+
def find_interval(x):
|
96 |
+
# Grab the value where `mask` switches from True to False, and vice versa.
|
97 |
+
# This approach takes advantage of the fact that `x` is sorted.
|
98 |
+
x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values
|
99 |
+
x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values
|
100 |
+
return x0, x1
|
101 |
+
|
102 |
+
fp0, fp1 = find_interval(fp)
|
103 |
+
xp0, xp1 = find_interval(xp)
|
104 |
+
|
105 |
+
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
|
106 |
+
ret = fp0 + offset * (fp1 - fp0)
|
107 |
+
return ret
|
108 |
+
|
109 |
+
|
110 |
+
def sorted_interp_quad(x, xp, fpdf, fcdf):
|
111 |
+
"""interp in quadratic"""
|
112 |
+
|
113 |
+
# Identify the location in `xp` that corresponds to each `x`.
|
114 |
+
# The final `True` index in `mask` is the start of the matching interval.
|
115 |
+
mask = x[..., None, :] >= xp[..., :, None]
|
116 |
+
|
117 |
+
def find_interval(x, return_idx=False):
|
118 |
+
# Grab the value where `mask` switches from True to False, and vice versa.
|
119 |
+
# This approach takes advantage of the fact that `x` is sorted.
|
120 |
+
x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2)
|
121 |
+
x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2)
|
122 |
+
if return_idx:
|
123 |
+
return x0, x1, x0_idx, x1_idx
|
124 |
+
return x0, x1
|
125 |
+
|
126 |
+
fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True)
|
127 |
+
fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1)
|
128 |
+
fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1)
|
129 |
+
xp0, xp1 = find_interval(xp)
|
130 |
+
|
131 |
+
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
|
132 |
+
ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2
|
133 |
+
return ret
|
internal/models.py
ADDED
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import accelerate
|
2 |
+
import gin
|
3 |
+
from internal import coord
|
4 |
+
from internal import geopoly
|
5 |
+
from internal import image
|
6 |
+
from internal import math
|
7 |
+
from internal import ref_utils
|
8 |
+
from internal import train_utils
|
9 |
+
from internal import render
|
10 |
+
from internal import stepfun
|
11 |
+
from internal import utils
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.utils._pytree import tree_map
|
17 |
+
from tqdm import tqdm
|
18 |
+
from gridencoder import GridEncoder
|
19 |
+
from torch_scatter import segment_coo
|
20 |
+
|
21 |
+
gin.config.external_configurable(math.safe_exp, module='math')
|
22 |
+
|
23 |
+
|
24 |
+
def set_kwargs(self, kwargs):
|
25 |
+
for k, v in kwargs.items():
|
26 |
+
setattr(self, k, v)
|
27 |
+
|
28 |
+
|
29 |
+
@gin.configurable
|
30 |
+
class Model(nn.Module):
|
31 |
+
"""A mip-Nerf360 model containing all MLPs."""
|
32 |
+
num_prop_samples: int = 64 # The number of samples for each proposal level.
|
33 |
+
num_nerf_samples: int = 32 # The number of samples the final nerf level.
|
34 |
+
num_levels: int = 3 # The number of sampling levels (3==2 proposals, 1 nerf).
|
35 |
+
bg_intensity_range = (1., 1.) # The range of background colors.
|
36 |
+
anneal_slope: float = 10 # Higher = more rapid annealing.
|
37 |
+
stop_level_grad: bool = True # If True, don't backprop across levels.
|
38 |
+
use_viewdirs: bool = True # If True, use view directions as input.
|
39 |
+
raydist_fn = None # The curve used for ray dists.
|
40 |
+
single_jitter: bool = True # If True, jitter whole rays instead of samples.
|
41 |
+
dilation_multiplier: float = 0.5 # How much to dilate intervals relatively.
|
42 |
+
dilation_bias: float = 0.0025 # How much to dilate intervals absolutely.
|
43 |
+
num_glo_features: int = 0 # GLO vector length, disabled if 0.
|
44 |
+
num_glo_embeddings: int = 1000 # Upper bound on max number of train images.
|
45 |
+
learned_exposure_scaling: bool = False # Learned exposure scaling (RawNeRF).
|
46 |
+
near_anneal_rate = None # How fast to anneal in near bound.
|
47 |
+
near_anneal_init: float = 0.95 # Where to initialize near bound (in [0, 1]).
|
48 |
+
single_mlp: bool = False # Use the NerfMLP for all rounds of sampling.
|
49 |
+
distinct_prop: bool = True # Use the NerfMLP for all rounds of sampling.
|
50 |
+
resample_padding: float = 0.0 # Dirichlet/alpha "padding" on the histogram.
|
51 |
+
opaque_background: bool = False # If true, make the background opaque.
|
52 |
+
power_lambda: float = -1.5
|
53 |
+
std_scale: float = 0.5
|
54 |
+
prop_desired_grid_size = [512, 2048]
|
55 |
+
|
56 |
+
def __init__(self, config=None, **kwargs):
|
57 |
+
super().__init__()
|
58 |
+
set_kwargs(self, kwargs)
|
59 |
+
self.config = config
|
60 |
+
|
61 |
+
# Construct MLPs. WARNING: Construction order may matter, if MLP weights are
|
62 |
+
# being regularized.
|
63 |
+
self.nerf_mlp = NerfMLP(num_glo_features=self.num_glo_features,
|
64 |
+
num_glo_embeddings=self.num_glo_embeddings)
|
65 |
+
if self.single_mlp:
|
66 |
+
self.prop_mlp = self.nerf_mlp
|
67 |
+
elif not self.distinct_prop:
|
68 |
+
self.prop_mlp = PropMLP()
|
69 |
+
else:
|
70 |
+
for i in range(self.num_levels - 1):
|
71 |
+
self.register_module(f'prop_mlp_{i}', PropMLP(grid_disired_resolution=self.prop_desired_grid_size[i]))
|
72 |
+
if self.num_glo_features > 0 and not config.zero_glo:
|
73 |
+
# Construct/grab GLO vectors for the cameras of each input ray.
|
74 |
+
self.glo_vecs = nn.Embedding(self.num_glo_embeddings, self.num_glo_features)
|
75 |
+
|
76 |
+
if self.learned_exposure_scaling:
|
77 |
+
# Setup learned scaling factors for output colors.
|
78 |
+
max_num_exposures = self.num_glo_embeddings
|
79 |
+
# Initialize the learned scaling offsets at 0.
|
80 |
+
self.exposure_scaling_offsets = nn.Embedding(max_num_exposures, 3)
|
81 |
+
torch.nn.init.zeros_(self.exposure_scaling_offsets.weight)
|
82 |
+
|
83 |
+
def forward(
|
84 |
+
self,
|
85 |
+
rand,
|
86 |
+
batch,
|
87 |
+
train_frac,
|
88 |
+
compute_extras,
|
89 |
+
zero_glo=True,
|
90 |
+
):
|
91 |
+
"""The mip-NeRF Model.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
rand: random number generator (or None for deterministic output).
|
95 |
+
batch: util.Rays, a pytree of ray origins, directions, and viewdirs.
|
96 |
+
train_frac: float in [0, 1], what fraction of training is complete.
|
97 |
+
compute_extras: bool, if True, compute extra quantities besides color.
|
98 |
+
zero_glo: bool, if True, when using GLO pass in vector of zeros.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
ret: list, [*(rgb, distance, acc)]
|
102 |
+
"""
|
103 |
+
device = batch['origins'].device
|
104 |
+
if self.num_glo_features > 0:
|
105 |
+
if not zero_glo:
|
106 |
+
# Construct/grab GLO vectors for the cameras of each input ray.
|
107 |
+
cam_idx = batch['cam_idx'][..., 0]
|
108 |
+
glo_vec = self.glo_vecs(cam_idx.long())
|
109 |
+
else:
|
110 |
+
glo_vec = torch.zeros(batch['origins'].shape[:-1] + (self.num_glo_features,), device=device)
|
111 |
+
else:
|
112 |
+
glo_vec = None
|
113 |
+
|
114 |
+
# Define the mapping from normalized to metric ray distance.
|
115 |
+
_, s_to_t = coord.construct_ray_warps(self.raydist_fn, batch['near'], batch['far'], self.power_lambda)
|
116 |
+
|
117 |
+
# Initialize the range of (normalized) distances for each ray to [0, 1],
|
118 |
+
# and assign that single interval a weight of 1. These distances and weights
|
119 |
+
# will be repeatedly updated as we proceed through sampling levels.
|
120 |
+
# `near_anneal_rate` can be used to anneal in the near bound at the start
|
121 |
+
# of training, eg. 0.1 anneals in the bound over the first 10% of training.
|
122 |
+
if self.near_anneal_rate is None:
|
123 |
+
init_s_near = 0.
|
124 |
+
else:
|
125 |
+
init_s_near = np.clip(1 - train_frac / self.near_anneal_rate, 0,
|
126 |
+
self.near_anneal_init)
|
127 |
+
init_s_far = 1.
|
128 |
+
sdist = torch.cat([
|
129 |
+
torch.full_like(batch['near'], init_s_near),
|
130 |
+
torch.full_like(batch['far'], init_s_far)
|
131 |
+
], dim=-1)
|
132 |
+
weights = torch.ones_like(batch['near'])
|
133 |
+
prod_num_samples = 1
|
134 |
+
|
135 |
+
ray_history = []
|
136 |
+
renderings = []
|
137 |
+
for i_level in range(self.num_levels):
|
138 |
+
is_prop = i_level < (self.num_levels - 1)
|
139 |
+
num_samples = self.num_prop_samples if is_prop else self.num_nerf_samples
|
140 |
+
|
141 |
+
# Dilate by some multiple of the expected span of each current interval,
|
142 |
+
# with some bias added in.
|
143 |
+
dilation = self.dilation_bias + self.dilation_multiplier * (
|
144 |
+
init_s_far - init_s_near) / prod_num_samples
|
145 |
+
|
146 |
+
# Record the product of the number of samples seen so far.
|
147 |
+
prod_num_samples *= num_samples
|
148 |
+
|
149 |
+
# After the first level (where dilation would be a no-op) optionally
|
150 |
+
# dilate the interval weights along each ray slightly so that they're
|
151 |
+
# overestimates, which can reduce aliasing.
|
152 |
+
use_dilation = self.dilation_bias > 0 or self.dilation_multiplier > 0
|
153 |
+
if i_level > 0 and use_dilation:
|
154 |
+
sdist, weights = stepfun.max_dilate_weights(
|
155 |
+
sdist,
|
156 |
+
weights,
|
157 |
+
dilation,
|
158 |
+
domain=(init_s_near, init_s_far),
|
159 |
+
renormalize=True)
|
160 |
+
sdist = sdist[..., 1:-1]
|
161 |
+
weights = weights[..., 1:-1]
|
162 |
+
|
163 |
+
# Optionally anneal the weights as a function of training iteration.
|
164 |
+
if self.anneal_slope > 0:
|
165 |
+
# Schlick's bias function, see https://arxiv.org/abs/2010.09714
|
166 |
+
bias = lambda x, s: (s * x) / ((s - 1) * x + 1)
|
167 |
+
anneal = bias(train_frac, self.anneal_slope)
|
168 |
+
else:
|
169 |
+
anneal = 1.
|
170 |
+
|
171 |
+
# A slightly more stable way to compute weights**anneal. If the distance
|
172 |
+
# between adjacent intervals is zero then its weight is fixed to 0.
|
173 |
+
logits_resample = torch.where(
|
174 |
+
sdist[..., 1:] > sdist[..., :-1],
|
175 |
+
anneal * torch.log(weights + self.resample_padding),
|
176 |
+
torch.full_like(sdist[..., :-1], -torch.inf))
|
177 |
+
|
178 |
+
# Draw sampled intervals from each ray's current weights.
|
179 |
+
sdist = stepfun.sample_intervals(
|
180 |
+
rand,
|
181 |
+
sdist,
|
182 |
+
logits_resample,
|
183 |
+
num_samples,
|
184 |
+
single_jitter=self.single_jitter,
|
185 |
+
domain=(init_s_near, init_s_far))
|
186 |
+
|
187 |
+
# Optimization will usually go nonlinear if you propagate gradients
|
188 |
+
# through sampling.
|
189 |
+
if self.stop_level_grad:
|
190 |
+
sdist = sdist.detach()
|
191 |
+
|
192 |
+
# Convert normalized distances to metric distances.
|
193 |
+
tdist = s_to_t(sdist)
|
194 |
+
|
195 |
+
# Cast our rays, by turning our distance intervals into Gaussians.
|
196 |
+
means, stds, ts = render.cast_rays(
|
197 |
+
tdist,
|
198 |
+
batch['origins'],
|
199 |
+
batch['directions'],
|
200 |
+
batch['cam_dirs'],
|
201 |
+
batch['radii'],
|
202 |
+
rand,
|
203 |
+
std_scale=self.std_scale)
|
204 |
+
|
205 |
+
# Push our Gaussians through one of our two MLPs.
|
206 |
+
mlp = (self.get_submodule(
|
207 |
+
f'prop_mlp_{i_level}') if self.distinct_prop else self.prop_mlp) if is_prop else self.nerf_mlp
|
208 |
+
ray_results = mlp(
|
209 |
+
rand,
|
210 |
+
means, stds,
|
211 |
+
viewdirs=batch['viewdirs'] if self.use_viewdirs else None,
|
212 |
+
imageplane=batch.get('imageplane'),
|
213 |
+
glo_vec=None if is_prop else glo_vec,
|
214 |
+
exposure=batch.get('exposure_values'),
|
215 |
+
)
|
216 |
+
if self.config.gradient_scaling:
|
217 |
+
ray_results['rgb'], ray_results['density'] = train_utils.GradientScaler.apply(
|
218 |
+
ray_results['rgb'], ray_results['density'], ts.mean(dim=-1))
|
219 |
+
|
220 |
+
# Get the weights used by volumetric rendering (and our other losses).
|
221 |
+
weights = render.compute_alpha_weights(
|
222 |
+
ray_results['density'],
|
223 |
+
tdist,
|
224 |
+
batch['directions'],
|
225 |
+
opaque_background=self.opaque_background,
|
226 |
+
)[0]
|
227 |
+
|
228 |
+
# Define or sample the background color for each ray.
|
229 |
+
if self.bg_intensity_range[0] == self.bg_intensity_range[1]:
|
230 |
+
# If the min and max of the range are equal, just take it.
|
231 |
+
bg_rgbs = self.bg_intensity_range[0]
|
232 |
+
elif rand is None:
|
233 |
+
# If rendering is deterministic, use the midpoint of the range.
|
234 |
+
bg_rgbs = (self.bg_intensity_range[0] + self.bg_intensity_range[1]) / 2
|
235 |
+
else:
|
236 |
+
# Sample RGB values from the range for each ray.
|
237 |
+
minval = self.bg_intensity_range[0]
|
238 |
+
maxval = self.bg_intensity_range[1]
|
239 |
+
bg_rgbs = torch.rand(weights.shape[:-1] + (3,), device=device) * (maxval - minval) + minval
|
240 |
+
|
241 |
+
# RawNeRF exposure logic.
|
242 |
+
if batch.get('exposure_idx') is not None:
|
243 |
+
# Scale output colors by the exposure.
|
244 |
+
ray_results['rgb'] *= batch['exposure_values'][..., None, :]
|
245 |
+
if self.learned_exposure_scaling:
|
246 |
+
exposure_idx = batch['exposure_idx'][..., 0]
|
247 |
+
# Force scaling offset to always be zero when exposure_idx is 0.
|
248 |
+
# This constraint fixes a reference point for the scene's brightness.
|
249 |
+
mask = exposure_idx > 0
|
250 |
+
# Scaling is parameterized as an offset from 1.
|
251 |
+
scaling = 1 + mask[..., None] * self.exposure_scaling_offsets(exposure_idx.long())
|
252 |
+
ray_results['rgb'] *= scaling[..., None, :]
|
253 |
+
|
254 |
+
# Render each ray.
|
255 |
+
rendering = render.volumetric_rendering(
|
256 |
+
ray_results['rgb'],
|
257 |
+
weights,
|
258 |
+
tdist,
|
259 |
+
bg_rgbs,
|
260 |
+
batch['far'],
|
261 |
+
compute_extras,
|
262 |
+
extras={
|
263 |
+
k: v
|
264 |
+
for k, v in ray_results.items()
|
265 |
+
if k.startswith('normals') or k in ['roughness']
|
266 |
+
})
|
267 |
+
|
268 |
+
if compute_extras:
|
269 |
+
# Collect some rays to visualize directly. By naming these quantities
|
270 |
+
# with `ray_` they get treated differently downstream --- they're
|
271 |
+
# treated as bags of rays, rather than image chunks.
|
272 |
+
n = self.config.vis_num_rays
|
273 |
+
rendering['ray_sdist'] = sdist.reshape([-1, sdist.shape[-1]])[:n, :]
|
274 |
+
rendering['ray_weights'] = (
|
275 |
+
weights.reshape([-1, weights.shape[-1]])[:n, :])
|
276 |
+
rgb = ray_results['rgb']
|
277 |
+
rendering['ray_rgbs'] = (rgb.reshape((-1,) + rgb.shape[-2:]))[:n, :, :]
|
278 |
+
|
279 |
+
if self.training:
|
280 |
+
# Compute the hash decay loss for this level.
|
281 |
+
idx = mlp.encoder.idx
|
282 |
+
param = mlp.encoder.embeddings
|
283 |
+
loss_hash_decay = segment_coo(param ** 2,
|
284 |
+
idx,
|
285 |
+
torch.zeros(idx.max() + 1, param.shape[-1], device=param.device),
|
286 |
+
reduce='mean'
|
287 |
+
).mean()
|
288 |
+
ray_results['loss_hash_decay'] = loss_hash_decay
|
289 |
+
|
290 |
+
renderings.append(rendering)
|
291 |
+
ray_results['sdist'] = sdist.clone()
|
292 |
+
ray_results['weights'] = weights.clone()
|
293 |
+
ray_history.append(ray_results)
|
294 |
+
|
295 |
+
if compute_extras:
|
296 |
+
# Because the proposal network doesn't produce meaningful colors, for
|
297 |
+
# easier visualization we replace their colors with the final average
|
298 |
+
# color.
|
299 |
+
weights = [r['ray_weights'] for r in renderings]
|
300 |
+
rgbs = [r['ray_rgbs'] for r in renderings]
|
301 |
+
final_rgb = torch.sum(rgbs[-1] * weights[-1][..., None], dim=-2)
|
302 |
+
avg_rgbs = [
|
303 |
+
torch.broadcast_to(final_rgb[:, None, :], r.shape) for r in rgbs[:-1]
|
304 |
+
]
|
305 |
+
for i in range(len(avg_rgbs)):
|
306 |
+
renderings[i]['ray_rgbs'] = avg_rgbs[i]
|
307 |
+
|
308 |
+
return renderings, ray_history
|
309 |
+
|
310 |
+
|
311 |
+
class MLP(nn.Module):
|
312 |
+
"""A PosEnc MLP."""
|
313 |
+
bottleneck_width: int = 256 # The width of the bottleneck vector.
|
314 |
+
net_depth_viewdirs: int = 2 # The depth of the second part of ML.
|
315 |
+
net_width_viewdirs: int = 256 # The width of the second part of MLP.
|
316 |
+
skip_layer_dir: int = 0 # Add a skip connection to 2nd MLP after Nth layers.
|
317 |
+
num_rgb_channels: int = 3 # The number of RGB channels.
|
318 |
+
deg_view: int = 4 # Degree of encoding for viewdirs or refdirs.
|
319 |
+
use_reflections: bool = False # If True, use refdirs instead of viewdirs.
|
320 |
+
use_directional_enc: bool = False # If True, use IDE to encode directions.
|
321 |
+
# If False and if use_directional_enc is True, use zero roughness in IDE.
|
322 |
+
enable_pred_roughness: bool = False
|
323 |
+
roughness_bias: float = -1. # Shift added to raw roughness pre-activation.
|
324 |
+
use_diffuse_color: bool = False # If True, predict diffuse & specular colors.
|
325 |
+
use_specular_tint: bool = False # If True, predict tint.
|
326 |
+
use_n_dot_v: bool = False # If True, feed dot(n * viewdir) to 2nd MLP.
|
327 |
+
bottleneck_noise: float = 0.0 # Std. deviation of noise added to bottleneck.
|
328 |
+
density_bias: float = -1. # Shift added to raw densities pre-activation.
|
329 |
+
density_noise: float = 0. # Standard deviation of noise added to raw density.
|
330 |
+
rgb_premultiplier: float = 1. # Premultiplier on RGB before activation.
|
331 |
+
rgb_bias: float = 0. # The shift added to raw colors pre-activation.
|
332 |
+
rgb_padding: float = 0.001 # Padding added to the RGB outputs.
|
333 |
+
enable_pred_normals: bool = False # If True compute predicted normals.
|
334 |
+
disable_density_normals: bool = False # If True don't compute normals.
|
335 |
+
disable_rgb: bool = False # If True don't output RGB.
|
336 |
+
warp_fn = 'contract'
|
337 |
+
num_glo_features: int = 0 # GLO vector length, disabled if 0.
|
338 |
+
num_glo_embeddings: int = 1000 # Upper bound on max number of train images.
|
339 |
+
scale_featurization: bool = False
|
340 |
+
grid_num_levels: int = 10
|
341 |
+
grid_level_interval: int = 2
|
342 |
+
grid_level_dim: int = 4
|
343 |
+
grid_base_resolution: int = 16
|
344 |
+
grid_disired_resolution: int = 8192
|
345 |
+
grid_log2_hashmap_size: int = 21
|
346 |
+
net_width_glo: int = 128 # The width of the second part of MLP.
|
347 |
+
net_depth_glo: int = 2 # The width of the second part of MLP.
|
348 |
+
|
349 |
+
def __init__(self, **kwargs):
|
350 |
+
super().__init__()
|
351 |
+
set_kwargs(self, kwargs)
|
352 |
+
# Make sure that normals are computed if reflection direction is used.
|
353 |
+
if self.use_reflections and not (self.enable_pred_normals or
|
354 |
+
not self.disable_density_normals):
|
355 |
+
raise ValueError('Normals must be computed for reflection directions.')
|
356 |
+
|
357 |
+
# Precompute and define viewdir or refdir encoding function.
|
358 |
+
if self.use_directional_enc:
|
359 |
+
self.dir_enc_fn = ref_utils.generate_ide_fn(self.deg_view)
|
360 |
+
dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), torch.zeros(1, 1)).shape[-1]
|
361 |
+
else:
|
362 |
+
|
363 |
+
def dir_enc_fn(direction, _):
|
364 |
+
return coord.pos_enc(
|
365 |
+
direction, min_deg=0, max_deg=self.deg_view, append_identity=True)
|
366 |
+
|
367 |
+
self.dir_enc_fn = dir_enc_fn
|
368 |
+
dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), None).shape[-1]
|
369 |
+
self.grid_num_levels = int(
|
370 |
+
np.log(self.grid_disired_resolution / self.grid_base_resolution) / np.log(self.grid_level_interval)) + 1
|
371 |
+
self.encoder = GridEncoder(input_dim=3,
|
372 |
+
num_levels=self.grid_num_levels,
|
373 |
+
level_dim=self.grid_level_dim,
|
374 |
+
base_resolution=self.grid_base_resolution,
|
375 |
+
desired_resolution=self.grid_disired_resolution,
|
376 |
+
log2_hashmap_size=self.grid_log2_hashmap_size,
|
377 |
+
gridtype='hash',
|
378 |
+
align_corners=False)
|
379 |
+
last_dim = self.encoder.output_dim
|
380 |
+
if self.scale_featurization:
|
381 |
+
last_dim += self.encoder.num_levels
|
382 |
+
self.density_layer = nn.Sequential(nn.Linear(last_dim, 64),
|
383 |
+
nn.ReLU(),
|
384 |
+
nn.Linear(64,
|
385 |
+
1 if self.disable_rgb else self.bottleneck_width)) # Hardcoded to a single channel.
|
386 |
+
last_dim = 1 if self.disable_rgb and not self.enable_pred_normals else self.bottleneck_width
|
387 |
+
if self.enable_pred_normals:
|
388 |
+
self.normal_layer = nn.Linear(last_dim, 3)
|
389 |
+
|
390 |
+
if not self.disable_rgb:
|
391 |
+
if self.use_diffuse_color:
|
392 |
+
self.diffuse_layer = nn.Linear(last_dim, self.num_rgb_channels)
|
393 |
+
|
394 |
+
if self.use_specular_tint:
|
395 |
+
self.specular_layer = nn.Linear(last_dim, 3)
|
396 |
+
|
397 |
+
if self.enable_pred_roughness:
|
398 |
+
self.roughness_layer = nn.Linear(last_dim, 1)
|
399 |
+
|
400 |
+
# Output of the first part of MLP.
|
401 |
+
if self.bottleneck_width > 0:
|
402 |
+
last_dim_rgb = self.bottleneck_width
|
403 |
+
else:
|
404 |
+
last_dim_rgb = 0
|
405 |
+
|
406 |
+
last_dim_rgb += dim_dir_enc
|
407 |
+
|
408 |
+
if self.use_n_dot_v:
|
409 |
+
last_dim_rgb += 1
|
410 |
+
|
411 |
+
if self.num_glo_features > 0:
|
412 |
+
last_dim_glo = self.num_glo_features
|
413 |
+
for i in range(self.net_depth_glo - 1):
|
414 |
+
self.register_module(f"lin_glo_{i}", nn.Linear(last_dim_glo, self.net_width_glo))
|
415 |
+
last_dim_glo = self.net_width_glo
|
416 |
+
self.register_module(f"lin_glo_{self.net_depth_glo - 1}",
|
417 |
+
nn.Linear(last_dim_glo, self.bottleneck_width * 2))
|
418 |
+
|
419 |
+
input_dim_rgb = last_dim_rgb
|
420 |
+
for i in range(self.net_depth_viewdirs):
|
421 |
+
lin = nn.Linear(last_dim_rgb, self.net_width_viewdirs)
|
422 |
+
torch.nn.init.kaiming_uniform_(lin.weight)
|
423 |
+
self.register_module(f"lin_second_stage_{i}", lin)
|
424 |
+
last_dim_rgb = self.net_width_viewdirs
|
425 |
+
if i == self.skip_layer_dir:
|
426 |
+
last_dim_rgb += input_dim_rgb
|
427 |
+
self.rgb_layer = nn.Linear(last_dim_rgb, self.num_rgb_channels)
|
428 |
+
|
429 |
+
def predict_density(self, means, stds, rand=False, no_warp=False):
|
430 |
+
"""Helper function to output density."""
|
431 |
+
# Encode input positions
|
432 |
+
if self.warp_fn is not None and not no_warp:
|
433 |
+
means, stds = coord.track_linearize(self.warp_fn, means, stds)
|
434 |
+
# contract [-2, 2] to [-1, 1]
|
435 |
+
bound = 2
|
436 |
+
means = means / bound
|
437 |
+
stds = stds / bound
|
438 |
+
features = self.encoder(means, bound=1).unflatten(-1, (self.encoder.num_levels, -1))
|
439 |
+
weights = torch.erf(1 / torch.sqrt(8 * stds[..., None] ** 2 * self.encoder.grid_sizes ** 2))
|
440 |
+
features = (features * weights[..., None]).mean(dim=-3).flatten(-2, -1)
|
441 |
+
if self.scale_featurization:
|
442 |
+
with torch.no_grad():
|
443 |
+
vl2mean = segment_coo((self.encoder.embeddings ** 2).sum(-1),
|
444 |
+
self.encoder.idx,
|
445 |
+
torch.zeros(self.grid_num_levels, device=weights.device),
|
446 |
+
self.grid_num_levels,
|
447 |
+
reduce='mean'
|
448 |
+
)
|
449 |
+
featurized_w = (2 * weights.mean(dim=-2) - 1) * (self.encoder.init_std ** 2 + vl2mean).sqrt()
|
450 |
+
features = torch.cat([features, featurized_w], dim=-1)
|
451 |
+
x = self.density_layer(features)
|
452 |
+
raw_density = x[..., 0] # Hardcoded to a single channel.
|
453 |
+
# Add noise to regularize the density predictions if needed.
|
454 |
+
if rand and (self.density_noise > 0):
|
455 |
+
raw_density += self.density_noise * torch.randn_like(raw_density)
|
456 |
+
return raw_density, x, means.mean(dim=-2)
|
457 |
+
|
458 |
+
def forward(self,
|
459 |
+
rand,
|
460 |
+
means, stds,
|
461 |
+
viewdirs=None,
|
462 |
+
imageplane=None,
|
463 |
+
glo_vec=None,
|
464 |
+
exposure=None,
|
465 |
+
no_warp=False):
|
466 |
+
"""Evaluate the MLP.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
rand: if random .
|
470 |
+
means: [..., n, 3], coordinate means.
|
471 |
+
stds: [..., n], coordinate stds.
|
472 |
+
viewdirs: [..., 3], if not None, this variable will
|
473 |
+
be part of the input to the second part of the MLP concatenated with the
|
474 |
+
output vector of the first part of the MLP. If None, only the first part
|
475 |
+
of the MLP will be used with input x. In the original paper, this
|
476 |
+
variable is the view direction.
|
477 |
+
imageplane:[batch, 2], xy image plane coordinates
|
478 |
+
for each ray in the batch. Useful for image plane operations such as a
|
479 |
+
learned vignette mapping.
|
480 |
+
glo_vec: [..., num_glo_features], The GLO vector for each ray.
|
481 |
+
exposure: [..., 1], exposure value (shutter_speed * ISO) for each ray.
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
rgb: [..., num_rgb_channels].
|
485 |
+
density: [...].
|
486 |
+
normals: [..., 3], or None.
|
487 |
+
normals_pred: [..., 3], or None.
|
488 |
+
roughness: [..., 1], or None.
|
489 |
+
"""
|
490 |
+
if self.disable_density_normals:
|
491 |
+
raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp)
|
492 |
+
raw_grad_density = None
|
493 |
+
normals = None
|
494 |
+
else:
|
495 |
+
with torch.enable_grad():
|
496 |
+
means.requires_grad_(True)
|
497 |
+
raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp)
|
498 |
+
d_output = torch.ones_like(raw_density, requires_grad=False, device=raw_density.device)
|
499 |
+
raw_grad_density = torch.autograd.grad(
|
500 |
+
outputs=raw_density,
|
501 |
+
inputs=means,
|
502 |
+
grad_outputs=d_output,
|
503 |
+
create_graph=True,
|
504 |
+
retain_graph=True,
|
505 |
+
only_inputs=True)[0]
|
506 |
+
raw_grad_density = raw_grad_density.mean(-2)
|
507 |
+
# Compute normal vectors as negative normalized density gradient.
|
508 |
+
# We normalize the gradient of raw (pre-activation) density because
|
509 |
+
# it's the same as post-activation density, but is more numerically stable
|
510 |
+
# when the activation function has a steep or flat gradient.
|
511 |
+
normals = -ref_utils.l2_normalize(raw_grad_density)
|
512 |
+
|
513 |
+
if self.enable_pred_normals:
|
514 |
+
grad_pred = self.normal_layer(x)
|
515 |
+
|
516 |
+
# Normalize negative predicted gradients to get predicted normal vectors.
|
517 |
+
normals_pred = -ref_utils.l2_normalize(grad_pred)
|
518 |
+
normals_to_use = normals_pred
|
519 |
+
else:
|
520 |
+
grad_pred = None
|
521 |
+
normals_pred = None
|
522 |
+
normals_to_use = normals
|
523 |
+
|
524 |
+
# Apply bias and activation to raw density
|
525 |
+
density = F.softplus(raw_density + self.density_bias)
|
526 |
+
|
527 |
+
roughness = None
|
528 |
+
if self.disable_rgb:
|
529 |
+
rgb = torch.zeros(density.shape + (3,), device=density.device)
|
530 |
+
else:
|
531 |
+
if viewdirs is not None:
|
532 |
+
# Predict diffuse color.
|
533 |
+
if self.use_diffuse_color:
|
534 |
+
raw_rgb_diffuse = self.diffuse_layer(x)
|
535 |
+
|
536 |
+
if self.use_specular_tint:
|
537 |
+
tint = torch.sigmoid(self.specular_layer(x))
|
538 |
+
|
539 |
+
if self.enable_pred_roughness:
|
540 |
+
raw_roughness = self.roughness_layer(x)
|
541 |
+
roughness = (F.softplus(raw_roughness + self.roughness_bias))
|
542 |
+
|
543 |
+
# Output of the first part of MLP.
|
544 |
+
if self.bottleneck_width > 0:
|
545 |
+
bottleneck = x
|
546 |
+
# Add bottleneck noise.
|
547 |
+
if rand and (self.bottleneck_noise > 0):
|
548 |
+
bottleneck += self.bottleneck_noise * torch.randn_like(bottleneck)
|
549 |
+
|
550 |
+
# Append GLO vector if used.
|
551 |
+
if glo_vec is not None:
|
552 |
+
for i in range(self.net_depth_glo):
|
553 |
+
glo_vec = self.get_submodule(f"lin_glo_{i}")(glo_vec)
|
554 |
+
if i != self.net_depth_glo - 1:
|
555 |
+
glo_vec = F.relu(glo_vec)
|
556 |
+
glo_vec = torch.broadcast_to(glo_vec[..., None, :],
|
557 |
+
bottleneck.shape[:-1] + glo_vec.shape[-1:])
|
558 |
+
scale, shift = glo_vec.chunk(2, dim=-1)
|
559 |
+
bottleneck = bottleneck * torch.exp(scale) + shift
|
560 |
+
|
561 |
+
x = [bottleneck]
|
562 |
+
else:
|
563 |
+
x = []
|
564 |
+
|
565 |
+
# Encode view (or reflection) directions.
|
566 |
+
if self.use_reflections:
|
567 |
+
# Compute reflection directions. Note that we flip viewdirs before
|
568 |
+
# reflecting, because they point from the camera to the point,
|
569 |
+
# whereas ref_utils.reflect() assumes they point toward the camera.
|
570 |
+
# Returned refdirs then point from the point to the environment.
|
571 |
+
refdirs = ref_utils.reflect(-viewdirs[..., None, :], normals_to_use)
|
572 |
+
# Encode reflection directions.
|
573 |
+
dir_enc = self.dir_enc_fn(refdirs, roughness)
|
574 |
+
else:
|
575 |
+
# Encode view directions.
|
576 |
+
dir_enc = self.dir_enc_fn(viewdirs, roughness)
|
577 |
+
dir_enc = torch.broadcast_to(
|
578 |
+
dir_enc[..., None, :],
|
579 |
+
bottleneck.shape[:-1] + (dir_enc.shape[-1],))
|
580 |
+
|
581 |
+
# Append view (or reflection) direction encoding to bottleneck vector.
|
582 |
+
x.append(dir_enc)
|
583 |
+
|
584 |
+
# Append dot product between normal vectors and view directions.
|
585 |
+
if self.use_n_dot_v:
|
586 |
+
dotprod = torch.sum(
|
587 |
+
normals_to_use * viewdirs[..., None, :], dim=-1, keepdim=True)
|
588 |
+
x.append(dotprod)
|
589 |
+
|
590 |
+
# Concatenate bottleneck, directional encoding, and GLO.
|
591 |
+
x = torch.cat(x, dim=-1)
|
592 |
+
# Output of the second part of MLP.
|
593 |
+
inputs = x
|
594 |
+
for i in range(self.net_depth_viewdirs):
|
595 |
+
x = self.get_submodule(f"lin_second_stage_{i}")(x)
|
596 |
+
x = F.relu(x)
|
597 |
+
if i == self.skip_layer_dir:
|
598 |
+
x = torch.cat([x, inputs], dim=-1)
|
599 |
+
# If using diffuse/specular colors, then `rgb` is treated as linear
|
600 |
+
# specular color. Otherwise it's treated as the color itself.
|
601 |
+
rgb = torch.sigmoid(self.rgb_premultiplier *
|
602 |
+
self.rgb_layer(x) +
|
603 |
+
self.rgb_bias)
|
604 |
+
|
605 |
+
if self.use_diffuse_color:
|
606 |
+
# Initialize linear diffuse color around 0.25, so that the combined
|
607 |
+
# linear color is initialized around 0.5.
|
608 |
+
diffuse_linear = torch.sigmoid(raw_rgb_diffuse - np.log(3.0))
|
609 |
+
if self.use_specular_tint:
|
610 |
+
specular_linear = tint * rgb
|
611 |
+
else:
|
612 |
+
specular_linear = 0.5 * rgb
|
613 |
+
|
614 |
+
# Combine specular and diffuse components and tone map to sRGB.
|
615 |
+
rgb = torch.clip(image.linear_to_srgb(specular_linear + diffuse_linear), 0.0, 1.0)
|
616 |
+
|
617 |
+
# Apply padding, mapping color to [-rgb_padding, 1+rgb_padding].
|
618 |
+
rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding
|
619 |
+
|
620 |
+
return dict(
|
621 |
+
coord=means_contract,
|
622 |
+
density=density,
|
623 |
+
rgb=rgb,
|
624 |
+
raw_grad_density=raw_grad_density,
|
625 |
+
grad_pred=grad_pred,
|
626 |
+
normals=normals,
|
627 |
+
normals_pred=normals_pred,
|
628 |
+
roughness=roughness,
|
629 |
+
)
|
630 |
+
|
631 |
+
|
632 |
+
@gin.configurable
|
633 |
+
class NerfMLP(MLP):
|
634 |
+
pass
|
635 |
+
|
636 |
+
|
637 |
+
@gin.configurable
|
638 |
+
class PropMLP(MLP):
|
639 |
+
pass
|
640 |
+
|
641 |
+
|
642 |
+
@torch.no_grad()
|
643 |
+
def render_image(model,
|
644 |
+
accelerator: accelerate.Accelerator,
|
645 |
+
batch,
|
646 |
+
rand,
|
647 |
+
train_frac,
|
648 |
+
config,
|
649 |
+
verbose=True,
|
650 |
+
return_weights=False):
|
651 |
+
"""Render all the pixels of an image (in test mode).
|
652 |
+
|
653 |
+
Args:
|
654 |
+
render_fn: function, jit-ed render function mapping (rand, batch) -> pytree.
|
655 |
+
accelerator: used for DDP.
|
656 |
+
batch: a `Rays` pytree, the rays to be rendered.
|
657 |
+
rand: if random
|
658 |
+
config: A Config class.
|
659 |
+
|
660 |
+
Returns:
|
661 |
+
rgb: rendered color image.
|
662 |
+
disp: rendered disparity image.
|
663 |
+
acc: rendered accumulated weights per pixel.
|
664 |
+
"""
|
665 |
+
model.eval()
|
666 |
+
|
667 |
+
height, width = batch['origins'].shape[:2]
|
668 |
+
num_rays = height * width
|
669 |
+
batch = {k: v.reshape((num_rays, -1)) for k, v in batch.items() if v is not None}
|
670 |
+
|
671 |
+
global_rank = accelerator.process_index
|
672 |
+
chunks = []
|
673 |
+
idx0s = tqdm(range(0, num_rays, config.render_chunk_size),
|
674 |
+
desc="Rendering chunk", leave=False,
|
675 |
+
disable=not (accelerator.is_main_process and verbose))
|
676 |
+
|
677 |
+
for i_chunk, idx0 in enumerate(idx0s):
|
678 |
+
chunk_batch = tree_map(lambda r: r[idx0:idx0 + config.render_chunk_size], batch)
|
679 |
+
actual_chunk_size = chunk_batch['origins'].shape[0]
|
680 |
+
rays_remaining = actual_chunk_size % accelerator.num_processes
|
681 |
+
if rays_remaining != 0:
|
682 |
+
padding = accelerator.num_processes - rays_remaining
|
683 |
+
chunk_batch = tree_map(lambda v: torch.cat([v, torch.zeros_like(v[-padding:])], dim=0), chunk_batch)
|
684 |
+
else:
|
685 |
+
padding = 0
|
686 |
+
# After padding the number of chunk_rays is always divisible by host_count.
|
687 |
+
rays_per_host = chunk_batch['origins'].shape[0] // accelerator.num_processes
|
688 |
+
start, stop = global_rank * rays_per_host, (global_rank + 1) * rays_per_host
|
689 |
+
chunk_batch = tree_map(lambda r: r[start:stop], chunk_batch)
|
690 |
+
|
691 |
+
with accelerator.autocast():
|
692 |
+
chunk_renderings, ray_history = model(rand,
|
693 |
+
chunk_batch,
|
694 |
+
train_frac=train_frac,
|
695 |
+
compute_extras=True,
|
696 |
+
zero_glo=True)
|
697 |
+
|
698 |
+
gather = lambda v: accelerator.gather(v.contiguous())[:-padding] \
|
699 |
+
if padding > 0 else accelerator.gather(v.contiguous())
|
700 |
+
# Unshard the renderings.
|
701 |
+
chunk_renderings = tree_map(gather, chunk_renderings)
|
702 |
+
|
703 |
+
# Gather the final pass for 2D buffers and all passes for ray bundles.
|
704 |
+
chunk_rendering = chunk_renderings[-1]
|
705 |
+
for k in chunk_renderings[0]:
|
706 |
+
if k.startswith('ray_'):
|
707 |
+
chunk_rendering[k] = [r[k] for r in chunk_renderings]
|
708 |
+
|
709 |
+
if return_weights:
|
710 |
+
chunk_rendering['weights'] = gather(ray_history[-1]['weights'])
|
711 |
+
chunk_rendering['coord'] = gather(ray_history[-1]['coord'])
|
712 |
+
chunks.append(chunk_rendering)
|
713 |
+
|
714 |
+
# Concatenate all chunks within each leaf of a single pytree.
|
715 |
+
rendering = {}
|
716 |
+
for k in chunks[0].keys():
|
717 |
+
if isinstance(chunks[0][k], list):
|
718 |
+
rendering[k] = []
|
719 |
+
for i in range(len(chunks[0][k])):
|
720 |
+
rendering[k].append(torch.cat([item[k][i] for item in chunks]))
|
721 |
+
else:
|
722 |
+
rendering[k] = torch.cat([item[k] for item in chunks])
|
723 |
+
|
724 |
+
for k, z in rendering.items():
|
725 |
+
if not k.startswith('ray_'):
|
726 |
+
# Reshape 2D buffers into original image shape.
|
727 |
+
rendering[k] = z.reshape((height, width) + z.shape[1:])
|
728 |
+
|
729 |
+
# After all of the ray bundles have been concatenated together, extract a
|
730 |
+
# new random bundle (deterministically) from the concatenation that is the
|
731 |
+
# same size as one of the individual bundles.
|
732 |
+
keys = [k for k in rendering if k.startswith('ray_')]
|
733 |
+
if keys:
|
734 |
+
num_rays = rendering[keys[0]][0].shape[0]
|
735 |
+
ray_idx = torch.randperm(num_rays)
|
736 |
+
ray_idx = ray_idx[:config.vis_num_rays]
|
737 |
+
for k in keys:
|
738 |
+
rendering[k] = [r[ray_idx] for r in rendering[k]]
|
739 |
+
model.train()
|
740 |
+
return rendering
|
internal/pycolmap/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.sw*
|
internal/pycolmap/LICENSE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 True Price, UNC Chapel Hill
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
internal/pycolmap/README.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pycolmap
|
2 |
+
Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions.
|
3 |
+
|
4 |
+
This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output.
|
internal/pycolmap/pycolmap/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from camera import Camera
|
2 |
+
from database import COLMAPDatabase
|
3 |
+
from image import Image
|
4 |
+
from scene_manager import SceneManager
|
5 |
+
from rotation import Quaternion, DualQuaternion
|
internal/pycolmap/pycolmap/camera.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: True Price <jtprice at cs.unc.edu>
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from scipy.optimize import root
|
6 |
+
|
7 |
+
|
8 |
+
#-------------------------------------------------------------------------------
|
9 |
+
#
|
10 |
+
# camera distortion functions for arrays of size (..., 2)
|
11 |
+
#
|
12 |
+
#-------------------------------------------------------------------------------
|
13 |
+
|
14 |
+
def simple_radial_distortion(camera, x):
|
15 |
+
return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True))
|
16 |
+
|
17 |
+
def radial_distortion(camera, x):
|
18 |
+
r_sq = np.square(x).sum(axis=-1, keepdims=True)
|
19 |
+
return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq))
|
20 |
+
|
21 |
+
def opencv_distortion(camera, x):
|
22 |
+
x_sq = np.square(x)
|
23 |
+
xy = np.prod(x, axis=-1, keepdims=True)
|
24 |
+
r_sq = x_sq.sum(axis=-1, keepdims=True)
|
25 |
+
|
26 |
+
return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate((
|
27 |
+
2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq),
|
28 |
+
camera.p1 * (r_sq + 2. * y_sq) + 2. * camera.p2 * xy),
|
29 |
+
axis=-1)
|
30 |
+
|
31 |
+
|
32 |
+
#-------------------------------------------------------------------------------
|
33 |
+
#
|
34 |
+
# Camera
|
35 |
+
#
|
36 |
+
#-------------------------------------------------------------------------------
|
37 |
+
|
38 |
+
class Camera:
|
39 |
+
@staticmethod
|
40 |
+
def GetNumParams(type_):
|
41 |
+
if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
|
42 |
+
return 3
|
43 |
+
if type_ == 1 or type_ == 'PINHOLE':
|
44 |
+
return 4
|
45 |
+
if type_ == 2 or type_ == 'SIMPLE_RADIAL':
|
46 |
+
return 4
|
47 |
+
if type_ == 3 or type_ == 'RADIAL':
|
48 |
+
return 5
|
49 |
+
if type_ == 4 or type_ == 'OPENCV':
|
50 |
+
return 8
|
51 |
+
#if type_ == 5 or type_ == 'OPENCV_FISHEYE':
|
52 |
+
# return 8
|
53 |
+
#if type_ == 6 or type_ == 'FULL_OPENCV':
|
54 |
+
# return 12
|
55 |
+
#if type_ == 7 or type_ == 'FOV':
|
56 |
+
# return 5
|
57 |
+
#if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE':
|
58 |
+
# return 4
|
59 |
+
#if type_ == 9 or type_ == 'RADIAL_FISHEYE':
|
60 |
+
# return 5
|
61 |
+
#if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE':
|
62 |
+
# return 12
|
63 |
+
|
64 |
+
# TODO: not supporting other camera types, currently
|
65 |
+
raise Exception('Camera type not supported')
|
66 |
+
|
67 |
+
|
68 |
+
#---------------------------------------------------------------------------
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def GetNameFromType(type_):
|
72 |
+
if type_ == 0: return 'SIMPLE_PINHOLE'
|
73 |
+
if type_ == 1: return 'PINHOLE'
|
74 |
+
if type_ == 2: return 'SIMPLE_RADIAL'
|
75 |
+
if type_ == 3: return 'RADIAL'
|
76 |
+
if type_ == 4: return 'OPENCV'
|
77 |
+
#if type_ == 5: return 'OPENCV_FISHEYE'
|
78 |
+
#if type_ == 6: return 'FULL_OPENCV'
|
79 |
+
#if type_ == 7: return 'FOV'
|
80 |
+
#if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE'
|
81 |
+
#if type_ == 9: return 'RADIAL_FISHEYE'
|
82 |
+
#if type_ == 10: return 'THIN_PRISM_FISHEYE'
|
83 |
+
|
84 |
+
raise Exception('Camera type not supported')
|
85 |
+
|
86 |
+
|
87 |
+
#---------------------------------------------------------------------------
|
88 |
+
|
89 |
+
def __init__(self, type_, width_, height_, params):
|
90 |
+
self.width = width_
|
91 |
+
self.height = height_
|
92 |
+
|
93 |
+
if type_ == 0 or type_ == 'SIMPLE_PINHOLE':
|
94 |
+
self.fx, self.cx, self.cy = params
|
95 |
+
self.fy = self.fx
|
96 |
+
self.distortion_func = None
|
97 |
+
self.camera_type = 0
|
98 |
+
|
99 |
+
elif type_ == 1 or type_ == 'PINHOLE':
|
100 |
+
self.fx, self.fy, self.cx, self.cy = params
|
101 |
+
self.distortion_func = None
|
102 |
+
self.camera_type = 1
|
103 |
+
|
104 |
+
elif type_ == 2 or type_ == 'SIMPLE_RADIAL':
|
105 |
+
self.fx, self.cx, self.cy, self.k1 = params
|
106 |
+
self.fy = self.fx
|
107 |
+
self.distortion_func = simple_radial_distortion
|
108 |
+
self.camera_type = 2
|
109 |
+
|
110 |
+
elif type_ == 3 or type_ == 'RADIAL':
|
111 |
+
self.fx, self.cx, self.cy, self.k1, self.k2 = params
|
112 |
+
self.fy = self.fx
|
113 |
+
self.distortion_func = radial_distortion
|
114 |
+
self.camera_type = 3
|
115 |
+
|
116 |
+
elif type_ == 4 or type_ == 'OPENCV':
|
117 |
+
self.fx, self.fy, self.cx, self.cy = params[:4]
|
118 |
+
self.k1, self.k2, self.p1, self.p2 = params[4:]
|
119 |
+
self.distortion_func = opencv_distortion
|
120 |
+
self.camera_type = 4
|
121 |
+
|
122 |
+
else:
|
123 |
+
raise Exception('Camera type not supported')
|
124 |
+
|
125 |
+
|
126 |
+
#---------------------------------------------------------------------------
|
127 |
+
|
128 |
+
def __str__(self):
|
129 |
+
s = (self.GetNameFromType(self.camera_type) +
|
130 |
+
' {} {} {}'.format(self.width, self.height, self.fx))
|
131 |
+
|
132 |
+
if self.camera_type in (1, 4): # PINHOLE, OPENCV
|
133 |
+
s += ' {}'.format(self.fy)
|
134 |
+
|
135 |
+
s += ' {} {}'.format(self.cx, self.cy)
|
136 |
+
|
137 |
+
if self.camera_type == 2: # SIMPLE_RADIAL
|
138 |
+
s += ' {}'.format(self.k1)
|
139 |
+
|
140 |
+
elif self.camera_type == 3: # RADIAL
|
141 |
+
s += ' {} {}'.format(self.k1, self.k2)
|
142 |
+
|
143 |
+
elif self.camera_type == 4: # OPENCV
|
144 |
+
s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2)
|
145 |
+
|
146 |
+
return s
|
147 |
+
|
148 |
+
|
149 |
+
#---------------------------------------------------------------------------
|
150 |
+
|
151 |
+
# return the camera parameters in the same order as the colmap output format
|
152 |
+
def get_params(self):
|
153 |
+
if self.camera_type == 0:
|
154 |
+
return np.array((self.fx, self.cx, self.cy))
|
155 |
+
if self.camera_type == 1:
|
156 |
+
return np.array((self.fx, self.fy, self.cx, self.cy))
|
157 |
+
if self.camera_type == 2:
|
158 |
+
return np.array((self.fx, self.cx, self.cy, self.k1))
|
159 |
+
if self.camera_type == 3:
|
160 |
+
return np.array((self.fx, self.cx, self.cy, self.k1, self.k2))
|
161 |
+
if self.camera_type == 4:
|
162 |
+
return np.array((self.fx, self.fy, self.cx, self.cy, self.k1,
|
163 |
+
self.k2, self.p1, self.p2))
|
164 |
+
|
165 |
+
|
166 |
+
#---------------------------------------------------------------------------
|
167 |
+
|
168 |
+
def get_camera_matrix(self):
|
169 |
+
return np.array(
|
170 |
+
((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1)))
|
171 |
+
|
172 |
+
def get_inverse_camera_matrix(self):
|
173 |
+
return np.array(
|
174 |
+
((1. / self.fx, 0, -self.cx / self.fx),
|
175 |
+
(0, 1. / self.fy, -self.cy / self.fy),
|
176 |
+
(0, 0, 1)))
|
177 |
+
|
178 |
+
@property
|
179 |
+
def K(self):
|
180 |
+
return self.get_camera_matrix()
|
181 |
+
|
182 |
+
@property
|
183 |
+
def K_inv(self):
|
184 |
+
return self.get_inverse_camera_matrix()
|
185 |
+
|
186 |
+
#---------------------------------------------------------------------------
|
187 |
+
|
188 |
+
# return the inverse camera matrix
|
189 |
+
def get_inv_camera_matrix(self):
|
190 |
+
inv_fx, inv_fy = 1. / self.fx, 1. / self.fy
|
191 |
+
return np.array(((inv_fx, 0, -inv_fx * self.cx),
|
192 |
+
(0, inv_fy, -inv_fy * self.cy),
|
193 |
+
(0, 0, 1)))
|
194 |
+
|
195 |
+
|
196 |
+
#---------------------------------------------------------------------------
|
197 |
+
|
198 |
+
# return an (x, y) pixel coordinate grid for this camera
|
199 |
+
def get_image_grid(self):
|
200 |
+
xmin = (0.5 - self.cx) / self.fx
|
201 |
+
xmax = (self.width - 0.5 - self.cx) / self.fx
|
202 |
+
ymin = (0.5 - self.cy) / self.fy
|
203 |
+
ymax = (self.height - 0.5 - self.cy) / self.fy
|
204 |
+
return np.meshgrid(np.linspace(xmin, xmax, self.width),
|
205 |
+
np.linspace(ymin, ymax, self.height))
|
206 |
+
|
207 |
+
|
208 |
+
#---------------------------------------------------------------------------
|
209 |
+
|
210 |
+
# x: array of shape (N,2) or (2,)
|
211 |
+
# normalized: False if the input points are in pixel coordinates
|
212 |
+
# denormalize: True if the points should be put back into pixel coordinates
|
213 |
+
def distort_points(self, x, normalized=True, denormalize=True):
|
214 |
+
x = np.atleast_2d(x)
|
215 |
+
|
216 |
+
# put the points into normalized camera coordinates
|
217 |
+
if not normalized:
|
218 |
+
x -= np.array([[self.cx, self.cy]])
|
219 |
+
x /= np.array([[self.fx, self.fy]])
|
220 |
+
|
221 |
+
# distort, if necessary
|
222 |
+
if self.distortion_func is not None:
|
223 |
+
x = self.distortion_func(self, x)
|
224 |
+
|
225 |
+
if denormalize:
|
226 |
+
x *= np.array([[self.fx, self.fy]])
|
227 |
+
x += np.array([[self.cx, self.cy]])
|
228 |
+
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
#---------------------------------------------------------------------------
|
233 |
+
|
234 |
+
# x: array of shape (N1,N2,...,2), (N,2), or (2,)
|
235 |
+
# normalized: False if the input points are in pixel coordinates
|
236 |
+
# denormalize: True if the points should be put back into pixel coordinates
|
237 |
+
def undistort_points(self, x, normalized=False, denormalize=True):
|
238 |
+
x = np.atleast_2d(x)
|
239 |
+
|
240 |
+
# put the points into normalized camera coordinates
|
241 |
+
if not normalized:
|
242 |
+
x = x - np.array([self.cx, self.cy]) # creates a copy
|
243 |
+
x /= np.array([self.fx, self.fy])
|
244 |
+
|
245 |
+
# undistort, if necessary
|
246 |
+
if self.distortion_func is not None:
|
247 |
+
def objective(xu):
|
248 |
+
return (x - self.distortion_func(self, xu.reshape(*x.shape))
|
249 |
+
).ravel()
|
250 |
+
|
251 |
+
xu = root(objective, x).x.reshape(*x.shape)
|
252 |
+
else:
|
253 |
+
xu = x
|
254 |
+
|
255 |
+
if denormalize:
|
256 |
+
xu *= np.array([[self.fx, self.fy]])
|
257 |
+
xu += np.array([[self.cx, self.cy]])
|
258 |
+
|
259 |
+
return xu
|
internal/pycolmap/pycolmap/database.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sqlite3
|
4 |
+
|
5 |
+
|
6 |
+
#-------------------------------------------------------------------------------
|
7 |
+
# convert SQLite BLOBs to/from numpy arrays
|
8 |
+
|
9 |
+
def array_to_blob(arr):
|
10 |
+
return np.getbuffer(arr)
|
11 |
+
|
12 |
+
def blob_to_array(blob, dtype, shape=(-1,)):
|
13 |
+
return np.frombuffer(blob, dtype).reshape(*shape)
|
14 |
+
|
15 |
+
|
16 |
+
#-------------------------------------------------------------------------------
|
17 |
+
# convert to/from image pair ids
|
18 |
+
|
19 |
+
MAX_IMAGE_ID = 2**31 - 1
|
20 |
+
|
21 |
+
def get_pair_id(image_id1, image_id2):
|
22 |
+
if image_id1 > image_id2:
|
23 |
+
image_id1, image_id2 = image_id2, image_id1
|
24 |
+
return image_id1 * MAX_IMAGE_ID + image_id2
|
25 |
+
|
26 |
+
|
27 |
+
def get_image_ids_from_pair_id(pair_id):
|
28 |
+
image_id2 = pair_id % MAX_IMAGE_ID
|
29 |
+
return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2
|
30 |
+
|
31 |
+
|
32 |
+
#-------------------------------------------------------------------------------
|
33 |
+
# create table commands
|
34 |
+
|
35 |
+
CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
|
36 |
+
camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
37 |
+
model INTEGER NOT NULL,
|
38 |
+
width INTEGER NOT NULL,
|
39 |
+
height INTEGER NOT NULL,
|
40 |
+
params BLOB,
|
41 |
+
prior_focal_length INTEGER NOT NULL)"""
|
42 |
+
|
43 |
+
CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
|
44 |
+
image_id INTEGER PRIMARY KEY NOT NULL,
|
45 |
+
rows INTEGER NOT NULL,
|
46 |
+
cols INTEGER NOT NULL,
|
47 |
+
data BLOB,
|
48 |
+
FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
|
49 |
+
|
50 |
+
CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
|
51 |
+
image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
52 |
+
name TEXT NOT NULL UNIQUE,
|
53 |
+
camera_id INTEGER NOT NULL,
|
54 |
+
prior_qw REAL,
|
55 |
+
prior_qx REAL,
|
56 |
+
prior_qy REAL,
|
57 |
+
prior_qz REAL,
|
58 |
+
prior_tx REAL,
|
59 |
+
prior_ty REAL,
|
60 |
+
prior_tz REAL,
|
61 |
+
CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647),
|
62 |
+
FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))"""
|
63 |
+
|
64 |
+
CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries (
|
65 |
+
pair_id INTEGER PRIMARY KEY NOT NULL,
|
66 |
+
rows INTEGER NOT NULL,
|
67 |
+
cols INTEGER NOT NULL,
|
68 |
+
data BLOB,
|
69 |
+
config INTEGER NOT NULL,
|
70 |
+
F BLOB,
|
71 |
+
E BLOB,
|
72 |
+
H BLOB)"""
|
73 |
+
|
74 |
+
CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
|
75 |
+
image_id INTEGER PRIMARY KEY NOT NULL,
|
76 |
+
rows INTEGER NOT NULL,
|
77 |
+
cols INTEGER NOT NULL,
|
78 |
+
data BLOB,
|
79 |
+
FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""
|
80 |
+
|
81 |
+
CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
|
82 |
+
pair_id INTEGER PRIMARY KEY NOT NULL,
|
83 |
+
rows INTEGER NOT NULL,
|
84 |
+
cols INTEGER NOT NULL,
|
85 |
+
data BLOB)"""
|
86 |
+
|
87 |
+
CREATE_NAME_INDEX = \
|
88 |
+
"CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"
|
89 |
+
|
90 |
+
CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE,
|
91 |
+
CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE,
|
92 |
+
CREATE_MATCHES_TABLE, CREATE_NAME_INDEX])
|
93 |
+
|
94 |
+
|
95 |
+
#-------------------------------------------------------------------------------
|
96 |
+
# functional interface for adding objects
|
97 |
+
|
98 |
+
def add_camera(db, model, width, height, params, prior_focal_length=False,
|
99 |
+
camera_id=None):
|
100 |
+
# TODO: Parameter count checks
|
101 |
+
params = np.asarray(params, np.float64)
|
102 |
+
db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
|
103 |
+
(camera_id, model, width, height, array_to_blob(params),
|
104 |
+
prior_focal_length))
|
105 |
+
|
106 |
+
|
107 |
+
def add_descriptors(db, image_id, descriptors):
|
108 |
+
descriptors = np.ascontiguousarray(descriptors, np.uint8)
|
109 |
+
db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)",
|
110 |
+
(image_id,) + descriptors.shape + (array_to_blob(descriptors),))
|
111 |
+
|
112 |
+
|
113 |
+
def add_image(db, name, camera_id, prior_q=np.zeros(4), prior_t=np.zeros(3),
|
114 |
+
image_id=None):
|
115 |
+
db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
116 |
+
(image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
|
117 |
+
prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
|
118 |
+
|
119 |
+
|
120 |
+
# config: defaults to fundamental matrix
|
121 |
+
def add_inlier_matches(db, image_id1, image_id2, matches, config=2, F=None,
|
122 |
+
E=None, H=None):
|
123 |
+
assert(len(matches.shape) == 2)
|
124 |
+
assert(matches.shape[1] == 2)
|
125 |
+
|
126 |
+
if image_id1 > image_id2:
|
127 |
+
matches = matches[:,::-1]
|
128 |
+
|
129 |
+
if F is not None:
|
130 |
+
F = np.asarray(F, np.float64)
|
131 |
+
if E is not None:
|
132 |
+
E = np.asarray(E, np.float64)
|
133 |
+
if H is not None:
|
134 |
+
H = np.asarray(H, np.float64)
|
135 |
+
|
136 |
+
pair_id = get_pair_id(image_id1, image_id2)
|
137 |
+
matches = np.asarray(matches, np.uint32)
|
138 |
+
db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
139 |
+
(pair_id,) + matches.shape + (array_to_blob(matches), config, F, E, H))
|
140 |
+
|
141 |
+
|
142 |
+
def add_keypoints(db, image_id, keypoints):
|
143 |
+
assert(len(keypoints.shape) == 2)
|
144 |
+
assert(keypoints.shape[1] in [2, 4, 6])
|
145 |
+
|
146 |
+
keypoints = np.asarray(keypoints, np.float32)
|
147 |
+
db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)",
|
148 |
+
(image_id,) + keypoints.shape + (array_to_blob(keypoints),))
|
149 |
+
|
150 |
+
|
151 |
+
# config: defaults to fundamental matrix
|
152 |
+
def add_matches(db, image_id1, image_id2, matches):
|
153 |
+
assert(len(matches.shape) == 2)
|
154 |
+
assert(matches.shape[1] == 2)
|
155 |
+
|
156 |
+
if image_id1 > image_id2:
|
157 |
+
matches = matches[:,::-1]
|
158 |
+
|
159 |
+
pair_id = get_pair_id(image_id1, image_id2)
|
160 |
+
matches = np.asarray(matches, np.uint32)
|
161 |
+
db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)",
|
162 |
+
(pair_id,) + matches.shape + (array_to_blob(matches),))
|
163 |
+
|
164 |
+
|
165 |
+
#-------------------------------------------------------------------------------
|
166 |
+
# simple functional interface
|
167 |
+
|
168 |
+
class COLMAPDatabase(sqlite3.Connection):
|
169 |
+
@staticmethod
|
170 |
+
def connect(database_path):
|
171 |
+
return sqlite3.connect(database_path, factory=COLMAPDatabase)
|
172 |
+
|
173 |
+
|
174 |
+
def __init__(self, *args, **kwargs):
|
175 |
+
super(COLMAPDatabase, self).__init__(*args, **kwargs)
|
176 |
+
|
177 |
+
self.initialize_tables = lambda: self.executescript(CREATE_ALL)
|
178 |
+
|
179 |
+
self.initialize_cameras = \
|
180 |
+
lambda: self.executescript(CREATE_CAMERAS_TABLE)
|
181 |
+
self.initialize_descriptors = \
|
182 |
+
lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
|
183 |
+
self.initialize_images = \
|
184 |
+
lambda: self.executescript(CREATE_IMAGES_TABLE)
|
185 |
+
self.initialize_inlier_matches = \
|
186 |
+
lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE)
|
187 |
+
self.initialize_keypoints = \
|
188 |
+
lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
|
189 |
+
self.initialize_matches = \
|
190 |
+
lambda: self.executescript(CREATE_MATCHES_TABLE)
|
191 |
+
|
192 |
+
self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)
|
193 |
+
|
194 |
+
|
195 |
+
add_camera = add_camera
|
196 |
+
add_descriptors = add_descriptors
|
197 |
+
add_image = add_image
|
198 |
+
add_inlier_matches = add_inlier_matches
|
199 |
+
add_keypoints = add_keypoints
|
200 |
+
add_matches = add_matches
|
201 |
+
|
202 |
+
|
203 |
+
#-------------------------------------------------------------------------------
|
204 |
+
|
205 |
+
def main(args):
|
206 |
+
import os
|
207 |
+
|
208 |
+
if os.path.exists(args.database_path):
|
209 |
+
print("Error: database path already exists -- will not modify it.")
|
210 |
+
exit()
|
211 |
+
|
212 |
+
db = COLMAPDatabase.connect(args.database_path)
|
213 |
+
|
214 |
+
#
|
215 |
+
# for convenience, try creating all the tables upfront
|
216 |
+
#
|
217 |
+
|
218 |
+
db.initialize_tables()
|
219 |
+
|
220 |
+
|
221 |
+
#
|
222 |
+
# create dummy cameras
|
223 |
+
#
|
224 |
+
|
225 |
+
model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.))
|
226 |
+
model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1))
|
227 |
+
|
228 |
+
db.add_camera(model1, w1, h1, params1)
|
229 |
+
db.add_camera(model2, w2, h2, params2)
|
230 |
+
|
231 |
+
|
232 |
+
#
|
233 |
+
# create dummy images
|
234 |
+
#
|
235 |
+
|
236 |
+
db.add_image("image1.png", 0)
|
237 |
+
db.add_image("image2.png", 0)
|
238 |
+
db.add_image("image3.png", 2)
|
239 |
+
db.add_image("image4.png", 2)
|
240 |
+
|
241 |
+
|
242 |
+
#
|
243 |
+
# create dummy keypoints; note that COLMAP supports 2D keypoints (x, y),
|
244 |
+
# 4D keypoints (x, y, theta, scale), and 6D affine keypoints
|
245 |
+
# (x, y, a_11, a_12, a_21, a_22)
|
246 |
+
#
|
247 |
+
|
248 |
+
N = 1000
|
249 |
+
kp1 = np.random.rand(N, 2) * (1024., 768.)
|
250 |
+
kp2 = np.random.rand(N, 2) * (1024., 768.)
|
251 |
+
kp3 = np.random.rand(N, 2) * (1024., 768.)
|
252 |
+
kp4 = np.random.rand(N, 2) * (1024., 768.)
|
253 |
+
|
254 |
+
db.add_keypoints(1, kp1)
|
255 |
+
db.add_keypoints(2, kp2)
|
256 |
+
db.add_keypoints(3, kp3)
|
257 |
+
db.add_keypoints(4, kp4)
|
258 |
+
|
259 |
+
|
260 |
+
#
|
261 |
+
# create dummy matches
|
262 |
+
#
|
263 |
+
|
264 |
+
M = 50
|
265 |
+
m12 = np.random.randint(N, size=(M, 2))
|
266 |
+
m23 = np.random.randint(N, size=(M, 2))
|
267 |
+
m34 = np.random.randint(N, size=(M, 2))
|
268 |
+
|
269 |
+
db.add_matches(1, 2, m12)
|
270 |
+
db.add_matches(2, 3, m23)
|
271 |
+
db.add_matches(3, 4, m34)
|
272 |
+
|
273 |
+
|
274 |
+
#
|
275 |
+
# check cameras
|
276 |
+
#
|
277 |
+
|
278 |
+
rows = db.execute("SELECT * FROM cameras")
|
279 |
+
|
280 |
+
camera_id, model, width, height, params, prior = next(rows)
|
281 |
+
params = blob_to_array(params, np.float32)
|
282 |
+
assert model == model1 and width == w1 and height == h1
|
283 |
+
assert np.allclose(params, params1)
|
284 |
+
|
285 |
+
camera_id, model, width, height, params, prior = next(rows)
|
286 |
+
params = blob_to_array(params, np.float32)
|
287 |
+
assert model == model2 and width == w2 and height == h2
|
288 |
+
assert np.allclose(params, params2)
|
289 |
+
|
290 |
+
|
291 |
+
#
|
292 |
+
# check keypoints
|
293 |
+
#
|
294 |
+
|
295 |
+
kps = dict(
|
296 |
+
(image_id, blob_to_array(data, np.float32, (-1, 2)))
|
297 |
+
for image_id, data in db.execute(
|
298 |
+
"SELECT image_id, data FROM keypoints"))
|
299 |
+
|
300 |
+
assert np.allclose(kps[1], kp1)
|
301 |
+
assert np.allclose(kps[2], kp2)
|
302 |
+
assert np.allclose(kps[3], kp3)
|
303 |
+
assert np.allclose(kps[4], kp4)
|
304 |
+
|
305 |
+
|
306 |
+
#
|
307 |
+
# check matches
|
308 |
+
#
|
309 |
+
|
310 |
+
pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]]
|
311 |
+
|
312 |
+
matches = dict(
|
313 |
+
(get_image_ids_from_pair_id(pair_id),
|
314 |
+
blob_to_array(data, np.uint32, (-1, 2)))
|
315 |
+
for pair_id, data in db.execute("SELECT pair_id, data FROM matches"))
|
316 |
+
|
317 |
+
assert np.all(matches[(1, 2)] == m12)
|
318 |
+
assert np.all(matches[(2, 3)] == m23)
|
319 |
+
assert np.all(matches[(3, 4)] == m34)
|
320 |
+
|
321 |
+
#
|
322 |
+
# clean up
|
323 |
+
#
|
324 |
+
|
325 |
+
db.close()
|
326 |
+
os.remove(args.database_path)
|
327 |
+
|
328 |
+
#-------------------------------------------------------------------------------
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
import argparse
|
332 |
+
|
333 |
+
parser = argparse.ArgumentParser(
|
334 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
335 |
+
|
336 |
+
parser.add_argument("--database_path", type=str, default="database.db")
|
337 |
+
|
338 |
+
args = parser.parse_args()
|
339 |
+
|
340 |
+
main(args)
|
internal/pycolmap/pycolmap/image.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: True Price <jtprice at cs.unc.edu>
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
#-------------------------------------------------------------------------------
|
6 |
+
#
|
7 |
+
# Image
|
8 |
+
#
|
9 |
+
#-------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
class Image:
|
12 |
+
def __init__(self, name_, camera_id_, q_, tvec_):
|
13 |
+
self.name = name_
|
14 |
+
self.camera_id = camera_id_
|
15 |
+
self.q = q_
|
16 |
+
self.tvec = tvec_
|
17 |
+
|
18 |
+
self.points2D = np.empty((0, 2), dtype=np.float64)
|
19 |
+
self.point3D_ids = np.empty((0,), dtype=np.uint64)
|
20 |
+
|
21 |
+
#---------------------------------------------------------------------------
|
22 |
+
|
23 |
+
def R(self):
|
24 |
+
return self.q.ToR()
|
25 |
+
|
26 |
+
#---------------------------------------------------------------------------
|
27 |
+
|
28 |
+
def C(self):
|
29 |
+
return -self.R().T.dot(self.tvec)
|
30 |
+
|
31 |
+
#---------------------------------------------------------------------------
|
32 |
+
|
33 |
+
@property
|
34 |
+
def t(self):
|
35 |
+
return self.tvec
|
internal/pycolmap/pycolmap/rotation.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: True Price <jtprice at cs.unc.edu>
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
#-------------------------------------------------------------------------------
|
6 |
+
#
|
7 |
+
# Axis-Angle Functions
|
8 |
+
#
|
9 |
+
#-------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
# returns the cross product matrix representation of a 3-vector v
|
12 |
+
def cross_prod_matrix(v):
|
13 |
+
return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.)))
|
14 |
+
|
15 |
+
#-------------------------------------------------------------------------------
|
16 |
+
|
17 |
+
# www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/
|
18 |
+
# if angle is None, assume ||axis|| == angle, in radians
|
19 |
+
# if angle is not None, assume that axis is a unit vector
|
20 |
+
def axis_angle_to_rotation_matrix(axis, angle=None):
|
21 |
+
if angle is None:
|
22 |
+
angle = np.linalg.norm(axis)
|
23 |
+
if np.abs(angle) > np.finfo('float').eps:
|
24 |
+
axis = axis / angle
|
25 |
+
|
26 |
+
cp_axis = cross_prod_matrix(axis)
|
27 |
+
return np.eye(3) + (
|
28 |
+
np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis))
|
29 |
+
|
30 |
+
#-------------------------------------------------------------------------------
|
31 |
+
|
32 |
+
# after some deliberation, I've decided the easiest way to do this is to use
|
33 |
+
# quaternions as an intermediary
|
34 |
+
def rotation_matrix_to_axis_angle(R):
|
35 |
+
return Quaternion.FromR(R).ToAxisAngle()
|
36 |
+
|
37 |
+
#-------------------------------------------------------------------------------
|
38 |
+
#
|
39 |
+
# Quaternion
|
40 |
+
#
|
41 |
+
#-------------------------------------------------------------------------------
|
42 |
+
|
43 |
+
class Quaternion:
|
44 |
+
# create a quaternion from an existing rotation matrix
|
45 |
+
# euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/
|
46 |
+
@staticmethod
|
47 |
+
def FromR(R):
|
48 |
+
trace = np.trace(R)
|
49 |
+
|
50 |
+
if trace > 0:
|
51 |
+
qw = 0.5 * np.sqrt(1. + trace)
|
52 |
+
qx = (R[2,1] - R[1,2]) * 0.25 / qw
|
53 |
+
qy = (R[0,2] - R[2,0]) * 0.25 / qw
|
54 |
+
qz = (R[1,0] - R[0,1]) * 0.25 / qw
|
55 |
+
elif R[0,0] > R[1,1] and R[0,0] > R[2,2]:
|
56 |
+
s = 2. * np.sqrt(1. + R[0,0] - R[1,1] - R[2,2])
|
57 |
+
qw = (R[2,1] - R[1,2]) / s
|
58 |
+
qx = 0.25 * s
|
59 |
+
qy = (R[0,1] + R[1,0]) / s
|
60 |
+
qz = (R[0,2] + R[2,0]) / s
|
61 |
+
elif R[1,1] > R[2,2]:
|
62 |
+
s = 2. * np.sqrt(1. + R[1,1] - R[0,0] - R[2,2])
|
63 |
+
qw = (R[0,2] - R[2,0]) / s
|
64 |
+
qx = (R[0,1] + R[1,0]) / s
|
65 |
+
qy = 0.25 * s
|
66 |
+
qz = (R[1,2] + R[2,1]) / s
|
67 |
+
else:
|
68 |
+
s = 2. * np.sqrt(1. + R[2,2] - R[0,0] - R[1,1])
|
69 |
+
qw = (R[1,0] - R[0,1]) / s
|
70 |
+
qx = (R[0,2] + R[2,0]) / s
|
71 |
+
qy = (R[1,2] + R[2,1]) / s
|
72 |
+
qz = 0.25 * s
|
73 |
+
|
74 |
+
return Quaternion(np.array((qw, qx, qy, qz)))
|
75 |
+
|
76 |
+
# if angle is None, assume ||axis|| == angle, in radians
|
77 |
+
# if angle is not None, assume that axis is a unit vector
|
78 |
+
@staticmethod
|
79 |
+
def FromAxisAngle(axis, angle=None):
|
80 |
+
if angle is None:
|
81 |
+
angle = np.linalg.norm(axis)
|
82 |
+
if np.abs(angle) > np.finfo('float').eps:
|
83 |
+
axis = axis / angle
|
84 |
+
|
85 |
+
qw = np.cos(0.5 * angle)
|
86 |
+
axis = axis * np.sin(0.5 * angle)
|
87 |
+
|
88 |
+
return Quaternion(np.array((qw, axis[0], axis[1], axis[2])))
|
89 |
+
|
90 |
+
#---------------------------------------------------------------------------
|
91 |
+
|
92 |
+
def __init__(self, q=np.array((1., 0., 0., 0.))):
|
93 |
+
if isinstance(q, Quaternion):
|
94 |
+
self.q = q.q.copy()
|
95 |
+
else:
|
96 |
+
q = np.asarray(q)
|
97 |
+
if q.size == 4:
|
98 |
+
self.q = q.copy()
|
99 |
+
elif q.size == 3: # convert from a 3-vector to a quaternion
|
100 |
+
self.q = np.empty(4)
|
101 |
+
self.q[0], self.q[1:] = 0., q.ravel()
|
102 |
+
else:
|
103 |
+
raise Exception('Input quaternion should be a 3- or 4-vector')
|
104 |
+
|
105 |
+
def __add__(self, other):
|
106 |
+
return Quaternion(self.q + other.q)
|
107 |
+
|
108 |
+
def __iadd__(self, other):
|
109 |
+
self.q += other.q
|
110 |
+
return self
|
111 |
+
|
112 |
+
# conjugation via the ~ operator
|
113 |
+
def __invert__(self):
|
114 |
+
return Quaternion(
|
115 |
+
np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3])))
|
116 |
+
|
117 |
+
# returns: self.q * other.q if other is a Quaternion; otherwise performs
|
118 |
+
# scalar multiplication
|
119 |
+
def __mul__(self, other):
|
120 |
+
if isinstance(other, Quaternion): # quaternion multiplication
|
121 |
+
return Quaternion(np.array((
|
122 |
+
self.q[0] * other.q[0] - self.q[1] * other.q[1] -
|
123 |
+
self.q[2] * other.q[2] - self.q[3] * other.q[3],
|
124 |
+
self.q[0] * other.q[1] + self.q[1] * other.q[0] +
|
125 |
+
self.q[2] * other.q[3] - self.q[3] * other.q[2],
|
126 |
+
self.q[0] * other.q[2] - self.q[1] * other.q[3] +
|
127 |
+
self.q[2] * other.q[0] + self.q[3] * other.q[1],
|
128 |
+
self.q[0] * other.q[3] + self.q[1] * other.q[2] -
|
129 |
+
self.q[2] * other.q[1] + self.q[3] * other.q[0])))
|
130 |
+
else: # scalar multiplication (assumed)
|
131 |
+
return Quaternion(other * self.q)
|
132 |
+
|
133 |
+
def __rmul__(self, other):
|
134 |
+
return self * other
|
135 |
+
|
136 |
+
def __imul__(self, other):
|
137 |
+
self.q[:] = (self * other).q
|
138 |
+
return self
|
139 |
+
|
140 |
+
def __irmul__(self, other):
|
141 |
+
self.q[:] = (self * other).q
|
142 |
+
return self
|
143 |
+
|
144 |
+
def __neg__(self):
|
145 |
+
return Quaternion(-self.q)
|
146 |
+
|
147 |
+
def __sub__(self, other):
|
148 |
+
return Quaternion(self.q - other.q)
|
149 |
+
|
150 |
+
def __isub__(self, other):
|
151 |
+
self.q -= other.q
|
152 |
+
return self
|
153 |
+
|
154 |
+
def __str__(self):
|
155 |
+
return str(self.q)
|
156 |
+
|
157 |
+
def copy(self):
|
158 |
+
return Quaternion(self)
|
159 |
+
|
160 |
+
def dot(self, other):
|
161 |
+
return self.q.dot(other.q)
|
162 |
+
|
163 |
+
# assume the quaternion is nonzero!
|
164 |
+
def inverse(self):
|
165 |
+
return Quaternion((~self).q / self.q.dot(self.q))
|
166 |
+
|
167 |
+
def norm(self):
|
168 |
+
return np.linalg.norm(self.q)
|
169 |
+
|
170 |
+
def normalize(self):
|
171 |
+
self.q /= np.linalg.norm(self.q)
|
172 |
+
return self
|
173 |
+
|
174 |
+
# assume x is a Nx3 numpy array or a numpy 3-vector
|
175 |
+
def rotate_points(self, x):
|
176 |
+
x = np.atleast_2d(x)
|
177 |
+
return x.dot(self.ToR().T)
|
178 |
+
|
179 |
+
# convert to a rotation matrix
|
180 |
+
def ToR(self):
|
181 |
+
return np.eye(3) + 2 * np.array((
|
182 |
+
(-self.q[2] * self.q[2] - self.q[3] * self.q[3],
|
183 |
+
self.q[1] * self.q[2] - self.q[3] * self.q[0],
|
184 |
+
self.q[1] * self.q[3] + self.q[2] * self.q[0]),
|
185 |
+
( self.q[1] * self.q[2] + self.q[3] * self.q[0],
|
186 |
+
-self.q[1] * self.q[1] - self.q[3] * self.q[3],
|
187 |
+
self.q[2] * self.q[3] - self.q[1] * self.q[0]),
|
188 |
+
( self.q[1] * self.q[3] - self.q[2] * self.q[0],
|
189 |
+
self.q[2] * self.q[3] + self.q[1] * self.q[0],
|
190 |
+
-self.q[1] * self.q[1] - self.q[2] * self.q[2])))
|
191 |
+
|
192 |
+
# convert to axis-angle representation, with angle encoded by the length
|
193 |
+
def ToAxisAngle(self):
|
194 |
+
# recall that for axis-angle representation (a, angle), with "a" unit:
|
195 |
+
# q = (cos(angle/2), a * sin(angle/2))
|
196 |
+
# below, for readability, "theta" actually means half of the angle
|
197 |
+
|
198 |
+
sin_sq_theta = self.q[1:].dot(self.q[1:])
|
199 |
+
|
200 |
+
# if theta is non-zero, then we can compute a unique rotation
|
201 |
+
if np.abs(sin_sq_theta) > np.finfo('float').eps:
|
202 |
+
sin_theta = np.sqrt(sin_sq_theta)
|
203 |
+
cos_theta = self.q[0]
|
204 |
+
|
205 |
+
# atan2 is more stable, so we use it to compute theta
|
206 |
+
# note that we multiply by 2 to get the actual angle
|
207 |
+
angle = 2. * (
|
208 |
+
np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else
|
209 |
+
np.arctan2(sin_theta, cos_theta))
|
210 |
+
|
211 |
+
return self.q[1:] * (angle / sin_theta)
|
212 |
+
|
213 |
+
# otherwise, the result is singular, and we avoid dividing by
|
214 |
+
# sin(angle/2) = 0
|
215 |
+
return np.zeros(3)
|
216 |
+
|
217 |
+
# euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler
|
218 |
+
# this assumes the quaternion is non-zero
|
219 |
+
# returns yaw, pitch, roll, with application in that order
|
220 |
+
def ToEulerAngles(self):
|
221 |
+
qsq = self.q**2
|
222 |
+
k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum()
|
223 |
+
|
224 |
+
if (1. - k) < np.finfo('float').eps: # north pole singularity
|
225 |
+
return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0.
|
226 |
+
if (1. + k) < np.finfo('float').eps: # south pole singularity
|
227 |
+
return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0.
|
228 |
+
|
229 |
+
yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]),
|
230 |
+
qsq[0] + qsq[1] - qsq[2] - qsq[3])
|
231 |
+
pitch = np.arcsin(k)
|
232 |
+
roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]),
|
233 |
+
qsq[0] - qsq[1] + qsq[2] - qsq[3])
|
234 |
+
|
235 |
+
return yaw, pitch, roll
|
236 |
+
|
237 |
+
#-------------------------------------------------------------------------------
|
238 |
+
#
|
239 |
+
# DualQuaternion
|
240 |
+
#
|
241 |
+
#-------------------------------------------------------------------------------
|
242 |
+
|
243 |
+
class DualQuaternion:
|
244 |
+
# DualQuaternion from an existing rotation + translation
|
245 |
+
@staticmethod
|
246 |
+
def FromQT(q, t):
|
247 |
+
return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q)
|
248 |
+
|
249 |
+
def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)):
|
250 |
+
self.q0, self.qe = Quaternion(q0), Quaternion(qe)
|
251 |
+
|
252 |
+
def __add__(self, other):
|
253 |
+
return DualQuaternion(self.q0 + other.q0, self.qe + other.qe)
|
254 |
+
|
255 |
+
def __iadd__(self, other):
|
256 |
+
self.q0 += other.q0
|
257 |
+
self.qe += other.qe
|
258 |
+
return self
|
259 |
+
|
260 |
+
# conguation via the ~ operator
|
261 |
+
def __invert__(self):
|
262 |
+
return DualQuaternion(~self.q0, ~self.qe)
|
263 |
+
|
264 |
+
def __mul__(self, other):
|
265 |
+
if isinstance(other, DualQuaternion):
|
266 |
+
return DualQuaternion(
|
267 |
+
self.q0 * other.q0,
|
268 |
+
self.q0 * other.qe + self.qe * other.q0)
|
269 |
+
elif isinstance(other, complex): # multiplication by a dual number
|
270 |
+
return DualQuaternion(
|
271 |
+
self.q0 * other.real,
|
272 |
+
self.q0 * other.imag + self.qe * other.real)
|
273 |
+
else: # scalar multiplication (assumed)
|
274 |
+
return DualQuaternion(other * self.q0, other * self.qe)
|
275 |
+
|
276 |
+
def __rmul__(self, other):
|
277 |
+
return self.__mul__(other)
|
278 |
+
|
279 |
+
def __imul__(self, other):
|
280 |
+
tmp = self * other
|
281 |
+
self.q0, self.qe = tmp.q0, tmp.qe
|
282 |
+
return self
|
283 |
+
|
284 |
+
def __neg__(self):
|
285 |
+
return DualQuaternion(-self.q0, -self.qe)
|
286 |
+
|
287 |
+
def __sub__(self, other):
|
288 |
+
return DualQuaternion(self.q0 - other.q0, self.qe - other.qe)
|
289 |
+
|
290 |
+
def __isub__(self, other):
|
291 |
+
self.q0 -= other.q0
|
292 |
+
self.qe -= other.qe
|
293 |
+
return self
|
294 |
+
|
295 |
+
# q^-1 = q* / ||q||^2
|
296 |
+
# assume that q0 is nonzero!
|
297 |
+
def inverse(self):
|
298 |
+
normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q))
|
299 |
+
inv_len_real = 1. / normsq.real
|
300 |
+
return ~self * complex(
|
301 |
+
inv_len_real, -normsq.imag * inv_len_real * inv_len_real)
|
302 |
+
|
303 |
+
# returns a complex representation of the real and imaginary parts of the norm
|
304 |
+
# assume that q0 is nonzero!
|
305 |
+
def norm(self):
|
306 |
+
q0_norm = self.q0.norm()
|
307 |
+
return complex(q0_norm, self.q0.dot(self.qe) / q0_norm)
|
308 |
+
|
309 |
+
# assume that q0 is nonzero!
|
310 |
+
def normalize(self):
|
311 |
+
# current length is ||q0|| + eps * (<q0, qe> / ||q0||)
|
312 |
+
# writing this as a + eps * b, the inverse is
|
313 |
+
# 1/||q|| = 1/a - eps * b / a^2
|
314 |
+
norm = self.norm()
|
315 |
+
inv_len_real = 1. / norm.real
|
316 |
+
self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real)
|
317 |
+
return self
|
318 |
+
|
319 |
+
# return the translation vector for this dual quaternion
|
320 |
+
def getT(self):
|
321 |
+
return 2 * (self.qe * ~self.q0).q[1:]
|
322 |
+
|
323 |
+
def ToQT(self):
|
324 |
+
return self.q0, self.getT()
|
internal/pycolmap/pycolmap/scene_manager.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: True Price <jtprice at cs.unc.edu>
|
2 |
+
|
3 |
+
import array
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import struct
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
from itertools import combinations
|
10 |
+
|
11 |
+
from camera import Camera
|
12 |
+
from image import Image
|
13 |
+
from rotation import Quaternion
|
14 |
+
|
15 |
+
#-------------------------------------------------------------------------------
|
16 |
+
#
|
17 |
+
# SceneManager
|
18 |
+
#
|
19 |
+
#-------------------------------------------------------------------------------
|
20 |
+
|
21 |
+
class SceneManager:
|
22 |
+
INVALID_POINT3D = np.uint64(-1)
|
23 |
+
|
24 |
+
def __init__(self, colmap_results_folder, image_path=None):
|
25 |
+
self.folder = colmap_results_folder
|
26 |
+
if not self.folder.endswith('/'):
|
27 |
+
self.folder += '/'
|
28 |
+
|
29 |
+
self.image_path = None
|
30 |
+
self.load_colmap_project_file(image_path=image_path)
|
31 |
+
|
32 |
+
self.cameras = OrderedDict()
|
33 |
+
self.images = OrderedDict()
|
34 |
+
self.name_to_image_id = dict()
|
35 |
+
|
36 |
+
self.last_camera_id = 0
|
37 |
+
self.last_image_id = 0
|
38 |
+
|
39 |
+
# Nx3 array of point3D xyz's
|
40 |
+
self.points3D = np.zeros((0, 3))
|
41 |
+
|
42 |
+
# for each element in points3D, stores the id of the point
|
43 |
+
self.point3D_ids = np.empty(0)
|
44 |
+
|
45 |
+
# point3D_id => index in self.points3D
|
46 |
+
self.point3D_id_to_point3D_idx = dict()
|
47 |
+
|
48 |
+
# point3D_id => [(image_id, point2D idx in image)]
|
49 |
+
self.point3D_id_to_images = dict()
|
50 |
+
|
51 |
+
self.point3D_colors = np.zeros((0, 3), dtype=np.uint8)
|
52 |
+
self.point3D_errors = np.zeros(0)
|
53 |
+
|
54 |
+
#---------------------------------------------------------------------------
|
55 |
+
|
56 |
+
def load_colmap_project_file(self, project_file=None, image_path=None):
|
57 |
+
if project_file is None:
|
58 |
+
project_file = self.folder + 'project.ini'
|
59 |
+
|
60 |
+
self.image_path = image_path
|
61 |
+
|
62 |
+
if self.image_path is None:
|
63 |
+
try:
|
64 |
+
with open(project_file, 'r') as f:
|
65 |
+
for line in iter(f.readline, ''):
|
66 |
+
if line.startswith('image_path'):
|
67 |
+
self.image_path = line[11:].strip()
|
68 |
+
break
|
69 |
+
except:
|
70 |
+
pass
|
71 |
+
|
72 |
+
if self.image_path is None:
|
73 |
+
print('Warning: image_path not found for reconstruction')
|
74 |
+
elif not self.image_path.endswith('/'):
|
75 |
+
self.image_path += '/'
|
76 |
+
|
77 |
+
#---------------------------------------------------------------------------
|
78 |
+
|
79 |
+
def load(self):
|
80 |
+
self.load_cameras()
|
81 |
+
self.load_images()
|
82 |
+
self.load_points3D()
|
83 |
+
|
84 |
+
#---------------------------------------------------------------------------
|
85 |
+
|
86 |
+
def load_cameras(self, input_file=None):
|
87 |
+
if input_file is None:
|
88 |
+
input_file = self.folder + 'cameras.bin'
|
89 |
+
if os.path.exists(input_file):
|
90 |
+
self._load_cameras_bin(input_file)
|
91 |
+
else:
|
92 |
+
input_file = self.folder + 'cameras.txt'
|
93 |
+
if os.path.exists(input_file):
|
94 |
+
self._load_cameras_txt(input_file)
|
95 |
+
else:
|
96 |
+
raise IOError('no cameras file found')
|
97 |
+
|
98 |
+
def _load_cameras_bin(self, input_file):
|
99 |
+
self.cameras = OrderedDict()
|
100 |
+
|
101 |
+
with open(input_file, 'rb') as f:
|
102 |
+
num_cameras = struct.unpack('L', f.read(8))[0]
|
103 |
+
|
104 |
+
for _ in range(num_cameras):
|
105 |
+
camera_id, camera_type, w, h = struct.unpack('IiLL', f.read(24))
|
106 |
+
num_params = Camera.GetNumParams(camera_type)
|
107 |
+
params = struct.unpack('d' * num_params, f.read(8 * num_params))
|
108 |
+
self.cameras[camera_id] = Camera(camera_type, w, h, params)
|
109 |
+
self.last_camera_id = max(self.last_camera_id, camera_id)
|
110 |
+
|
111 |
+
def _load_cameras_txt(self, input_file):
|
112 |
+
self.cameras = OrderedDict()
|
113 |
+
|
114 |
+
with open(input_file, 'r') as f:
|
115 |
+
for line in iter(lambda: f.readline().strip(), ''):
|
116 |
+
if not line or line.startswith('#'):
|
117 |
+
continue
|
118 |
+
|
119 |
+
data = line.split()
|
120 |
+
camera_id = int(data[0])
|
121 |
+
self.cameras[camera_id] = Camera(
|
122 |
+
data[1], int(data[2]), int(data[3]), map(float, data[4:]))
|
123 |
+
self.last_camera_id = max(self.last_camera_id, camera_id)
|
124 |
+
|
125 |
+
#---------------------------------------------------------------------------
|
126 |
+
|
127 |
+
def load_images(self, input_file=None):
|
128 |
+
if input_file is None:
|
129 |
+
input_file = self.folder + 'images.bin'
|
130 |
+
if os.path.exists(input_file):
|
131 |
+
self._load_images_bin(input_file)
|
132 |
+
else:
|
133 |
+
input_file = self.folder + 'images.txt'
|
134 |
+
if os.path.exists(input_file):
|
135 |
+
self._load_images_txt(input_file)
|
136 |
+
else:
|
137 |
+
raise IOError('no images file found')
|
138 |
+
|
139 |
+
def _load_images_bin(self, input_file):
|
140 |
+
self.images = OrderedDict()
|
141 |
+
|
142 |
+
with open(input_file, 'rb') as f:
|
143 |
+
num_images = struct.unpack('L', f.read(8))[0]
|
144 |
+
image_struct = struct.Struct('<I 4d 3d I')
|
145 |
+
for _ in range(num_images):
|
146 |
+
data = image_struct.unpack(f.read(image_struct.size))
|
147 |
+
image_id = data[0]
|
148 |
+
q = Quaternion(np.array(data[1:5]))
|
149 |
+
t = np.array(data[5:8])
|
150 |
+
camera_id = data[8]
|
151 |
+
name = b''.join(c for c in iter(lambda: f.read(1), b'\x00')).decode()
|
152 |
+
|
153 |
+
image = Image(name, camera_id, q, t)
|
154 |
+
num_points2D = struct.unpack('Q', f.read(8))[0]
|
155 |
+
|
156 |
+
# Optimized code below.
|
157 |
+
# Read all elements as double first, then convert to array, slice it
|
158 |
+
# into points2d and ids, and convert ids back to unsigned long longs
|
159 |
+
# ('Q'). This is significantly faster than using O(num_points2D) f.read
|
160 |
+
# calls, experiments show >7x improvements in 60 image model, 23s -> 3s.
|
161 |
+
points_array = array.array('d')
|
162 |
+
points_array.fromfile(f, 3 * num_points2D)
|
163 |
+
points_elements = np.array(points_array).reshape((num_points2D, 3))
|
164 |
+
image.points2D = points_elements[:, :2]
|
165 |
+
|
166 |
+
ids_array = array.array('Q')
|
167 |
+
ids_array.frombytes(points_elements[:, 2].tobytes())
|
168 |
+
image.point3D_ids = np.array(ids_array, dtype=np.uint64).reshape(
|
169 |
+
(num_points2D,))
|
170 |
+
|
171 |
+
# automatically remove points without an associated 3D point
|
172 |
+
#mask = (image.point3D_ids != SceneManager.INVALID_POINT3D)
|
173 |
+
#image.points2D = image.points2D[mask]
|
174 |
+
#image.point3D_ids = image.point3D_ids[mask]
|
175 |
+
|
176 |
+
self.images[image_id] = image
|
177 |
+
self.name_to_image_id[image.name] = image_id
|
178 |
+
|
179 |
+
self.last_image_id = max(self.last_image_id, image_id)
|
180 |
+
|
181 |
+
def _load_images_txt(self, input_file):
|
182 |
+
self.images = OrderedDict()
|
183 |
+
|
184 |
+
with open(input_file, 'r') as f:
|
185 |
+
is_camera_description_line = False
|
186 |
+
|
187 |
+
for line in iter(lambda: f.readline().strip(), ''):
|
188 |
+
if not line or line.startswith('#'):
|
189 |
+
continue
|
190 |
+
|
191 |
+
is_camera_description_line = not is_camera_description_line
|
192 |
+
|
193 |
+
data = line.split()
|
194 |
+
|
195 |
+
if is_camera_description_line:
|
196 |
+
image_id = int(data[0])
|
197 |
+
image = Image(data[-1], int(data[-2]),
|
198 |
+
Quaternion(np.array(map(float, data[1:5]))),
|
199 |
+
np.array(map(float, data[5:8])))
|
200 |
+
else:
|
201 |
+
image.points2D = np.array(
|
202 |
+
[map(float, data[::3]), map(float, data[1::3])]).T
|
203 |
+
image.point3D_ids = np.array(map(np.uint64, data[2::3]))
|
204 |
+
|
205 |
+
# automatically remove points without an associated 3D point
|
206 |
+
#mask = (image.point3D_ids != SceneManager.INVALID_POINT3D)
|
207 |
+
#image.points2D = image.points2D[mask]
|
208 |
+
#image.point3D_ids = image.point3D_ids[mask]
|
209 |
+
|
210 |
+
self.images[image_id] = image
|
211 |
+
self.name_to_image_id[image.name] = image_id
|
212 |
+
|
213 |
+
self.last_image_id = max(self.last_image_id, image_id)
|
214 |
+
|
215 |
+
#---------------------------------------------------------------------------
|
216 |
+
|
217 |
+
def load_points3D(self, input_file=None):
|
218 |
+
if input_file is None:
|
219 |
+
input_file = self.folder + 'points3D.bin'
|
220 |
+
if os.path.exists(input_file):
|
221 |
+
self._load_points3D_bin(input_file)
|
222 |
+
else:
|
223 |
+
input_file = self.folder + 'points3D.txt'
|
224 |
+
if os.path.exists(input_file):
|
225 |
+
self._load_points3D_txt(input_file)
|
226 |
+
else:
|
227 |
+
raise IOError('no points3D file found')
|
228 |
+
|
229 |
+
def _load_points3D_bin(self, input_file):
|
230 |
+
with open(input_file, 'rb') as f:
|
231 |
+
num_points3D = struct.unpack('L', f.read(8))[0]
|
232 |
+
|
233 |
+
self.points3D = np.empty((num_points3D, 3))
|
234 |
+
self.point3D_ids = np.empty(num_points3D, dtype=np.uint64)
|
235 |
+
self.point3D_colors = np.empty((num_points3D, 3), dtype=np.uint8)
|
236 |
+
self.point3D_id_to_point3D_idx = dict()
|
237 |
+
self.point3D_id_to_images = dict()
|
238 |
+
self.point3D_errors = np.empty(num_points3D)
|
239 |
+
|
240 |
+
data_struct = struct.Struct('<Q 3d 3B d Q')
|
241 |
+
|
242 |
+
for i in range(num_points3D):
|
243 |
+
data = data_struct.unpack(f.read(data_struct.size))
|
244 |
+
self.point3D_ids[i] = data[0]
|
245 |
+
self.points3D[i] = data[1:4]
|
246 |
+
self.point3D_colors[i] = data[4:7]
|
247 |
+
self.point3D_errors[i] = data[7]
|
248 |
+
track_len = data[8]
|
249 |
+
|
250 |
+
self.point3D_id_to_point3D_idx[self.point3D_ids[i]] = i
|
251 |
+
|
252 |
+
data = struct.unpack(f'{2*track_len}I', f.read(2 * track_len * 4))
|
253 |
+
|
254 |
+
self.point3D_id_to_images[self.point3D_ids[i]] = \
|
255 |
+
np.array(data, dtype=np.uint32).reshape(track_len, 2)
|
256 |
+
|
257 |
+
def _load_points3D_txt(self, input_file):
|
258 |
+
self.points3D = []
|
259 |
+
self.point3D_ids = []
|
260 |
+
self.point3D_colors = []
|
261 |
+
self.point3D_id_to_point3D_idx = dict()
|
262 |
+
self.point3D_id_to_images = dict()
|
263 |
+
self.point3D_errors = []
|
264 |
+
|
265 |
+
with open(input_file, 'r') as f:
|
266 |
+
for line in iter(lambda: f.readline().strip(), ''):
|
267 |
+
if not line or line.startswith('#'):
|
268 |
+
continue
|
269 |
+
|
270 |
+
data = line.split()
|
271 |
+
point3D_id = np.uint64(data[0])
|
272 |
+
|
273 |
+
self.point3D_ids.append(point3D_id)
|
274 |
+
self.point3D_id_to_point3D_idx[point3D_id] = len(self.points3D)
|
275 |
+
self.points3D.append(map(np.float64, data[1:4]))
|
276 |
+
self.point3D_colors.append(map(np.uint8, data[4:7]))
|
277 |
+
self.point3D_errors.append(np.float64(data[7]))
|
278 |
+
|
279 |
+
# load (image id, point2D idx) pairs
|
280 |
+
self.point3D_id_to_images[point3D_id] = \
|
281 |
+
np.array(map(np.uint32, data[8:])).reshape(-1, 2)
|
282 |
+
|
283 |
+
self.points3D = np.array(self.points3D)
|
284 |
+
self.point3D_ids = np.array(self.point3D_ids)
|
285 |
+
self.point3D_colors = np.array(self.point3D_colors)
|
286 |
+
self.point3D_errors = np.array(self.point3D_errors)
|
287 |
+
|
288 |
+
#---------------------------------------------------------------------------
|
289 |
+
|
290 |
+
def save(self, output_folder, binary=True):
|
291 |
+
self.save_cameras(output_folder, binary=binary)
|
292 |
+
self.save_images(output_folder, binary=binary)
|
293 |
+
self.save_points3D(output_folder, binary=binary)
|
294 |
+
|
295 |
+
#---------------------------------------------------------------------------
|
296 |
+
|
297 |
+
def save_cameras(self, output_folder, output_file=None, binary=True):
|
298 |
+
if not os.path.exists(output_folder):
|
299 |
+
os.makedirs(output_folder)
|
300 |
+
|
301 |
+
if output_file is None:
|
302 |
+
output_file = 'cameras.bin' if binary else 'cameras.txt'
|
303 |
+
|
304 |
+
output_file = os.path.join(output_folder, output_file)
|
305 |
+
|
306 |
+
if binary:
|
307 |
+
self._save_cameras_bin(output_file)
|
308 |
+
else:
|
309 |
+
self._save_cameras_txt(output_file)
|
310 |
+
|
311 |
+
def _save_cameras_bin(self, output_file):
|
312 |
+
with open(output_file, 'wb') as fid:
|
313 |
+
fid.write(struct.pack('L', len(self.cameras)))
|
314 |
+
|
315 |
+
camera_struct = struct.Struct('IiLL')
|
316 |
+
|
317 |
+
for camera_id, camera in sorted(self.cameras.iteritems()):
|
318 |
+
fid.write(camera_struct.pack(
|
319 |
+
camera_id, camera.camera_type, camera.width, camera.height))
|
320 |
+
# TODO (True): should move this into the Camera class
|
321 |
+
fid.write(camera.get_params().tobytes())
|
322 |
+
|
323 |
+
def _save_cameras_txt(self, output_file):
|
324 |
+
with open(output_file, 'w') as fid:
|
325 |
+
print>>fid, '# Camera list with one line of data per camera:'
|
326 |
+
print>>fid, '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]'
|
327 |
+
print>>fid, '# Number of cameras:', len(self.cameras)
|
328 |
+
|
329 |
+
for camera_id, camera in sorted(self.cameras.iteritems()):
|
330 |
+
print>>fid, camera_id, camera
|
331 |
+
|
332 |
+
#---------------------------------------------------------------------------
|
333 |
+
|
334 |
+
def save_images(self, output_folder, output_file=None, binary=True):
|
335 |
+
if not os.path.exists(output_folder):
|
336 |
+
os.makedirs(output_folder)
|
337 |
+
|
338 |
+
if output_file is None:
|
339 |
+
output_file = 'images.bin' if binary else 'images.txt'
|
340 |
+
|
341 |
+
output_file = os.path.join(output_folder, output_file)
|
342 |
+
|
343 |
+
if binary:
|
344 |
+
self._save_images_bin(output_file)
|
345 |
+
else:
|
346 |
+
self._save_images_txt(output_file)
|
347 |
+
|
348 |
+
def _save_images_bin(self, output_file):
|
349 |
+
with open(output_file, 'wb') as fid:
|
350 |
+
fid.write(struct.pack('L', len(self.images)))
|
351 |
+
|
352 |
+
for image_id, image in self.images.iteritems():
|
353 |
+
fid.write(struct.pack('I', image_id))
|
354 |
+
fid.write(image.q.q.tobytes())
|
355 |
+
fid.write(image.tvec.tobytes())
|
356 |
+
fid.write(struct.pack('I', image.camera_id))
|
357 |
+
fid.write(image.name + '\0')
|
358 |
+
fid.write(struct.pack('L', len(image.points2D)))
|
359 |
+
data = np.rec.fromarrays(
|
360 |
+
(image.points2D[:,0], image.points2D[:,1], image.point3D_ids))
|
361 |
+
fid.write(data.tobytes())
|
362 |
+
|
363 |
+
def _save_images_txt(self, output_file):
|
364 |
+
with open(output_file, 'w') as fid:
|
365 |
+
print>>fid, '# Image list with two lines of data per image:'
|
366 |
+
print>>fid, '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME'
|
367 |
+
print>>fid, '# POINTS2D[] as (X, Y, POINT3D_ID)'
|
368 |
+
print>>fid, '# Number of images: {},'.format(len(self.images)),
|
369 |
+
print>>fid, 'mean observations per image: unknown'
|
370 |
+
|
371 |
+
for image_id, image in self.images.iteritems():
|
372 |
+
print>>fid, image_id,
|
373 |
+
print>>fid, ' '.join(str(qi) for qi in image.q.q),
|
374 |
+
print>>fid, ' '.join(str(ti) for ti in image.tvec),
|
375 |
+
print>>fid, image.camera_id, image.name
|
376 |
+
|
377 |
+
data = np.rec.fromarrays(
|
378 |
+
(image.points2D[:,0], image.points2D[:,1],
|
379 |
+
image.point3D_ids.astype(np.int64)))
|
380 |
+
if len(data) > 0:
|
381 |
+
np.savetxt(fid, data, '%.2f %.2f %d', newline=' ')
|
382 |
+
fid.seek(-1, os.SEEK_CUR)
|
383 |
+
fid.write('\n')
|
384 |
+
|
385 |
+
#---------------------------------------------------------------------------
|
386 |
+
|
387 |
+
def save_points3D(self, output_folder, output_file=None, binary=True):
|
388 |
+
if not os.path.exists(output_folder):
|
389 |
+
os.makedirs(output_folder)
|
390 |
+
|
391 |
+
if output_file is None:
|
392 |
+
output_file = 'points3D.bin' if binary else 'points3D.txt'
|
393 |
+
|
394 |
+
output_file = os.path.join(output_folder, output_file)
|
395 |
+
|
396 |
+
if binary:
|
397 |
+
self._save_points3D_bin(output_file)
|
398 |
+
else:
|
399 |
+
self._save_points3D_txt(output_file)
|
400 |
+
|
401 |
+
def _save_points3D_bin(self, output_file):
|
402 |
+
num_valid_points3D = sum(
|
403 |
+
1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues()
|
404 |
+
if point3D_idx != SceneManager.INVALID_POINT3D)
|
405 |
+
|
406 |
+
iter_point3D_id_to_point3D_idx = \
|
407 |
+
self.point3D_id_to_point3D_idx.iteritems()
|
408 |
+
|
409 |
+
with open(output_file, 'wb') as fid:
|
410 |
+
fid.write(struct.pack('L', num_valid_points3D))
|
411 |
+
|
412 |
+
for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
|
413 |
+
if point3D_idx == SceneManager.INVALID_POINT3D:
|
414 |
+
continue
|
415 |
+
|
416 |
+
fid.write(struct.pack('L', point3D_id))
|
417 |
+
fid.write(self.points3D[point3D_idx].tobytes())
|
418 |
+
fid.write(self.point3D_colors[point3D_idx].tobytes())
|
419 |
+
fid.write(self.point3D_errors[point3D_idx].tobytes())
|
420 |
+
fid.write(
|
421 |
+
struct.pack('L', len(self.point3D_id_to_images[point3D_id])))
|
422 |
+
fid.write(self.point3D_id_to_images[point3D_id].tobytes())
|
423 |
+
|
424 |
+
def _save_points3D_txt(self, output_file):
|
425 |
+
num_valid_points3D = sum(
|
426 |
+
1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues()
|
427 |
+
if point3D_idx != SceneManager.INVALID_POINT3D)
|
428 |
+
|
429 |
+
array_to_string = lambda arr: ' '.join(str(x) for x in arr)
|
430 |
+
|
431 |
+
iter_point3D_id_to_point3D_idx = \
|
432 |
+
self.point3D_id_to_point3D_idx.iteritems()
|
433 |
+
|
434 |
+
with open(output_file, 'w') as fid:
|
435 |
+
print>>fid, '# 3D point list with one line of data per point:'
|
436 |
+
print>>fid, '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as ',
|
437 |
+
print>>fid, '(IMAGE_ID, POINT2D_IDX)'
|
438 |
+
print>>fid, '# Number of points: {},'.format(num_valid_points3D),
|
439 |
+
print>>fid, 'mean track length: unknown'
|
440 |
+
|
441 |
+
for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
|
442 |
+
if point3D_idx == SceneManager.INVALID_POINT3D:
|
443 |
+
continue
|
444 |
+
|
445 |
+
print>>fid, point3D_id,
|
446 |
+
print>>fid, array_to_string(self.points3D[point3D_idx]),
|
447 |
+
print>>fid, array_to_string(self.point3D_colors[point3D_idx]),
|
448 |
+
print>>fid, self.point3D_errors[point3D_idx],
|
449 |
+
print>>fid, array_to_string(
|
450 |
+
self.point3D_id_to_images[point3D_id].flat)
|
451 |
+
|
452 |
+
#---------------------------------------------------------------------------
|
453 |
+
|
454 |
+
# return the image id associated with a given image file
|
455 |
+
def get_image_from_name(self, image_name):
|
456 |
+
image_id = self.name_to_image_id[image_name]
|
457 |
+
return image_id, self.images[image_id]
|
458 |
+
|
459 |
+
#---------------------------------------------------------------------------
|
460 |
+
|
461 |
+
def get_camera(self, camera_id):
|
462 |
+
return self.cameras[camera_id]
|
463 |
+
|
464 |
+
#---------------------------------------------------------------------------
|
465 |
+
|
466 |
+
def get_points3D(self, image_id, return_points2D=True, return_colors=False):
|
467 |
+
image = self.images[image_id]
|
468 |
+
|
469 |
+
mask = (image.point3D_ids != SceneManager.INVALID_POINT3D)
|
470 |
+
|
471 |
+
point3D_idxs = np.array([
|
472 |
+
self.point3D_id_to_point3D_idx[point3D_id]
|
473 |
+
for point3D_id in image.point3D_ids[mask]])
|
474 |
+
# detect filtered points
|
475 |
+
filter_mask = (point3D_idxs != SceneManager.INVALID_POINT3D)
|
476 |
+
point3D_idxs = point3D_idxs[filter_mask]
|
477 |
+
result = [self.points3D[point3D_idxs,:]]
|
478 |
+
|
479 |
+
if return_points2D:
|
480 |
+
mask[mask] &= filter_mask
|
481 |
+
result += [image.points2D[mask]]
|
482 |
+
if return_colors:
|
483 |
+
result += [self.point3D_colors[point3D_idxs,:]]
|
484 |
+
|
485 |
+
return result if len(result) > 1 else result[0]
|
486 |
+
|
487 |
+
#---------------------------------------------------------------------------
|
488 |
+
|
489 |
+
def point3D_valid(self, point3D_id):
|
490 |
+
return (self.point3D_id_to_point3D_idx[point3D_id] !=
|
491 |
+
SceneManager.INVALID_POINT3D)
|
492 |
+
|
493 |
+
#---------------------------------------------------------------------------
|
494 |
+
|
495 |
+
def get_filtered_points3D(self, return_colors=False):
|
496 |
+
point3D_idxs = [
|
497 |
+
idx for idx in self.point3D_id_to_point3D_idx.values()
|
498 |
+
if idx != SceneManager.INVALID_POINT3D]
|
499 |
+
result = [self.points3D[point3D_idxs,:]]
|
500 |
+
|
501 |
+
if return_colors:
|
502 |
+
result += [self.point3D_colors[point3D_idxs,:]]
|
503 |
+
|
504 |
+
return result if len(result) > 1 else result[0]
|
505 |
+
|
506 |
+
#---------------------------------------------------------------------------
|
507 |
+
|
508 |
+
# return 3D points shared by two images
|
509 |
+
def get_shared_points3D(self, image_id1, image_id2):
|
510 |
+
point3D_ids = (
|
511 |
+
set(self.images[image_id1].point3D_ids) &
|
512 |
+
set(self.images[image_id2].point3D_ids))
|
513 |
+
point3D_ids.discard(SceneManager.INVALID_POINT3D)
|
514 |
+
|
515 |
+
point3D_idxs = np.array([self.point3D_id_to_point3D_idx[point3D_id]
|
516 |
+
for point3D_id in point3D_ids])
|
517 |
+
|
518 |
+
return self.points3D[point3D_idxs,:]
|
519 |
+
|
520 |
+
#---------------------------------------------------------------------------
|
521 |
+
|
522 |
+
# project *all* 3D points into image, return their projection coordinates,
|
523 |
+
# as well as their 3D positions
|
524 |
+
def get_viewed_points(self, image_id):
|
525 |
+
image = self.images[image_id]
|
526 |
+
|
527 |
+
# get unfiltered points
|
528 |
+
point3D_idxs = set(self.point3D_id_to_point3D_idx.itervalues())
|
529 |
+
point3D_idxs.discard(SceneManager.INVALID_POINT3D)
|
530 |
+
point3D_idxs = list(point3D_idxs)
|
531 |
+
points3D = self.points3D[point3D_idxs,:]
|
532 |
+
|
533 |
+
# orient points relative to camera
|
534 |
+
R = image.q.ToR()
|
535 |
+
points3D = points3D.dot(R.T) + image.tvec[np.newaxis,:]
|
536 |
+
points3D = points3D[points3D[:,2] > 0,:] # keep points with positive z
|
537 |
+
|
538 |
+
# put points into image coordinates
|
539 |
+
camera = self.cameras[image.camera_id]
|
540 |
+
points2D = points3D.dot(camera.get_camera_matrix().T)
|
541 |
+
points2D = points2D[:,:2] / points2D[:,2][:,np.newaxis]
|
542 |
+
|
543 |
+
# keep points that are within the image
|
544 |
+
mask = (
|
545 |
+
(points2D[:,0] >= 0) &
|
546 |
+
(points2D[:,1] >= 0) &
|
547 |
+
(points2D[:,0] < camera.width - 1) &
|
548 |
+
(points2D[:,1] < camera.height - 1))
|
549 |
+
|
550 |
+
return points2D[mask,:], points3D[mask,:]
|
551 |
+
|
552 |
+
#---------------------------------------------------------------------------
|
553 |
+
|
554 |
+
def add_camera(self, camera):
|
555 |
+
self.last_camera_id += 1
|
556 |
+
self.cameras[self.last_camera_id] = camera
|
557 |
+
return self.last_camera_id
|
558 |
+
|
559 |
+
#---------------------------------------------------------------------------
|
560 |
+
|
561 |
+
def add_image(self, image):
|
562 |
+
self.last_image_id += 1
|
563 |
+
self.images[self.last_image_id] = image
|
564 |
+
return self.last_image_id
|
565 |
+
|
566 |
+
#---------------------------------------------------------------------------
|
567 |
+
|
568 |
+
def delete_images(self, image_list):
|
569 |
+
# delete specified images
|
570 |
+
for image_id in image_list:
|
571 |
+
if image_id in self.images:
|
572 |
+
del self.images[image_id]
|
573 |
+
|
574 |
+
keep_set = set(self.images.iterkeys())
|
575 |
+
|
576 |
+
# delete references to specified images, and ignore any points that are
|
577 |
+
# invalidated
|
578 |
+
iter_point3D_id_to_point3D_idx = \
|
579 |
+
self.point3D_id_to_point3D_idx.iteritems()
|
580 |
+
|
581 |
+
for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
|
582 |
+
if point3D_idx == SceneManager.INVALID_POINT3D:
|
583 |
+
continue
|
584 |
+
|
585 |
+
mask = np.array([
|
586 |
+
image_id in keep_set
|
587 |
+
for image_id in self.point3D_id_to_images[point3D_id][:,0]])
|
588 |
+
if np.any(mask):
|
589 |
+
self.point3D_id_to_images[point3D_id] = \
|
590 |
+
self.point3D_id_to_images[point3D_id][mask]
|
591 |
+
else:
|
592 |
+
self.point3D_id_to_point3D_idx[point3D_id] = \
|
593 |
+
SceneManager.INVALID_POINT3D
|
594 |
+
|
595 |
+
#---------------------------------------------------------------------------
|
596 |
+
|
597 |
+
# camera_list: set of cameras whose points we'd like to keep
|
598 |
+
# min/max triangulation angle: in degrees
|
599 |
+
def filter_points3D(self,
|
600 |
+
min_track_len=0, max_error=np.inf, min_tri_angle=0,
|
601 |
+
max_tri_angle=180, image_set=set()):
|
602 |
+
|
603 |
+
image_set = set(image_set)
|
604 |
+
|
605 |
+
check_triangulation_angles = (min_tri_angle > 0 or max_tri_angle < 180)
|
606 |
+
if check_triangulation_angles:
|
607 |
+
max_tri_prod = np.cos(np.radians(min_tri_angle))
|
608 |
+
min_tri_prod = np.cos(np.radians(max_tri_angle))
|
609 |
+
|
610 |
+
iter_point3D_id_to_point3D_idx = \
|
611 |
+
self.point3D_id_to_point3D_idx.iteritems()
|
612 |
+
|
613 |
+
image_ids = []
|
614 |
+
|
615 |
+
for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx:
|
616 |
+
if point3D_idx == SceneManager.INVALID_POINT3D:
|
617 |
+
continue
|
618 |
+
|
619 |
+
if image_set or min_track_len > 0:
|
620 |
+
image_ids = set(self.point3D_id_to_images[point3D_id][:,0])
|
621 |
+
|
622 |
+
# check if error and min track length are sufficient, or if none of
|
623 |
+
# the selected cameras see the point
|
624 |
+
if (len(image_ids) < min_track_len or
|
625 |
+
self.point3D_errors[point3D_idx] > max_error or
|
626 |
+
image_set and image_set.isdisjoint(image_ids)):
|
627 |
+
self.point3D_id_to_point3D_idx[point3D_id] = \
|
628 |
+
SceneManager.INVALID_POINT3D
|
629 |
+
|
630 |
+
# find dot product between all camera viewing rays
|
631 |
+
elif check_triangulation_angles:
|
632 |
+
xyz = self.points3D[point3D_idx,:]
|
633 |
+
tvecs = np.array(
|
634 |
+
[(self.images[image_id].tvec - xyz)
|
635 |
+
for image_id in image_ids])
|
636 |
+
tvecs /= np.linalg.norm(tvecs, axis=-1)[:,np.newaxis]
|
637 |
+
|
638 |
+
cos_theta = np.array(
|
639 |
+
[u.dot(v) for u,v in combinations(tvecs, 2)])
|
640 |
+
|
641 |
+
# min_prod = cos(maximum viewing angle), and vice versa
|
642 |
+
# if maximum viewing angle is too small or too large,
|
643 |
+
# don't add this point
|
644 |
+
if (np.min(cos_theta) > max_tri_prod or
|
645 |
+
np.max(cos_theta) < min_tri_prod):
|
646 |
+
self.point3D_id_to_point3D_idx[point3D_id] = \
|
647 |
+
SceneManager.INVALID_POINT3D
|
648 |
+
|
649 |
+
# apply the filters to the image point3D_ids
|
650 |
+
for image in self.images.itervalues():
|
651 |
+
mask = np.array([
|
652 |
+
self.point3D_id_to_point3D_idx.get(point3D_id, 0) \
|
653 |
+
== SceneManager.INVALID_POINT3D
|
654 |
+
for point3D_id in image.point3D_ids])
|
655 |
+
image.point3D_ids[mask] = SceneManager.INVALID_POINT3D
|
656 |
+
|
657 |
+
#---------------------------------------------------------------------------
|
658 |
+
|
659 |
+
# scene graph: {image_id: [image_id: #shared points]}
|
660 |
+
def build_scene_graph(self):
|
661 |
+
self.scene_graph = defaultdict(lambda: defaultdict(int))
|
662 |
+
point3D_iter = self.point3D_id_to_images.iteritems()
|
663 |
+
|
664 |
+
for i, (point3D_id, images) in enumerate(point3D_iter):
|
665 |
+
if not self.point3D_valid(point3D_id):
|
666 |
+
continue
|
667 |
+
|
668 |
+
for image_id1, image_id2 in combinations(images[:,0], 2):
|
669 |
+
self.scene_graph[image_id1][image_id2] += 1
|
670 |
+
self.scene_graph[image_id2][image_id1] += 1
|
internal/pycolmap/tools/colmap_to_nvm.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import sys
|
3 |
+
sys.path.append("..")
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from pycolmap import Quaternion, SceneManager
|
8 |
+
|
9 |
+
|
10 |
+
#-------------------------------------------------------------------------------
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
scene_manager = SceneManager(args.input_folder)
|
14 |
+
scene_manager.load()
|
15 |
+
|
16 |
+
with open(args.output_file, "w") as fid:
|
17 |
+
fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images)))
|
18 |
+
|
19 |
+
image_fmt_str = " {:.3f} " + 7 * "{:.7f} "
|
20 |
+
for image_id, image in scene_manager.images.iteritems():
|
21 |
+
camera = scene_manager.cameras[image.camera_id]
|
22 |
+
f = 0.5 * (camera.fx + camera.fy)
|
23 |
+
fid.write(args.image_name_prefix + image.name)
|
24 |
+
fid.write(image_fmt_str.format(
|
25 |
+
*((f,) + tuple(image.q.q) + tuple(image.C()))))
|
26 |
+
if camera.distortion_func is None:
|
27 |
+
fid.write("0 0\n")
|
28 |
+
else:
|
29 |
+
fid.write("{:.7f} 0\n".format(-camera.k1))
|
30 |
+
|
31 |
+
image_id_to_idx = dict(
|
32 |
+
(image_id, i) for i, image_id in enumerate(scene_manager.images))
|
33 |
+
|
34 |
+
fid.write("{:d}\n".format(len(scene_manager.points3D)))
|
35 |
+
for i, point3D_id in enumerate(scene_manager.point3D_ids):
|
36 |
+
fid.write(
|
37 |
+
"{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i]))
|
38 |
+
fid.write(
|
39 |
+
"{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i]))
|
40 |
+
keypoints = [
|
41 |
+
(image_id_to_idx[image_id], kp_idx) +
|
42 |
+
tuple(scene_manager.images[image_id].points2D[kp_idx])
|
43 |
+
for image_id, kp_idx in
|
44 |
+
scene_manager.point3D_id_to_images[point3D_id]]
|
45 |
+
fid.write("{:d}".format(len(keypoints)))
|
46 |
+
fid.write(
|
47 |
+
(len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format(
|
48 |
+
*itertools.chain(*keypoints)))
|
49 |
+
|
50 |
+
|
51 |
+
#-------------------------------------------------------------------------------
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
import argparse
|
55 |
+
|
56 |
+
parser = argparse.ArgumentParser(
|
57 |
+
description="Save a COLMAP reconstruction in the NVM format "
|
58 |
+
"(http://ccwu.me/vsfm/doc.html#nvm).",
|
59 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
60 |
+
|
61 |
+
parser.add_argument("input_folder")
|
62 |
+
parser.add_argument("output_file")
|
63 |
+
|
64 |
+
parser.add_argument("--image_name_prefix", type=str, default="",
|
65 |
+
help="prefix image names with this string (e.g., 'images/')")
|
66 |
+
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
main(args)
|
internal/pycolmap/tools/delete_images.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from pycolmap import DualQuaternion, Image, SceneManager
|
7 |
+
|
8 |
+
|
9 |
+
#-------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
scene_manager = SceneManager(args.input_folder)
|
13 |
+
scene_manager.load()
|
14 |
+
|
15 |
+
image_ids = map(scene_manager.get_image_from_name,
|
16 |
+
iter(lambda: sys.stdin.readline().strip(), ""))
|
17 |
+
scene_manager.delete_images(image_ids)
|
18 |
+
|
19 |
+
scene_manager.save(args.output_folder)
|
20 |
+
|
21 |
+
|
22 |
+
#-------------------------------------------------------------------------------
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
import argparse
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser(
|
28 |
+
description="Deletes images (filenames read from stdin) from a model.",
|
29 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
30 |
+
|
31 |
+
parser.add_argument("input_folder")
|
32 |
+
parser.add_argument("output_folder")
|
33 |
+
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
main(args)
|
internal/pycolmap/tools/impute_missing_cameras.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from pycolmap import DualQuaternion, Image, SceneManager
|
7 |
+
|
8 |
+
|
9 |
+
#-------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
image_to_idx = lambda im: int(im.name[:im.name.rfind(".")])
|
12 |
+
|
13 |
+
|
14 |
+
#-------------------------------------------------------------------------------
|
15 |
+
|
16 |
+
def interpolate_linear(images, camera_id, file_format):
|
17 |
+
if len(images) < 2:
|
18 |
+
raise ValueError("Need at least two images for linear interpolation!")
|
19 |
+
|
20 |
+
prev_image = images[0]
|
21 |
+
prev_idx = image_to_idx(prev_image)
|
22 |
+
prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t)
|
23 |
+
start = prev_idx
|
24 |
+
|
25 |
+
new_images = []
|
26 |
+
|
27 |
+
for image in images[1:]:
|
28 |
+
curr_idx = image_to_idx(image)
|
29 |
+
curr_dq = DualQuaternion.FromQT(image.q, image.t)
|
30 |
+
T = curr_idx - prev_idx
|
31 |
+
Tinv = 1. / T
|
32 |
+
|
33 |
+
# like quaternions, dq(x) = -dq(x), so we'll need to pick the one more
|
34 |
+
# appropriate for interpolation by taking -dq if the dot product of the
|
35 |
+
# two q-vectors is negative
|
36 |
+
if prev_dq.q0.dot(curr_dq.q0) < 0:
|
37 |
+
curr_dq = -curr_dq
|
38 |
+
|
39 |
+
for i in xrange(1, T):
|
40 |
+
t = i * Tinv
|
41 |
+
dq = t * prev_dq + (1. - t) * curr_dq
|
42 |
+
q, t = dq.ToQT()
|
43 |
+
new_images.append(
|
44 |
+
Image(file_format.format(prev_idx + i), args.camera_id, q, t))
|
45 |
+
|
46 |
+
prev_idx = curr_idx
|
47 |
+
prev_dq = curr_dq
|
48 |
+
|
49 |
+
return new_images
|
50 |
+
|
51 |
+
|
52 |
+
#-------------------------------------------------------------------------------
|
53 |
+
|
54 |
+
def interpolate_hermite(images, camera_id, file_format):
|
55 |
+
if len(images) < 4:
|
56 |
+
raise ValueError(
|
57 |
+
"Need at least four images for Hermite spline interpolation!")
|
58 |
+
|
59 |
+
new_images = []
|
60 |
+
|
61 |
+
# linear blending for the first frames
|
62 |
+
T0 = image_to_idx(images[0])
|
63 |
+
dq0 = DualQuaternion.FromQT(images[0].q, images[0].t)
|
64 |
+
T1 = image_to_idx(images[1])
|
65 |
+
dq1 = DualQuaternion.FromQT(images[1].q, images[1].t)
|
66 |
+
|
67 |
+
if dq0.q0.dot(dq1.q0) < 0:
|
68 |
+
dq1 = -dq1
|
69 |
+
dT = 1. / float(T1 - T0)
|
70 |
+
for j in xrange(1, T1 - T0):
|
71 |
+
t = j * dT
|
72 |
+
dq = ((1. - t) * dq0 + t * dq1).normalize()
|
73 |
+
new_images.append(
|
74 |
+
Image(file_format.format(T0 + j), camera_id, *dq.ToQT()))
|
75 |
+
|
76 |
+
T2 = image_to_idx(images[2])
|
77 |
+
dq2 = DualQuaternion.FromQT(images[2].q, images[2].t)
|
78 |
+
if dq1.q0.dot(dq2.q0) < 0:
|
79 |
+
dq2 = -dq2
|
80 |
+
|
81 |
+
# Hermite spline interpolation of dual quaternions
|
82 |
+
# pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf
|
83 |
+
for i in xrange(1, len(images) - 2):
|
84 |
+
T3 = image_to_idx(images[i + 2])
|
85 |
+
dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t)
|
86 |
+
if dq2.q0.dot(dq3.q0) < 0:
|
87 |
+
dq3 = -dq3
|
88 |
+
|
89 |
+
prev_duration = T1 - T0
|
90 |
+
current_duration = T2 - T1
|
91 |
+
next_duration = T3 - T2
|
92 |
+
|
93 |
+
# approximate the derivatives at dq1 and dq2 using weighted central
|
94 |
+
# differences
|
95 |
+
dt1 = 1. / float(T2 - T0)
|
96 |
+
dt2 = 1. / float(T3 - T1)
|
97 |
+
|
98 |
+
m1 = (current_duration * dt1) * (dq2 - dq1) + \
|
99 |
+
(prev_duration * dt1) * (dq1 - dq0)
|
100 |
+
m2 = (next_duration * dt2) * (dq3 - dq2) + \
|
101 |
+
(current_duration * dt2) * (dq2 - dq1)
|
102 |
+
|
103 |
+
dT = 1. / float(current_duration)
|
104 |
+
|
105 |
+
for j in xrange(1, current_duration):
|
106 |
+
t = j * dT # 0 to 1
|
107 |
+
t2 = t * t # t squared
|
108 |
+
t3 = t2 * t # t cubed
|
109 |
+
|
110 |
+
# coefficients of the Hermite spline (a=>dq and b=>m)
|
111 |
+
a1 = 2. * t3 - 3. * t2 + 1.
|
112 |
+
b1 = t3 - 2. * t2 + t
|
113 |
+
a2 = -2. * t3 + 3. * t2
|
114 |
+
b2 = t3 - t2
|
115 |
+
|
116 |
+
dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize()
|
117 |
+
|
118 |
+
new_images.append(
|
119 |
+
Image(file_format.format(T1 + j), camera_id, *dq.ToQT()))
|
120 |
+
|
121 |
+
T0, T1, T2 = T1, T2, T3
|
122 |
+
dq0, dq1, dq2 = dq1, dq2, dq3
|
123 |
+
|
124 |
+
# linear blending for the last frames
|
125 |
+
dT = 1. / float(T2 - T1)
|
126 |
+
for j in xrange(1, T2 - T1):
|
127 |
+
t = j * dT # 0 to 1
|
128 |
+
dq = ((1. - t) * dq1 + t * dq2).normalize()
|
129 |
+
new_images.append(
|
130 |
+
Image(file_format.format(T1 + j), camera_id, *dq.ToQT()))
|
131 |
+
|
132 |
+
return new_images
|
133 |
+
|
134 |
+
|
135 |
+
#-------------------------------------------------------------------------------
|
136 |
+
|
137 |
+
def main(args):
|
138 |
+
scene_manager = SceneManager(args.input_folder)
|
139 |
+
scene_manager.load()
|
140 |
+
|
141 |
+
images = sorted(scene_manager.images.itervalues(), key=image_to_idx)
|
142 |
+
|
143 |
+
if args.method.lower() == "linear":
|
144 |
+
new_images = interpolate_linear(images, args.camera_id, args.format)
|
145 |
+
else:
|
146 |
+
new_images = interpolate_hermite(images, args.camera_id, args.format)
|
147 |
+
|
148 |
+
map(scene_manager.add_image, new_images)
|
149 |
+
|
150 |
+
scene_manager.save(args.output_folder)
|
151 |
+
|
152 |
+
|
153 |
+
#-------------------------------------------------------------------------------
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
import argparse
|
157 |
+
|
158 |
+
parser = argparse.ArgumentParser(
|
159 |
+
description="Given a reconstruction with ordered images *with integer "
|
160 |
+
"filenames* like '000100.png', fill in missing camera positions for "
|
161 |
+
"intermediate frames.",
|
162 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
163 |
+
|
164 |
+
parser.add_argument("input_folder")
|
165 |
+
parser.add_argument("output_folder")
|
166 |
+
|
167 |
+
parser.add_argument("--camera_id", type=int, default=1,
|
168 |
+
help="camera id to use for the missing images")
|
169 |
+
|
170 |
+
parser.add_argument("--format", type=str, default="{:06d}.png",
|
171 |
+
help="filename format to use for added images")
|
172 |
+
|
173 |
+
parser.add_argument(
|
174 |
+
"--method", type=str.lower, choices=("linear", "hermite"),
|
175 |
+
default="hermite",
|
176 |
+
help="Pose imputation method")
|
177 |
+
|
178 |
+
args = parser.parse_args()
|
179 |
+
|
180 |
+
main(args)
|
internal/pycolmap/tools/save_cameras_as_ply.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
from pycolmap import SceneManager
|
8 |
+
|
9 |
+
|
10 |
+
#-------------------------------------------------------------------------------
|
11 |
+
|
12 |
+
# Saves the cameras as a mesh
|
13 |
+
#
|
14 |
+
# inputs:
|
15 |
+
# - ply_file: output file
|
16 |
+
# - images: ordered array of pycolmap Image objects
|
17 |
+
# - color: color string for the camera
|
18 |
+
# - scale: amount to shrink/grow the camera model
|
19 |
+
def save_camera_ply(ply_file, images, scale):
|
20 |
+
points3D = scale * np.array((
|
21 |
+
(0., 0., 0.),
|
22 |
+
(-1., -1., 1.),
|
23 |
+
(-1., 1., 1.),
|
24 |
+
(1., -1., 1.),
|
25 |
+
(1., 1., 1.)))
|
26 |
+
|
27 |
+
faces = np.array(((0, 2, 1),
|
28 |
+
(0, 4, 2),
|
29 |
+
(0, 3, 4),
|
30 |
+
(0, 1, 3),
|
31 |
+
(1, 2, 4),
|
32 |
+
(1, 4, 3)))
|
33 |
+
|
34 |
+
r = np.linspace(0, 255, len(images), dtype=np.uint8)
|
35 |
+
g = 255 - r
|
36 |
+
b = r - np.linspace(0, 128, len(images), dtype=np.uint8)
|
37 |
+
color = np.column_stack((r, g, b))
|
38 |
+
|
39 |
+
with open(ply_file, "w") as fid:
|
40 |
+
print>>fid, "ply"
|
41 |
+
print>>fid, "format ascii 1.0"
|
42 |
+
print>>fid, "element vertex", len(points3D) * len(images)
|
43 |
+
print>>fid, "property float x"
|
44 |
+
print>>fid, "property float y"
|
45 |
+
print>>fid, "property float z"
|
46 |
+
print>>fid, "property uchar red"
|
47 |
+
print>>fid, "property uchar green"
|
48 |
+
print>>fid, "property uchar blue"
|
49 |
+
print>>fid, "element face", len(faces) * len(images)
|
50 |
+
print>>fid, "property list uchar int vertex_index"
|
51 |
+
print>>fid, "end_header"
|
52 |
+
|
53 |
+
for image, c in zip(images, color):
|
54 |
+
for p3D in (points3D.dot(image.R()) + image.C()):
|
55 |
+
print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2]
|
56 |
+
|
57 |
+
for i in xrange(len(images)):
|
58 |
+
for f in (faces + len(points3D) * i):
|
59 |
+
print>>fid, "3 {} {} {}".format(*f)
|
60 |
+
|
61 |
+
|
62 |
+
#-------------------------------------------------------------------------------
|
63 |
+
|
64 |
+
def main(args):
|
65 |
+
scene_manager = SceneManager(args.input_folder)
|
66 |
+
scene_manager.load_images()
|
67 |
+
|
68 |
+
images = sorted(scene_manager.images.itervalues(),
|
69 |
+
key=lambda image: image.name)
|
70 |
+
|
71 |
+
save_camera_ply(args.output_file, images, args.scale)
|
72 |
+
|
73 |
+
|
74 |
+
#-------------------------------------------------------------------------------
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
import argparse
|
78 |
+
|
79 |
+
parser = argparse.ArgumentParser(
|
80 |
+
description="Saves camera positions to a PLY for easy viewing outside "
|
81 |
+
"of COLMAP. Currently, camera FoV is not reflected in the output.",
|
82 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
83 |
+
|
84 |
+
parser.add_argument("input_folder")
|
85 |
+
parser.add_argument("output_file")
|
86 |
+
|
87 |
+
parser.add_argument("--scale", type=float, default=1.,
|
88 |
+
help="Scaling factor for the camera mesh.")
|
89 |
+
|
90 |
+
args = parser.parse_args()
|
91 |
+
|
92 |
+
main(args)
|
internal/pycolmap/tools/transform_model.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from pycolmap import Quaternion, SceneManager
|
7 |
+
|
8 |
+
|
9 |
+
#-------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
scene_manager = SceneManager(args.input_folder)
|
13 |
+
scene_manager.load()
|
14 |
+
|
15 |
+
# expect each line of input corresponds to one row
|
16 |
+
P = np.array([
|
17 |
+
map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)])
|
18 |
+
|
19 |
+
scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3]
|
20 |
+
|
21 |
+
# get rotation without any global scaling (assuming isotropic scaling)
|
22 |
+
scale = np.cbrt(np.linalg.det(P[:,:3]))
|
23 |
+
q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale)
|
24 |
+
|
25 |
+
for image in scene_manager.images.itervalues():
|
26 |
+
image.q *= q_old_from_new
|
27 |
+
image.tvec = scale * image.tvec - image.R().dot(P[:,3])
|
28 |
+
|
29 |
+
scene_manager.save(args.output_folder)
|
30 |
+
|
31 |
+
|
32 |
+
#-------------------------------------------------------------------------------
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
import argparse
|
36 |
+
|
37 |
+
parser = argparse.ArgumentParser(
|
38 |
+
description="Apply a 3x4 transformation matrix to a COLMAP model and "
|
39 |
+
"save the result as a new model. Row-major input can be piped in from "
|
40 |
+
"a file or entered via the command line.",
|
41 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
42 |
+
|
43 |
+
parser.add_argument("input_folder")
|
44 |
+
parser.add_argument("output_folder")
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
main(args)
|
internal/pycolmap/tools/write_camera_track_to_bundler.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from pycolmap import SceneManager
|
7 |
+
|
8 |
+
|
9 |
+
#-------------------------------------------------------------------------------
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
scene_manager = SceneManager(args.input_folder)
|
13 |
+
scene_manager.load_cameras()
|
14 |
+
scene_manager.load_images()
|
15 |
+
|
16 |
+
if args.sort:
|
17 |
+
images = sorted(
|
18 |
+
scene_manager.images.itervalues(), key=lambda im: im.name)
|
19 |
+
else:
|
20 |
+
images = scene_manager.images.values()
|
21 |
+
|
22 |
+
fid = open(args.output_file, "w")
|
23 |
+
fid_filenames = open(args.output_file + ".list.txt", "w")
|
24 |
+
|
25 |
+
print>>fid, "# Bundle file v0.3"
|
26 |
+
print>>fid, len(images), 0
|
27 |
+
|
28 |
+
for image in images:
|
29 |
+
print>>fid_filenames, image.name
|
30 |
+
camera = scene_manager.cameras[image.camera_id]
|
31 |
+
print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0
|
32 |
+
R, t = image.R(), image.t
|
33 |
+
print>>fid, R[0, 0], R[0, 1], R[0, 2]
|
34 |
+
print>>fid, -R[1, 0], -R[1, 1], -R[1, 2]
|
35 |
+
print>>fid, -R[2, 0], -R[2, 1], -R[2, 2]
|
36 |
+
print>>fid, t[0], -t[1], -t[2]
|
37 |
+
|
38 |
+
fid.close()
|
39 |
+
fid_filenames.close()
|
40 |
+
|
41 |
+
|
42 |
+
#-------------------------------------------------------------------------------
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
import argparse
|
46 |
+
|
47 |
+
parser = argparse.ArgumentParser(
|
48 |
+
description="Saves the camera positions in the Bundler format. Note "
|
49 |
+
"that 3D points are not saved.",
|
50 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
51 |
+
|
52 |
+
parser.add_argument("input_folder")
|
53 |
+
parser.add_argument("output_file")
|
54 |
+
|
55 |
+
parser.add_argument("--sort", default=False, action="store_true",
|
56 |
+
help="sort the images by their filename")
|
57 |
+
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
main(args)
|
internal/pycolmap/tools/write_depthmap_to_ply.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
|
4 |
+
import imageio
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
|
8 |
+
from plyfile import PlyData, PlyElement
|
9 |
+
from pycolmap import SceneManager
|
10 |
+
from scipy.ndimage.interpolation import zoom
|
11 |
+
|
12 |
+
|
13 |
+
#-------------------------------------------------------------------------------
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
suffix = ".photometric.bin" if args.photometric else ".geometric.bin"
|
17 |
+
|
18 |
+
image_file = os.path.join(args.dense_folder, "images", args.image_filename)
|
19 |
+
depth_file = os.path.join(
|
20 |
+
args.dense_folder, args.stereo_folder, "depth_maps",
|
21 |
+
args.image_filename + suffix)
|
22 |
+
if args.save_normals:
|
23 |
+
normals_file = os.path.join(
|
24 |
+
args.dense_folder, args.stereo_folder, "normal_maps",
|
25 |
+
args.image_filename + suffix)
|
26 |
+
|
27 |
+
# load camera intrinsics from the COLMAP reconstruction
|
28 |
+
scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse"))
|
29 |
+
scene_manager.load_cameras()
|
30 |
+
scene_manager.load_images()
|
31 |
+
|
32 |
+
image_id, image = scene_manager.get_image_from_name(args.image_filename)
|
33 |
+
camera = scene_manager.cameras[image.camera_id]
|
34 |
+
rotation_camera_from_world = image.R()
|
35 |
+
camera_center = image.C()
|
36 |
+
|
37 |
+
# load image, depth map, and normal map
|
38 |
+
image = imageio.imread(image_file)
|
39 |
+
|
40 |
+
with open(depth_file, "rb") as fid:
|
41 |
+
w = int("".join(iter(lambda: fid.read(1), "&")))
|
42 |
+
h = int("".join(iter(lambda: fid.read(1), "&")))
|
43 |
+
c = int("".join(iter(lambda: fid.read(1), "&")))
|
44 |
+
depth_map = np.fromfile(fid, np.float32).reshape(h, w)
|
45 |
+
if (h, w) != image.shape[:2]:
|
46 |
+
depth_map = zoom(
|
47 |
+
depth_map,
|
48 |
+
(float(image.shape[0]) / h, float(image.shape[1]) / w),
|
49 |
+
order=0)
|
50 |
+
|
51 |
+
if args.save_normals:
|
52 |
+
with open(normals_file, "rb") as fid:
|
53 |
+
w = int("".join(iter(lambda: fid.read(1), "&")))
|
54 |
+
h = int("".join(iter(lambda: fid.read(1), "&")))
|
55 |
+
c = int("".join(iter(lambda: fid.read(1), "&")))
|
56 |
+
normals = np.fromfile(
|
57 |
+
fid, np.float32).reshape(c, h, w).transpose([1, 2, 0])
|
58 |
+
if (h, w) != image.shape[:2]:
|
59 |
+
normals = zoom(
|
60 |
+
normals,
|
61 |
+
(float(image.shape[0]) / h, float(image.shape[1]) / w, 1.),
|
62 |
+
order=0)
|
63 |
+
|
64 |
+
if args.min_depth is not None:
|
65 |
+
depth_map[depth_map < args.min_depth] = 0.
|
66 |
+
if args.max_depth is not None:
|
67 |
+
depth_map[depth_map > args.max_depth] = 0.
|
68 |
+
|
69 |
+
# create 3D points
|
70 |
+
#depth_map = np.minimum(depth_map, 100.)
|
71 |
+
points3D = np.dstack(camera.get_image_grid() + [depth_map])
|
72 |
+
points3D[:,:,:2] *= depth_map[:,:,np.newaxis]
|
73 |
+
|
74 |
+
# save
|
75 |
+
points3D = points3D.astype(np.float32).reshape(-1, 3)
|
76 |
+
if args.save_normals:
|
77 |
+
normals = normals.astype(np.float32).reshape(-1, 3)
|
78 |
+
image = image.reshape(-1, 3)
|
79 |
+
if image.dtype != np.uint8:
|
80 |
+
if image.max() <= 1:
|
81 |
+
image = (image * 255.).astype(np.uint8)
|
82 |
+
else:
|
83 |
+
image = image.astype(np.uint8)
|
84 |
+
|
85 |
+
if args.world_space:
|
86 |
+
points3D = points3D.dot(rotation_camera_from_world) + camera_center
|
87 |
+
if args.save_normals:
|
88 |
+
normals = normals.dot(rotation_camera_from_world)
|
89 |
+
|
90 |
+
if args.save_normals:
|
91 |
+
vertices = np.rec.fromarrays(
|
92 |
+
tuple(points3D.T) + tuple(normals.T) + tuple(image.T),
|
93 |
+
names="x,y,z,nx,ny,nz,red,green,blue")
|
94 |
+
else:
|
95 |
+
vertices = np.rec.fromarrays(
|
96 |
+
tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue")
|
97 |
+
vertices = PlyElement.describe(vertices, "vertex")
|
98 |
+
PlyData([vertices]).write(args.output_filename)
|
99 |
+
|
100 |
+
|
101 |
+
#-------------------------------------------------------------------------------
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
import argparse
|
105 |
+
|
106 |
+
parser = argparse.ArgumentParser(
|
107 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
108 |
+
|
109 |
+
parser.add_argument("dense_folder", type=str)
|
110 |
+
parser.add_argument("image_filename", type=str)
|
111 |
+
parser.add_argument("output_filename", type=str)
|
112 |
+
|
113 |
+
parser.add_argument(
|
114 |
+
"--photometric", default=False, action="store_true",
|
115 |
+
help="use photometric depthmap instead of geometric")
|
116 |
+
|
117 |
+
parser.add_argument(
|
118 |
+
"--world_space", default=False, action="store_true",
|
119 |
+
help="apply the camera->world extrinsic transformation to the result")
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--save_normals", default=False, action="store_true",
|
123 |
+
help="load the estimated normal map and save as part of the PLY")
|
124 |
+
|
125 |
+
parser.add_argument(
|
126 |
+
"--stereo_folder", type=str, default="stereo",
|
127 |
+
help="folder in the dense workspace containing depth and normal maps")
|
128 |
+
|
129 |
+
parser.add_argument(
|
130 |
+
"--min_depth", type=float, default=None,
|
131 |
+
help="set pixels with depth less than this value to zero depth")
|
132 |
+
|
133 |
+
parser.add_argument(
|
134 |
+
"--max_depth", type=float, default=None,
|
135 |
+
help="set pixels with depth greater than this value to zero depth")
|
136 |
+
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
main(args)
|
internal/raw_utils.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from internal import image as lib_image
|
5 |
+
from internal import math
|
6 |
+
from internal import utils
|
7 |
+
import numpy as np
|
8 |
+
import rawpy
|
9 |
+
|
10 |
+
|
11 |
+
def postprocess_raw(raw, camtorgb, exposure=None):
|
12 |
+
"""Converts demosaicked raw to sRGB with a minimal postprocessing pipeline.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
raw: [H, W, 3], demosaicked raw camera image.
|
16 |
+
camtorgb: [3, 3], color correction transformation to apply to raw image.
|
17 |
+
exposure: color value to be scaled to pure white after color correction.
|
18 |
+
If None, "autoexposes" at the 97th percentile.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
srgb: [H, W, 3], color corrected + exposed + gamma mapped image.
|
22 |
+
"""
|
23 |
+
if raw.shape[-1] != 3:
|
24 |
+
raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3')
|
25 |
+
if camtorgb.shape != (3, 3):
|
26 |
+
raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)')
|
27 |
+
# Convert from camera color space to standard linear RGB color space.
|
28 |
+
rgb_linear = np.matmul(raw, camtorgb.T)
|
29 |
+
if exposure is None:
|
30 |
+
exposure = np.percentile(rgb_linear, 97)
|
31 |
+
# "Expose" image by mapping the input exposure level to white and clipping.
|
32 |
+
rgb_linear_scaled = np.clip(rgb_linear / exposure, 0, 1)
|
33 |
+
# Apply sRGB gamma curve to serve as a simple tonemap.
|
34 |
+
srgb = lib_image.linear_to_srgb_np(rgb_linear_scaled)
|
35 |
+
return srgb
|
36 |
+
|
37 |
+
|
38 |
+
def pixels_to_bayer_mask(pix_x, pix_y):
|
39 |
+
"""Computes binary RGB Bayer mask values from integer pixel coordinates."""
|
40 |
+
# Red is top left (0, 0).
|
41 |
+
r = (pix_x % 2 == 0) * (pix_y % 2 == 0)
|
42 |
+
# Green is top right (0, 1) and bottom left (1, 0).
|
43 |
+
g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1)
|
44 |
+
# Blue is bottom right (1, 1).
|
45 |
+
b = (pix_x % 2 == 1) * (pix_y % 2 == 1)
|
46 |
+
return np.stack([r, g, b], -1).astype(np.float32)
|
47 |
+
|
48 |
+
|
49 |
+
def bilinear_demosaic(bayer):
|
50 |
+
"""Converts Bayer data into a full RGB image using bilinear demosaicking.
|
51 |
+
|
52 |
+
Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern:
|
53 |
+
-------------
|
54 |
+
|red |green|
|
55 |
+
-------------
|
56 |
+
|green|blue |
|
57 |
+
-------------
|
58 |
+
Red and blue channels are bilinearly upsampled 2x, missing green channel
|
59 |
+
elements are the average of the neighboring 4 values in a cross pattern.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
bayer: [H, W] array, Bayer mosaic pattern input image.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
rgb: [H, W, 3] array, full RGB image.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def reshape_quads(*planes):
|
69 |
+
"""Reshape pixels from four input images to make tiled 2x2 quads."""
|
70 |
+
planes = np.stack(planes, -1)
|
71 |
+
shape = planes.shape[:-1]
|
72 |
+
# Create [2, 2] arrays out of 4 channels.
|
73 |
+
zup = planes.reshape(shape + (2, 2,))
|
74 |
+
# Transpose so that x-axis dimensions come before y-axis dimensions.
|
75 |
+
zup = np.transpose(zup, (0, 2, 1, 3))
|
76 |
+
# Reshape to 2D.
|
77 |
+
zup = zup.reshape((shape[0] * 2, shape[1] * 2))
|
78 |
+
return zup
|
79 |
+
|
80 |
+
def bilinear_upsample(z):
|
81 |
+
"""2x bilinear image upsample."""
|
82 |
+
# Using np.roll makes the right and bottom edges wrap around. The raw image
|
83 |
+
# data has a few garbage columns/rows at the edges that must be discarded
|
84 |
+
# anyway, so this does not matter in practice.
|
85 |
+
# Horizontally interpolated values.
|
86 |
+
zx = .5 * (z + np.roll(z, -1, axis=-1))
|
87 |
+
# Vertically interpolated values.
|
88 |
+
zy = .5 * (z + np.roll(z, -1, axis=-2))
|
89 |
+
# Diagonally interpolated values.
|
90 |
+
zxy = .5 * (zx + np.roll(zx, -1, axis=-2))
|
91 |
+
return reshape_quads(z, zx, zy, zxy)
|
92 |
+
|
93 |
+
def upsample_green(g1, g2):
|
94 |
+
"""Special 2x upsample from the two green channels."""
|
95 |
+
z = np.zeros_like(g1)
|
96 |
+
z = reshape_quads(z, g1, g2, z)
|
97 |
+
alt = 0
|
98 |
+
# Grab the 4 directly adjacent neighbors in a "cross" pattern.
|
99 |
+
for i in range(4):
|
100 |
+
axis = -1 - (i // 2)
|
101 |
+
roll = -1 + 2 * (i % 2)
|
102 |
+
alt = alt + .25 * np.roll(z, roll, axis=axis)
|
103 |
+
# For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross),
|
104 |
+
# so alt + z will have every pixel filled in.
|
105 |
+
return alt + z
|
106 |
+
|
107 |
+
r, g1, g2, b = [bayer[(i // 2)::2, (i % 2)::2] for i in range(4)]
|
108 |
+
r = bilinear_upsample(r)
|
109 |
+
# Flip in x and y before and after calling upsample, as bilinear_upsample
|
110 |
+
# assumes that the samples are at the top-left corner of the 2x2 sample.
|
111 |
+
b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1]
|
112 |
+
g = upsample_green(g1, g2)
|
113 |
+
rgb = np.stack([r, g, b], -1)
|
114 |
+
return rgb
|
115 |
+
|
116 |
+
|
117 |
+
def load_raw_images(image_dir, image_names=None):
|
118 |
+
"""Loads raw images and their metadata from disk.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
image_dir: directory containing raw image and EXIF data.
|
122 |
+
image_names: files to load (ignores file extension), loads all DNGs if None.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
A tuple (images, exifs).
|
126 |
+
images: [N, height, width, 3] array of raw sensor data.
|
127 |
+
exifs: [N] list of dicts, one per image, containing the EXIF data.
|
128 |
+
Raises:
|
129 |
+
ValueError: The requested `image_dir` does not exist on disk.
|
130 |
+
"""
|
131 |
+
|
132 |
+
if not utils.file_exists(image_dir):
|
133 |
+
raise ValueError(f'Raw image folder {image_dir} does not exist.')
|
134 |
+
|
135 |
+
# Load raw images (dng files) and exif metadata (json files).
|
136 |
+
def load_raw_exif(image_name):
|
137 |
+
base = os.path.join(image_dir, os.path.splitext(image_name)[0])
|
138 |
+
with utils.open_file(base + '.dng', 'rb') as f:
|
139 |
+
raw = rawpy.imread(f).raw_image
|
140 |
+
with utils.open_file(base + '.json', 'rb') as f:
|
141 |
+
exif = json.load(f)[0]
|
142 |
+
return raw, exif
|
143 |
+
|
144 |
+
if image_names is None:
|
145 |
+
image_names = [
|
146 |
+
os.path.basename(f)
|
147 |
+
for f in sorted(glob.glob(os.path.join(image_dir, '*.dng')))
|
148 |
+
]
|
149 |
+
|
150 |
+
data = [load_raw_exif(x) for x in image_names]
|
151 |
+
raws, exifs = zip(*data)
|
152 |
+
raws = np.stack(raws, axis=0).astype(np.float32)
|
153 |
+
|
154 |
+
return raws, exifs
|
155 |
+
|
156 |
+
|
157 |
+
# Brightness percentiles to use for re-exposing and tonemapping raw images.
|
158 |
+
_PERCENTILE_LIST = (80, 90, 97, 99, 100)
|
159 |
+
|
160 |
+
# Relevant fields to extract from raw image EXIF metadata.
|
161 |
+
# For details regarding EXIF parameters, see:
|
162 |
+
# https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf.
|
163 |
+
_EXIF_KEYS = (
|
164 |
+
'BlackLevel', # Black level offset added to sensor measurements.
|
165 |
+
'WhiteLevel', # Maximum possible sensor measurement.
|
166 |
+
'AsShotNeutral', # RGB white balance coefficients.
|
167 |
+
'ColorMatrix2', # XYZ to camera color space conversion matrix.
|
168 |
+
'NoiseProfile', # Shot and read noise levels.
|
169 |
+
)
|
170 |
+
|
171 |
+
# Color conversion from reference illuminant XYZ to RGB color space.
|
172 |
+
# See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html.
|
173 |
+
_RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375],
|
174 |
+
[0.2126729, 0.7151522, 0.0721750],
|
175 |
+
[0.0193339, 0.1191920, 0.9503041]])
|
176 |
+
|
177 |
+
|
178 |
+
def process_exif(exifs):
|
179 |
+
"""Processes list of raw image EXIF data into useful metadata dict.
|
180 |
+
|
181 |
+
Input should be a list of dictionaries loaded from JSON files.
|
182 |
+
These JSON files are produced by running
|
183 |
+
$ exiftool -json IMAGE.dng > IMAGE.json
|
184 |
+
for each input raw file.
|
185 |
+
|
186 |
+
We extract only the parameters relevant to
|
187 |
+
1. Rescaling the raw data to [0, 1],
|
188 |
+
2. White balance and color correction, and
|
189 |
+
3. Noise level estimation.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
exifs: a list of dicts containing EXIF data as loaded from JSON files.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
meta: a dict of the relevant metadata for running RawNeRF.
|
196 |
+
"""
|
197 |
+
meta = {}
|
198 |
+
exif = exifs[0]
|
199 |
+
# Convert from array of dicts (exifs) to dict of arrays (meta).
|
200 |
+
for key in _EXIF_KEYS:
|
201 |
+
exif_value = exif.get(key)
|
202 |
+
if exif_value is None:
|
203 |
+
continue
|
204 |
+
# Values can be a single int or float...
|
205 |
+
if isinstance(exif_value, int) or isinstance(exif_value, float):
|
206 |
+
vals = [x[key] for x in exifs]
|
207 |
+
# Or a string of numbers with ' ' between.
|
208 |
+
elif isinstance(exif_value, str):
|
209 |
+
vals = [[float(z) for z in x[key].split(' ')] for x in exifs]
|
210 |
+
meta[key] = np.squeeze(np.array(vals))
|
211 |
+
# Shutter speed is a special case, a string written like 1/N.
|
212 |
+
meta['ShutterSpeed'] = np.fromiter(
|
213 |
+
(1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float)
|
214 |
+
|
215 |
+
# Create raw-to-sRGB color transform matrices. Pipeline is:
|
216 |
+
# cam space -> white balanced cam space ("camwb") -> XYZ space -> RGB space.
|
217 |
+
# 'AsShotNeutral' is an RGB triplet representing how pure white would measure
|
218 |
+
# on the sensor, so dividing by these numbers corrects the white balance.
|
219 |
+
whitebalance = meta['AsShotNeutral'].reshape(-1, 3)
|
220 |
+
cam2camwb = np.array([np.diag(1. / x) for x in whitebalance])
|
221 |
+
# ColorMatrix2 converts from XYZ color space to "reference illuminant" (white
|
222 |
+
# balanced) camera space.
|
223 |
+
xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3)
|
224 |
+
rgb2camwb = xyz2camwb @ _RGB2XYZ
|
225 |
+
# We normalize the rows of the full color correction matrix, as is done in
|
226 |
+
# https://github.com/AbdoKamel/simple-camera-pipeline.
|
227 |
+
rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True)
|
228 |
+
# Combining color correction with white balance gives the entire transform.
|
229 |
+
cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb
|
230 |
+
meta['cam2rgb'] = cam2rgb
|
231 |
+
|
232 |
+
return meta
|
233 |
+
|
234 |
+
|
235 |
+
def load_raw_dataset(split, data_dir, image_names, exposure_percentile, n_downsample):
|
236 |
+
"""Loads and processes a set of RawNeRF input images.
|
237 |
+
|
238 |
+
Includes logic necessary for special "test" scenes that include a noiseless
|
239 |
+
ground truth frame, produced by HDR+ merge.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic.
|
243 |
+
data_dir: base directory for scene data.
|
244 |
+
image_names: which images were successfully posed by COLMAP.
|
245 |
+
exposure_percentile: what brightness percentile to expose to white.
|
246 |
+
n_downsample: returned images are downsampled by a factor of n_downsample.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
A tuple (images, meta, testscene).
|
250 |
+
images: [N, height // n_downsample, width // n_downsample, 3] array of
|
251 |
+
demosaicked raw image data.
|
252 |
+
meta: EXIF metadata and other useful processing parameters. Includes per
|
253 |
+
image exposure information that can be passed into the NeRF model with
|
254 |
+
each ray: the set of unique exposure times is determined and each image
|
255 |
+
assigned a corresponding exposure index (mapping to an exposure value).
|
256 |
+
These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in
|
257 |
+
the `meta` dictionary.
|
258 |
+
We rescale so the maximum `exposure_value` is 1 for convenience.
|
259 |
+
testscene: True when dataset includes ground truth test image, else False.
|
260 |
+
"""
|
261 |
+
|
262 |
+
image_dir = os.path.join(data_dir, 'raw')
|
263 |
+
|
264 |
+
testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng')
|
265 |
+
testscene = utils.file_exists(testimg_file)
|
266 |
+
if testscene:
|
267 |
+
# Test scenes have train/ and test/ split subdirectories inside raw/.
|
268 |
+
image_dir = os.path.join(image_dir, split.value)
|
269 |
+
if split == utils.DataSplit.TEST:
|
270 |
+
# COLMAP image names not valid for test split of test scene.
|
271 |
+
image_names = None
|
272 |
+
else:
|
273 |
+
# Discard the first COLMAP image name as it is a copy of the test image.
|
274 |
+
image_names = image_names[1:]
|
275 |
+
|
276 |
+
raws, exifs = load_raw_images(image_dir, image_names)
|
277 |
+
meta = process_exif(exifs)
|
278 |
+
|
279 |
+
if testscene and split == utils.DataSplit.TEST:
|
280 |
+
# Test split for test scene must load the "ground truth" HDR+ merged image.
|
281 |
+
with utils.open_file(testimg_file, 'rb') as imgin:
|
282 |
+
testraw = rawpy.imread(imgin).raw_image
|
283 |
+
# HDR+ output has 2 extra bits of fixed precision, need to divide by 4.
|
284 |
+
testraw = testraw.astype(np.float32) / 4.
|
285 |
+
# Need to rescale long exposure test image by fast:slow shutter speed ratio.
|
286 |
+
fast_shutter = meta['ShutterSpeed'][0]
|
287 |
+
slow_shutter = meta['ShutterSpeed'][-1]
|
288 |
+
shutter_ratio = fast_shutter / slow_shutter
|
289 |
+
# Replace loaded raws with the "ground truth" test image.
|
290 |
+
raws = testraw[None]
|
291 |
+
# Test image shares metadata with the first loaded image (fast exposure).
|
292 |
+
meta = {k: meta[k][:1] for k in meta}
|
293 |
+
else:
|
294 |
+
shutter_ratio = 1.
|
295 |
+
|
296 |
+
# Next we determine an index for each unique shutter speed in the data.
|
297 |
+
shutter_speeds = meta['ShutterSpeed']
|
298 |
+
# Sort the shutter speeds from slowest (largest) to fastest (smallest).
|
299 |
+
# This way index 0 will always correspond to the brightest image.
|
300 |
+
unique_shutters = np.sort(np.unique(shutter_speeds))[::-1]
|
301 |
+
exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32)
|
302 |
+
for i, shutter in enumerate(unique_shutters):
|
303 |
+
# Assign index `i` to all images with shutter speed `shutter`.
|
304 |
+
exposure_idx[shutter_speeds == shutter] = i
|
305 |
+
meta['exposure_idx'] = exposure_idx
|
306 |
+
meta['unique_shutters'] = unique_shutters
|
307 |
+
# Rescale to use relative shutter speeds, where 1. is the brightest.
|
308 |
+
# This way the NeRF output with exposure=1 will always be reasonable.
|
309 |
+
meta['exposure_values'] = shutter_speeds / unique_shutters[0]
|
310 |
+
|
311 |
+
# Rescale raw sensor measurements to [0, 1] (plus noise).
|
312 |
+
blacklevel = meta['BlackLevel'].reshape(-1, 1, 1)
|
313 |
+
whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1)
|
314 |
+
images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio
|
315 |
+
|
316 |
+
# Calculate value for exposure level when gamma mapping, defaults to 97%.
|
317 |
+
# Always based on full resolution image 0 (for consistency).
|
318 |
+
image0_raw_demosaic = np.array(bilinear_demosaic(images[0]))
|
319 |
+
image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T
|
320 |
+
exposure = np.percentile(image0_rgb, exposure_percentile)
|
321 |
+
meta['exposure'] = exposure
|
322 |
+
# Sweep over various exposure percentiles to visualize in training logs.
|
323 |
+
exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST}
|
324 |
+
meta['exposure_levels'] = exposure_levels
|
325 |
+
|
326 |
+
# Create postprocessing function mapping raw images to tonemapped sRGB space.
|
327 |
+
cam2rgb0 = meta['cam2rgb'][0]
|
328 |
+
meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x)
|
329 |
+
|
330 |
+
def processing_fn(x):
|
331 |
+
x_ = np.array(x)
|
332 |
+
x_demosaic = bilinear_demosaic(x_)
|
333 |
+
if n_downsample > 1:
|
334 |
+
x_demosaic = lib_image.downsample(x_demosaic, n_downsample)
|
335 |
+
return np.array(x_demosaic)
|
336 |
+
|
337 |
+
images = np.stack([processing_fn(im) for im in images], axis=0)
|
338 |
+
|
339 |
+
return images, meta, testscene
|
340 |
+
|
341 |
+
|
342 |
+
def best_fit_affine(x, y, axis):
|
343 |
+
"""Computes best fit a, b such that a * x + b = y, in a least square sense."""
|
344 |
+
x_m = x.mean(axis=axis)
|
345 |
+
y_m = y.mean(axis=axis)
|
346 |
+
xy_m = (x * y).mean(axis=axis)
|
347 |
+
xx_m = (x * x).mean(axis=axis)
|
348 |
+
# slope a = Cov(x, y) / Cov(x, x).
|
349 |
+
a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m)
|
350 |
+
b = y_m - a * x_m
|
351 |
+
return a, b
|
352 |
+
|
353 |
+
|
354 |
+
def match_images_affine(est, gt, axis=(0, 1)):
|
355 |
+
"""Computes affine best fit of gt->est, then maps est back to match gt."""
|
356 |
+
# Mapping is computed gt->est to be robust since `est` may be very noisy.
|
357 |
+
a, b = best_fit_affine(gt, est, axis=axis)
|
358 |
+
# Inverse mapping back to gt ensures we use a consistent space for metrics.
|
359 |
+
est_matched = (est - b) / a
|
360 |
+
return est_matched
|
internal/ref_utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from internal import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def reflect(viewdirs, normals):
|
7 |
+
"""Reflect view directions about normals.
|
8 |
+
|
9 |
+
The reflection of a vector v about a unit vector n is a vector u such that
|
10 |
+
dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two
|
11 |
+
equations is u = 2 dot(n, v) n - v.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
viewdirs: [..., 3] array of view directions.
|
15 |
+
normals: [..., 3] array of normal directions (assumed to be unit vectors).
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
[..., 3] array of reflection directions.
|
19 |
+
"""
|
20 |
+
return 2.0 * torch.sum(normals * viewdirs, dim=-1, keepdim=True) * normals - viewdirs
|
21 |
+
|
22 |
+
|
23 |
+
def l2_normalize(x):
|
24 |
+
"""Normalize x to unit length along last axis."""
|
25 |
+
return torch.nn.functional.normalize(x, dim=-1, eps=torch.finfo(x.dtype).eps)
|
26 |
+
|
27 |
+
|
28 |
+
def l2_normalize_np(x):
|
29 |
+
"""Normalize x to unit length along last axis."""
|
30 |
+
return x / np.sqrt(np.maximum(np.sum(x ** 2, axis=-1, keepdims=True), np.finfo(x.dtype).eps))
|
31 |
+
|
32 |
+
|
33 |
+
def compute_weighted_mae(weights, normals, normals_gt):
|
34 |
+
"""Compute weighted mean angular error, assuming normals are unit length."""
|
35 |
+
one_eps = 1 - torch.finfo(weights.dtype).eps
|
36 |
+
return (weights * torch.arccos(torch.clip((normals * normals_gt).sum(-1),
|
37 |
+
-one_eps, one_eps))).sum() / weights.sum() * 180.0 / torch.pi
|
38 |
+
|
39 |
+
|
40 |
+
def compute_weighted_mae_np(weights, normals, normals_gt):
|
41 |
+
"""Compute weighted mean angular error, assuming normals are unit length."""
|
42 |
+
one_eps = 1 - np.finfo(weights.dtype).eps
|
43 |
+
return (weights * np.arccos(np.clip((normals * normals_gt).sum(-1),
|
44 |
+
-one_eps, one_eps))).sum() / weights.sum() * 180.0 / np.pi
|
45 |
+
|
46 |
+
|
47 |
+
def generalized_binomial_coeff(a, k):
|
48 |
+
"""Compute generalized binomial coefficients."""
|
49 |
+
return np.prod(a - np.arange(k)) / np.math.factorial(k)
|
50 |
+
|
51 |
+
|
52 |
+
def assoc_legendre_coeff(l, m, k):
|
53 |
+
"""Compute associated Legendre polynomial coefficients.
|
54 |
+
|
55 |
+
Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the
|
56 |
+
(l, m)th associated Legendre polynomial, P_l^m(cos(theta)).
|
57 |
+
|
58 |
+
Args:
|
59 |
+
l: associated Legendre polynomial degree.
|
60 |
+
m: associated Legendre polynomial order.
|
61 |
+
k: power of cos(theta).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
A float, the coefficient of the term corresponding to the inputs.
|
65 |
+
"""
|
66 |
+
return ((-1) ** m * 2 ** l * np.math.factorial(l) / np.math.factorial(k) /
|
67 |
+
np.math.factorial(l - k - m) *
|
68 |
+
generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l))
|
69 |
+
|
70 |
+
|
71 |
+
def sph_harm_coeff(l, m, k):
|
72 |
+
"""Compute spherical harmonic coefficients."""
|
73 |
+
return (np.sqrt(
|
74 |
+
(2.0 * l + 1.0) * np.math.factorial(l - m) /
|
75 |
+
(4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k))
|
76 |
+
|
77 |
+
|
78 |
+
def get_ml_array(deg_view):
|
79 |
+
"""Create a list with all pairs of (l, m) values to use in the encoding."""
|
80 |
+
ml_list = []
|
81 |
+
for i in range(deg_view):
|
82 |
+
l = 2 ** i
|
83 |
+
# Only use nonnegative m values, later splitting real and imaginary parts.
|
84 |
+
for m in range(l + 1):
|
85 |
+
ml_list.append((m, l))
|
86 |
+
|
87 |
+
# Convert list into a numpy array.
|
88 |
+
ml_array = np.array(ml_list).T
|
89 |
+
return ml_array
|
90 |
+
|
91 |
+
|
92 |
+
def generate_ide_fn(deg_view):
|
93 |
+
"""Generate integrated directional encoding (IDE) function.
|
94 |
+
|
95 |
+
This function returns a function that computes the integrated directional
|
96 |
+
encoding from Equations 6-8 of arxiv.org/abs/2112.03907.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
deg_view: number of spherical harmonics degrees to use.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
A function for evaluating integrated directional encoding.
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
ValueError: if deg_view is larger than 5.
|
106 |
+
"""
|
107 |
+
if deg_view > 5:
|
108 |
+
raise ValueError('Only deg_view of at most 5 is numerically stable.')
|
109 |
+
|
110 |
+
ml_array = get_ml_array(deg_view)
|
111 |
+
l_max = 2 ** (deg_view - 1)
|
112 |
+
|
113 |
+
# Create a matrix corresponding to ml_array holding all coefficients, which,
|
114 |
+
# when multiplied (from the right) by the z coordinate Vandermonde matrix,
|
115 |
+
# results in the z component of the encoding.
|
116 |
+
mat = np.zeros((l_max + 1, ml_array.shape[1]))
|
117 |
+
for i, (m, l) in enumerate(ml_array.T):
|
118 |
+
for k in range(l - m + 1):
|
119 |
+
mat[k, i] = sph_harm_coeff(l, m, k)
|
120 |
+
mat = torch.from_numpy(mat).float()
|
121 |
+
ml_array = torch.from_numpy(ml_array).float()
|
122 |
+
|
123 |
+
def integrated_dir_enc_fn(xyz, kappa_inv):
|
124 |
+
"""Function returning integrated directional encoding (IDE).
|
125 |
+
|
126 |
+
Args:
|
127 |
+
xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at.
|
128 |
+
kappa_inv: [..., 1] reciprocal of the concentration parameter of the von
|
129 |
+
Mises-Fisher distribution.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
An array with the resulting IDE.
|
133 |
+
"""
|
134 |
+
x = xyz[..., 0:1]
|
135 |
+
y = xyz[..., 1:2]
|
136 |
+
z = xyz[..., 2:3]
|
137 |
+
|
138 |
+
# Compute z Vandermonde matrix.
|
139 |
+
vmz = torch.cat([z ** i for i in range(mat.shape[0])], dim=-1)
|
140 |
+
|
141 |
+
# Compute x+iy Vandermonde matrix.
|
142 |
+
vmxy = torch.cat([(x + 1j * y) ** m for m in ml_array[0, :]], dim=-1)
|
143 |
+
|
144 |
+
# Get spherical harmonics.
|
145 |
+
sph_harms = vmxy * math.matmul(vmz, mat.to(xyz.device))
|
146 |
+
|
147 |
+
# Apply attenuation function using the von Mises-Fisher distribution
|
148 |
+
# concentration parameter, kappa.
|
149 |
+
sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1)
|
150 |
+
sigma = sigma.to(sph_harms.device)
|
151 |
+
ide = sph_harms * torch.exp(-sigma * kappa_inv)
|
152 |
+
|
153 |
+
# Split into real and imaginary parts and return
|
154 |
+
return torch.cat([torch.real(ide), torch.imag(ide)], dim=-1)
|
155 |
+
|
156 |
+
return integrated_dir_enc_fn
|
157 |
+
|
158 |
+
|
159 |
+
def generate_dir_enc_fn(deg_view):
|
160 |
+
"""Generate directional encoding (DE) function.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
deg_view: number of spherical harmonics degrees to use.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
A function for evaluating directional encoding.
|
167 |
+
"""
|
168 |
+
integrated_dir_enc_fn = generate_ide_fn(deg_view)
|
169 |
+
|
170 |
+
def dir_enc_fn(xyz):
|
171 |
+
"""Function returning directional encoding (DE)."""
|
172 |
+
return integrated_dir_enc_fn(xyz, torch.zeros_like(xyz[..., :1]))
|
173 |
+
|
174 |
+
return dir_enc_fn
|
internal/render.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
from internal import stepfun
|
4 |
+
from internal import math
|
5 |
+
from internal import utils
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def lift_gaussian(d, t_mean, t_var, r_var, diag):
|
11 |
+
"""Lift a Gaussian defined along a ray to 3D coordinates."""
|
12 |
+
mean = d[..., None, :] * t_mean[..., None]
|
13 |
+
eps = torch.finfo(d.dtype).eps
|
14 |
+
# eps = 1e-3
|
15 |
+
d_mag_sq = torch.sum(d ** 2, dim=-1, keepdim=True).clamp_min(eps)
|
16 |
+
|
17 |
+
if diag:
|
18 |
+
d_outer_diag = d ** 2
|
19 |
+
null_outer_diag = 1 - d_outer_diag / d_mag_sq
|
20 |
+
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
|
21 |
+
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
|
22 |
+
cov_diag = t_cov_diag + xy_cov_diag
|
23 |
+
return mean, cov_diag
|
24 |
+
else:
|
25 |
+
d_outer = d[..., :, None] * d[..., None, :]
|
26 |
+
eye = torch.eye(d.shape[-1], device=d.device)
|
27 |
+
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
|
28 |
+
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
|
29 |
+
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
|
30 |
+
cov = t_cov + xy_cov
|
31 |
+
return mean, cov
|
32 |
+
|
33 |
+
|
34 |
+
def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True):
|
35 |
+
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
|
36 |
+
|
37 |
+
Assumes the ray is originating from the origin, and base_radius is the
|
38 |
+
radius at dist=1. Doesn't assume `d` is normalized.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
d: the axis of the cone
|
42 |
+
t0: the starting distance of the frustum.
|
43 |
+
t1: the ending distance of the frustum.
|
44 |
+
base_radius: the scale of the radius as a function of distance.
|
45 |
+
diag: whether or the Gaussian will be diagonal or full-covariance.
|
46 |
+
stable: whether or not to use the stable computation described in
|
47 |
+
the paper (setting this to False will cause catastrophic failure).
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
a Gaussian (mean and covariance).
|
51 |
+
"""
|
52 |
+
if stable:
|
53 |
+
# Equation 7 in the paper (https://arxiv.org/abs/2103.13415).
|
54 |
+
mu = (t0 + t1) / 2 # The average of the two `t` values.
|
55 |
+
hw = (t1 - t0) / 2 # The half-width of the two `t` values.
|
56 |
+
eps = torch.finfo(d.dtype).eps
|
57 |
+
# eps = 1e-3
|
58 |
+
t_mean = mu + (2 * mu * hw ** 2) / (3 * mu ** 2 + hw ** 2).clamp_min(eps)
|
59 |
+
denom = (3 * mu ** 2 + hw ** 2).clamp_min(eps)
|
60 |
+
t_var = (hw ** 2) / 3 - (4 / 15) * hw ** 4 * (12 * mu ** 2 - hw ** 2) / denom ** 2
|
61 |
+
r_var = (mu ** 2) / 4 + (5 / 12) * hw ** 2 - (4 / 15) * (hw ** 4) / denom
|
62 |
+
else:
|
63 |
+
# Equations 37-39 in the paper.
|
64 |
+
t_mean = (3 * (t1 ** 4 - t0 ** 4)) / (4 * (t1 ** 3 - t0 ** 3))
|
65 |
+
r_var = 3 / 20 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3)
|
66 |
+
t_mosq = 3 / 5 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3)
|
67 |
+
t_var = t_mosq - t_mean ** 2
|
68 |
+
r_var *= base_radius ** 2
|
69 |
+
return lift_gaussian(d, t_mean, t_var, r_var, diag)
|
70 |
+
|
71 |
+
|
72 |
+
def cylinder_to_gaussian(d, t0, t1, radius, diag):
|
73 |
+
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
|
74 |
+
|
75 |
+
Assumes the ray is originating from the origin, and radius is the
|
76 |
+
radius. Does not renormalize `d`.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
d: the axis of the cylinder
|
80 |
+
t0: the starting distance of the cylinder.
|
81 |
+
t1: the ending distance of the cylinder.
|
82 |
+
radius: the radius of the cylinder
|
83 |
+
diag: whether or the Gaussian will be diagonal or full-covariance.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
a Gaussian (mean and covariance).
|
87 |
+
"""
|
88 |
+
t_mean = (t0 + t1) / 2
|
89 |
+
r_var = radius ** 2 / 4
|
90 |
+
t_var = (t1 - t0) ** 2 / 12
|
91 |
+
return lift_gaussian(d, t_mean, t_var, r_var, diag)
|
92 |
+
|
93 |
+
|
94 |
+
def cast_rays(tdist, origins, directions, cam_dirs, radii, rand=True, n=7, m=3, std_scale=0.5, **kwargs):
|
95 |
+
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
tdist: float array, the "fencepost" distances along the ray.
|
99 |
+
origins: float array, the ray origin coordinates.
|
100 |
+
directions: float array, the ray direction vectors.
|
101 |
+
radii: float array, the radii (base radii for cones) of the rays.
|
102 |
+
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
|
103 |
+
diag: boolean, whether or not the covariance matrices should be diagonal.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
a tuple of arrays of means and covariances.
|
107 |
+
"""
|
108 |
+
t0 = tdist[..., :-1, None]
|
109 |
+
t1 = tdist[..., 1:, None]
|
110 |
+
radii = radii[..., None]
|
111 |
+
|
112 |
+
t_m = (t0 + t1) / 2
|
113 |
+
t_d = (t1 - t0) / 2
|
114 |
+
|
115 |
+
j = torch.arange(6, device=tdist.device)
|
116 |
+
t = t0 + t_d / (t_d ** 2 + 3 * t_m ** 2) * (t1 ** 2 + 2 * t_m ** 2 + 3 / 7 ** 0.5 * (2 * j / 5 - 1) * (
|
117 |
+
(t_d ** 2 - t_m ** 2) ** 2 + 4 * t_m ** 4).sqrt())
|
118 |
+
|
119 |
+
deg = torch.pi / 3 * torch.tensor([0, 2, 4, 3, 5, 1], device=tdist.device, dtype=torch.float)
|
120 |
+
deg = torch.broadcast_to(deg, t.shape)
|
121 |
+
if rand:
|
122 |
+
# randomly rotate and flip
|
123 |
+
mask = torch.rand_like(t0[..., 0]) > 0.5
|
124 |
+
deg = deg + 2 * torch.pi * torch.rand_like(deg[..., 0])[..., None]
|
125 |
+
deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg)
|
126 |
+
else:
|
127 |
+
# rotate 30 degree and flip every other pattern
|
128 |
+
mask = torch.arange(t.shape[-2], device=tdist.device) % 2 == 0
|
129 |
+
mask = torch.broadcast_to(mask, t.shape[:-1])
|
130 |
+
deg = torch.where(mask[..., None], deg, deg + torch.pi / 6)
|
131 |
+
deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg)
|
132 |
+
means = torch.stack([
|
133 |
+
radii * t * torch.cos(deg) / 2 ** 0.5,
|
134 |
+
radii * t * torch.sin(deg) / 2 ** 0.5,
|
135 |
+
t
|
136 |
+
], dim=-1)
|
137 |
+
stds = std_scale * radii * t / 2 ** 0.5
|
138 |
+
|
139 |
+
# two basis in parallel to the image plane
|
140 |
+
rand_vec = torch.randn_like(cam_dirs)
|
141 |
+
ortho1 = F.normalize(torch.cross(cam_dirs, rand_vec, dim=-1), dim=-1)
|
142 |
+
ortho2 = F.normalize(torch.cross(cam_dirs, ortho1, dim=-1), dim=-1)
|
143 |
+
|
144 |
+
# just use directions to be the third vector of the orthonormal basis,
|
145 |
+
# while the cross section of cone is parallel to the image plane
|
146 |
+
basis_matrix = torch.stack([ortho1, ortho2, directions], dim=-1)
|
147 |
+
means = math.matmul(means, basis_matrix[..., None, :, :].transpose(-1, -2))
|
148 |
+
means = means + origins[..., None, None, :]
|
149 |
+
# import trimesh
|
150 |
+
# trimesh.Trimesh(means.reshape(-1, 3).detach().cpu().numpy()).export("test.ply", "ply")
|
151 |
+
|
152 |
+
return means, stds, t
|
153 |
+
|
154 |
+
|
155 |
+
def compute_alpha_weights(density, tdist, dirs, opaque_background=False):
|
156 |
+
"""Helper function for computing alpha compositing weights."""
|
157 |
+
t_delta = tdist[..., 1:] - tdist[..., :-1]
|
158 |
+
delta = t_delta * torch.norm(dirs[..., None, :], dim=-1)
|
159 |
+
density_delta = density * delta
|
160 |
+
|
161 |
+
if opaque_background:
|
162 |
+
# Equivalent to making the final t-interval infinitely wide.
|
163 |
+
density_delta = torch.cat([
|
164 |
+
density_delta[..., :-1],
|
165 |
+
torch.full_like(density_delta[..., -1:], torch.inf)
|
166 |
+
], dim=-1)
|
167 |
+
|
168 |
+
alpha = 1 - torch.exp(-density_delta)
|
169 |
+
trans = torch.exp(-torch.cat([
|
170 |
+
torch.zeros_like(density_delta[..., :1]),
|
171 |
+
torch.cumsum(density_delta[..., :-1], dim=-1)
|
172 |
+
], dim=-1))
|
173 |
+
weights = alpha * trans
|
174 |
+
return weights, alpha, trans
|
175 |
+
|
176 |
+
|
177 |
+
def volumetric_rendering(rgbs,
|
178 |
+
weights,
|
179 |
+
tdist,
|
180 |
+
bg_rgbs,
|
181 |
+
t_far,
|
182 |
+
compute_extras,
|
183 |
+
extras=None):
|
184 |
+
"""Volumetric Rendering Function.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
rgbs: color, [batch_size, num_samples, 3]
|
188 |
+
weights: weights, [batch_size, num_samples].
|
189 |
+
tdist: [batch_size, num_samples].
|
190 |
+
bg_rgbs: the color(s) to use for the background.
|
191 |
+
t_far: [batch_size, 1], the distance of the far plane.
|
192 |
+
compute_extras: bool, if True, compute extra quantities besides color.
|
193 |
+
extras: dict, a set of values along rays to render by alpha compositing.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
rendering: a dict containing an rgb image of size [batch_size, 3], and other
|
197 |
+
visualizations if compute_extras=True.
|
198 |
+
"""
|
199 |
+
eps = torch.finfo(rgbs.dtype).eps
|
200 |
+
# eps = 1e-3
|
201 |
+
rendering = {}
|
202 |
+
|
203 |
+
acc = weights.sum(dim=-1)
|
204 |
+
bg_w = (1 - acc[..., None]).clamp_min(0.) # The weight of the background.
|
205 |
+
rgb = (weights[..., None] * rgbs).sum(dim=-2) + bg_w * bg_rgbs
|
206 |
+
t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:])
|
207 |
+
depth = (
|
208 |
+
torch.clip(
|
209 |
+
torch.nan_to_num((weights * t_mids).sum(dim=-1) / acc.clamp_min(eps), torch.inf),
|
210 |
+
tdist[..., 0], tdist[..., -1]))
|
211 |
+
|
212 |
+
rendering['rgb'] = rgb
|
213 |
+
rendering['depth'] = depth
|
214 |
+
rendering['acc'] = acc
|
215 |
+
|
216 |
+
if compute_extras:
|
217 |
+
if extras is not None:
|
218 |
+
for k, v in extras.items():
|
219 |
+
if v is not None:
|
220 |
+
rendering[k] = (weights[..., None] * v).sum(dim=-2)
|
221 |
+
|
222 |
+
expectation = lambda x: (weights * x).sum(dim=-1) / acc.clamp_min(eps)
|
223 |
+
# For numerical stability this expectation is computing using log-distance.
|
224 |
+
rendering['distance_mean'] = (
|
225 |
+
torch.clip(
|
226 |
+
torch.nan_to_num(torch.exp(expectation(torch.log(t_mids))), torch.inf),
|
227 |
+
tdist[..., 0], tdist[..., -1]))
|
228 |
+
|
229 |
+
# Add an extra fencepost with the far distance at the end of each ray, with
|
230 |
+
# whatever weight is needed to make the new weight vector sum to exactly 1
|
231 |
+
# (`weights` is only guaranteed to sum to <= 1, not == 1).
|
232 |
+
t_aug = torch.cat([tdist, t_far], dim=-1)
|
233 |
+
weights_aug = torch.cat([weights, bg_w], dim=-1)
|
234 |
+
|
235 |
+
ps = [5, 50, 95]
|
236 |
+
distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps)
|
237 |
+
|
238 |
+
for i, p in enumerate(ps):
|
239 |
+
s = 'median' if p == 50 else 'percentile_' + str(p)
|
240 |
+
rendering['distance_' + s] = distance_percentiles[..., i]
|
241 |
+
|
242 |
+
return rendering
|
internal/stepfun.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from internal import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def searchsorted(a, v):
|
7 |
+
"""Find indices where v should be inserted into a to maintain order.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
a: tensor, the sorted reference points that we are scanning to see where v
|
11 |
+
should lie.
|
12 |
+
v: tensor, the query points that we are pretending to insert into a. Does
|
13 |
+
not need to be sorted. All but the last dimensions should match or expand
|
14 |
+
to those of a, the last dimension can differ.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
(idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the
|
18 |
+
range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or
|
19 |
+
last index of a.
|
20 |
+
"""
|
21 |
+
i = torch.arange(a.shape[-1], device=a.device)
|
22 |
+
v_ge_a = v[..., None, :] >= a[..., :, None]
|
23 |
+
idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values
|
24 |
+
idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values
|
25 |
+
return idx_lo, idx_hi
|
26 |
+
|
27 |
+
|
28 |
+
def query(tq, t, y, outside_value=0):
|
29 |
+
"""Look up the values of the step function (t, y) at locations tq."""
|
30 |
+
idx_lo, idx_hi = searchsorted(t, tq)
|
31 |
+
yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value),
|
32 |
+
torch.take_along_dim(y, idx_lo, dim=-1))
|
33 |
+
return yq
|
34 |
+
|
35 |
+
|
36 |
+
def inner_outer(t0, t1, y1):
|
37 |
+
"""Construct inner and outer measures on (t1, y1) for t0."""
|
38 |
+
cy1 = torch.cat([torch.zeros_like(y1[..., :1]),
|
39 |
+
torch.cumsum(y1, dim=-1)],
|
40 |
+
dim=-1)
|
41 |
+
idx_lo, idx_hi = searchsorted(t1, t0)
|
42 |
+
|
43 |
+
cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1)
|
44 |
+
cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1)
|
45 |
+
|
46 |
+
y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1]
|
47 |
+
y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:],
|
48 |
+
cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:]))
|
49 |
+
return y0_inner, y0_outer
|
50 |
+
|
51 |
+
|
52 |
+
def lossfun_outer(t, w, t_env, w_env):
|
53 |
+
"""The proposal weight should be an upper envelope on the nerf weight."""
|
54 |
+
eps = torch.finfo(t.dtype).eps
|
55 |
+
# eps = 1e-3
|
56 |
+
|
57 |
+
_, w_outer = inner_outer(t, t_env, w_env)
|
58 |
+
# We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's
|
59 |
+
# more effective to pull w_outer up than it is to push w_inner down.
|
60 |
+
# Scaled half-quadratic loss that gives a constant gradient at w_outer = 0.
|
61 |
+
return (w - w_outer).clamp_min(0) ** 2 / (w + eps)
|
62 |
+
|
63 |
+
|
64 |
+
def weight_to_pdf(t, w):
|
65 |
+
"""Turn a vector of weights that sums to 1 into a PDF that integrates to 1."""
|
66 |
+
eps = torch.finfo(t.dtype).eps
|
67 |
+
return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps)
|
68 |
+
|
69 |
+
|
70 |
+
def pdf_to_weight(t, p):
|
71 |
+
"""Turn a PDF that integrates to 1 into a vector of weights that sums to 1."""
|
72 |
+
return p * (t[..., 1:] - t[..., :-1])
|
73 |
+
|
74 |
+
|
75 |
+
def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)):
|
76 |
+
"""Dilate (via max-pooling) a non-negative step function."""
|
77 |
+
t0 = t[..., :-1] - dilation
|
78 |
+
t1 = t[..., 1:] + dilation
|
79 |
+
t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1)
|
80 |
+
t_dilate = torch.clip(t_dilate, *domain)
|
81 |
+
w_dilate = torch.max(
|
82 |
+
torch.where(
|
83 |
+
(t0[..., None, :] <= t_dilate[..., None])
|
84 |
+
& (t1[..., None, :] > t_dilate[..., None]),
|
85 |
+
w[..., None, :],
|
86 |
+
torch.zeros_like(w[..., None, :]),
|
87 |
+
), dim=-1).values[..., :-1]
|
88 |
+
return t_dilate, w_dilate
|
89 |
+
|
90 |
+
|
91 |
+
def max_dilate_weights(t,
|
92 |
+
w,
|
93 |
+
dilation,
|
94 |
+
domain=(-torch.inf, torch.inf),
|
95 |
+
renormalize=False):
|
96 |
+
"""Dilate (via max-pooling) a set of weights."""
|
97 |
+
eps = torch.finfo(w.dtype).eps
|
98 |
+
# eps = 1e-3
|
99 |
+
|
100 |
+
p = weight_to_pdf(t, w)
|
101 |
+
t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain)
|
102 |
+
w_dilate = pdf_to_weight(t_dilate, p_dilate)
|
103 |
+
if renormalize:
|
104 |
+
w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps)
|
105 |
+
return t_dilate, w_dilate
|
106 |
+
|
107 |
+
|
108 |
+
def integrate_weights(w):
|
109 |
+
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
|
110 |
+
|
111 |
+
The output's size on the last dimension is one greater than that of the input,
|
112 |
+
because we're computing the integral corresponding to the endpoints of a step
|
113 |
+
function, not the integral of the interior/bin values.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
w: Tensor, which will be integrated along the last axis. This is assumed to
|
117 |
+
sum to 1 along the last axis, and this function will (silently) break if
|
118 |
+
that is not the case.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
|
122 |
+
"""
|
123 |
+
cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1)
|
124 |
+
shape = cw.shape[:-1] + (1,)
|
125 |
+
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
|
126 |
+
cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw,
|
127 |
+
torch.ones(shape, device=cw.device)], dim=-1)
|
128 |
+
return cw0
|
129 |
+
|
130 |
+
|
131 |
+
def integrate_weights_np(w):
|
132 |
+
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
|
133 |
+
|
134 |
+
The output's size on the last dimension is one greater than that of the input,
|
135 |
+
because we're computing the integral corresponding to the endpoints of a step
|
136 |
+
function, not the integral of the interior/bin values.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
w: Tensor, which will be integrated along the last axis. This is assumed to
|
140 |
+
sum to 1 along the last axis, and this function will (silently) break if
|
141 |
+
that is not the case.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
|
145 |
+
"""
|
146 |
+
cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1))
|
147 |
+
shape = cw.shape[:-1] + (1,)
|
148 |
+
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
|
149 |
+
cw0 = np.concatenate([np.zeros(shape), cw,
|
150 |
+
np.ones(shape)], axis=-1)
|
151 |
+
return cw0
|
152 |
+
|
153 |
+
|
154 |
+
def invert_cdf(u, t, w_logits):
|
155 |
+
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
|
156 |
+
# Compute the PDF and CDF for each weight vector.
|
157 |
+
w = torch.softmax(w_logits, dim=-1)
|
158 |
+
cw = integrate_weights(w)
|
159 |
+
# Interpolate into the inverse CDF.
|
160 |
+
t_new = math.sorted_interp(u, cw, t)
|
161 |
+
return t_new
|
162 |
+
|
163 |
+
|
164 |
+
def invert_cdf_np(u, t, w_logits):
|
165 |
+
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
|
166 |
+
# Compute the PDF and CDF for each weight vector.
|
167 |
+
w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True)
|
168 |
+
cw = integrate_weights_np(w)
|
169 |
+
# Interpolate into the inverse CDF.
|
170 |
+
interp_fn = np.interp
|
171 |
+
t_new = interp_fn(u, cw, t)
|
172 |
+
return t_new
|
173 |
+
|
174 |
+
|
175 |
+
def sample(rand,
|
176 |
+
t,
|
177 |
+
w_logits,
|
178 |
+
num_samples,
|
179 |
+
single_jitter=False,
|
180 |
+
deterministic_center=False):
|
181 |
+
"""Piecewise-Constant PDF sampling from a step function.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
rand: random number generator (or None for `linspace` sampling).
|
185 |
+
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
|
186 |
+
w_logits: [..., num_bins], logits corresponding to bin weights
|
187 |
+
num_samples: int, the number of samples.
|
188 |
+
single_jitter: bool, if True, jitter every sample along each ray by the same
|
189 |
+
amount in the inverse CDF. Otherwise, jitter each sample independently.
|
190 |
+
deterministic_center: bool, if False, when `rand` is None return samples that
|
191 |
+
linspace the entire PDF. If True, skip the front and back of the linspace
|
192 |
+
so that the centers of each PDF interval are returned.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
t_samples: [batch_size, num_samples].
|
196 |
+
"""
|
197 |
+
eps = torch.finfo(t.dtype).eps
|
198 |
+
# eps = 1e-3
|
199 |
+
|
200 |
+
device = t.device
|
201 |
+
|
202 |
+
# Draw uniform samples.
|
203 |
+
if not rand:
|
204 |
+
if deterministic_center:
|
205 |
+
pad = 1 / (2 * num_samples)
|
206 |
+
u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device)
|
207 |
+
else:
|
208 |
+
u = torch.linspace(0, 1. - eps, num_samples, device=device)
|
209 |
+
u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,))
|
210 |
+
else:
|
211 |
+
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
212 |
+
u_max = eps + (1 - eps) / num_samples
|
213 |
+
max_jitter = (1 - u_max) / (num_samples - 1) - eps
|
214 |
+
d = 1 if single_jitter else num_samples
|
215 |
+
u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \
|
216 |
+
torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter
|
217 |
+
|
218 |
+
return invert_cdf(u, t, w_logits)
|
219 |
+
|
220 |
+
|
221 |
+
def sample_np(rand,
|
222 |
+
t,
|
223 |
+
w_logits,
|
224 |
+
num_samples,
|
225 |
+
single_jitter=False,
|
226 |
+
deterministic_center=False):
|
227 |
+
"""
|
228 |
+
numpy version of sample()
|
229 |
+
"""
|
230 |
+
eps = np.finfo(np.float32).eps
|
231 |
+
|
232 |
+
# Draw uniform samples.
|
233 |
+
if not rand:
|
234 |
+
if deterministic_center:
|
235 |
+
pad = 1 / (2 * num_samples)
|
236 |
+
u = np.linspace(pad, 1. - pad - eps, num_samples)
|
237 |
+
else:
|
238 |
+
u = np.linspace(0, 1. - eps, num_samples)
|
239 |
+
u = np.broadcast_to(u, t.shape[:-1] + (num_samples,))
|
240 |
+
else:
|
241 |
+
# `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
242 |
+
u_max = eps + (1 - eps) / num_samples
|
243 |
+
max_jitter = (1 - u_max) / (num_samples - 1) - eps
|
244 |
+
d = 1 if single_jitter else num_samples
|
245 |
+
u = np.linspace(0, 1 - u_max, num_samples) + \
|
246 |
+
np.random.rand(*t.shape[:-1], d) * max_jitter
|
247 |
+
|
248 |
+
return invert_cdf_np(u, t, w_logits)
|
249 |
+
|
250 |
+
|
251 |
+
def sample_intervals(rand,
|
252 |
+
t,
|
253 |
+
w_logits,
|
254 |
+
num_samples,
|
255 |
+
single_jitter=False,
|
256 |
+
domain=(-torch.inf, torch.inf)):
|
257 |
+
"""Sample *intervals* (rather than points) from a step function.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
rand: random number generator (or None for `linspace` sampling).
|
261 |
+
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted)
|
262 |
+
w_logits: [..., num_bins], logits corresponding to bin weights
|
263 |
+
num_samples: int, the number of intervals to sample.
|
264 |
+
single_jitter: bool, if True, jitter every sample along each ray by the same
|
265 |
+
amount in the inverse CDF. Otherwise, jitter each sample independently.
|
266 |
+
domain: (minval, maxval), the range of valid values for `t`.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
t_samples: [batch_size, num_samples].
|
270 |
+
"""
|
271 |
+
if num_samples <= 1:
|
272 |
+
raise ValueError(f'num_samples must be > 1, is {num_samples}.')
|
273 |
+
|
274 |
+
# Sample a set of points from the step function.
|
275 |
+
centers = sample(
|
276 |
+
rand,
|
277 |
+
t,
|
278 |
+
w_logits,
|
279 |
+
num_samples,
|
280 |
+
single_jitter,
|
281 |
+
deterministic_center=True)
|
282 |
+
|
283 |
+
# The intervals we return will span the midpoints of each adjacent sample.
|
284 |
+
mid = (centers[..., 1:] + centers[..., :-1]) / 2
|
285 |
+
|
286 |
+
# Each first/last fencepost is the reflection of the first/last midpoint
|
287 |
+
# around the first/last sampled center. We clamp to the limits of the input
|
288 |
+
# domain, provided by the caller.
|
289 |
+
minval, maxval = domain
|
290 |
+
first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval)
|
291 |
+
last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval)
|
292 |
+
|
293 |
+
t_samples = torch.cat([first, mid, last], dim=-1)
|
294 |
+
return t_samples
|
295 |
+
|
296 |
+
|
297 |
+
def lossfun_distortion(t, w):
|
298 |
+
"""Compute iint w[i] w[j] |t[i] - t[j]| di dj."""
|
299 |
+
# The loss incurred between all pairs of intervals.
|
300 |
+
ut = (t[..., 1:] + t[..., :-1]) / 2
|
301 |
+
dut = torch.abs(ut[..., :, None] - ut[..., None, :])
|
302 |
+
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)
|
303 |
+
|
304 |
+
# The loss incurred within each individual interval with itself.
|
305 |
+
loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3
|
306 |
+
|
307 |
+
return loss_inter + loss_intra
|
308 |
+
|
309 |
+
|
310 |
+
def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi):
|
311 |
+
"""Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi])."""
|
312 |
+
# Distortion when the intervals do not overlap.
|
313 |
+
d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2)
|
314 |
+
|
315 |
+
# Distortion when the intervals overlap.
|
316 |
+
d_overlap = (2 *
|
317 |
+
(torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) +
|
318 |
+
3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) +
|
319 |
+
t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo *
|
320 |
+
(t0_lo - t1_hi) + t1_lo * t0_hi *
|
321 |
+
(t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo))
|
322 |
+
|
323 |
+
# Are the two intervals not overlapping?
|
324 |
+
are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi)
|
325 |
+
|
326 |
+
return torch.where(are_disjoint, d_disjoint, d_overlap)
|
327 |
+
|
328 |
+
|
329 |
+
def weighted_percentile(t, w, ps):
|
330 |
+
"""Compute the weighted percentiles of a step function. w's must sum to 1."""
|
331 |
+
cw = integrate_weights(w)
|
332 |
+
# We want to interpolate into the integrated weights according to `ps`.
|
333 |
+
fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i)
|
334 |
+
# Vmap fn to an arbitrary number of leading dimensions.
|
335 |
+
cw_mat = cw.reshape([-1, cw.shape[-1]])
|
336 |
+
t_mat = t.reshape([-1, t.shape[-1]])
|
337 |
+
wprctile_mat = fn(cw_mat, t_mat) # TODO
|
338 |
+
wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),))
|
339 |
+
return wprctile
|
340 |
+
|
341 |
+
|
342 |
+
def resample(t, tp, vp, use_avg=False):
|
343 |
+
"""Resample a step function defined by (tp, vp) into intervals t.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
t: tensor with shape (..., n+1), the endpoints to resample into.
|
347 |
+
tp: tensor with shape (..., m+1), the endpoints of the step function being
|
348 |
+
resampled.
|
349 |
+
vp: tensor with shape (..., m), the values of the step function being
|
350 |
+
resampled.
|
351 |
+
use_avg: bool, if False, return the sum of the step function for each
|
352 |
+
interval in `t`. If True, return the average, weighted by the width of
|
353 |
+
each interval in `t`.
|
354 |
+
eps: float, a small value to prevent division by zero when use_avg=True.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
v: tensor with shape (..., n), the values of the resampled step function.
|
358 |
+
"""
|
359 |
+
eps = torch.finfo(t.dtype).eps
|
360 |
+
# eps = 1e-3
|
361 |
+
|
362 |
+
if use_avg:
|
363 |
+
wp = torch.diff(tp, dim=-1)
|
364 |
+
v_numer = resample(t, tp, vp * wp, use_avg=False)
|
365 |
+
v_denom = resample(t, tp, wp, use_avg=False)
|
366 |
+
v = v_numer / v_denom.clamp_min(eps)
|
367 |
+
return v
|
368 |
+
|
369 |
+
acc = torch.cumsum(vp, dim=-1)
|
370 |
+
acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1)
|
371 |
+
acc0_resampled = math.sorted_interp(t, tp, acc0) # TODO
|
372 |
+
v = torch.diff(acc0_resampled, dim=-1)
|
373 |
+
return v
|
374 |
+
|
375 |
+
|
376 |
+
def resample_np(t, tp, vp, use_avg=False):
|
377 |
+
"""
|
378 |
+
numpy version of resample
|
379 |
+
"""
|
380 |
+
eps = np.finfo(t.dtype).eps
|
381 |
+
if use_avg:
|
382 |
+
wp = np.diff(tp, axis=-1)
|
383 |
+
v_numer = resample_np(t, tp, vp * wp, use_avg=False)
|
384 |
+
v_denom = resample_np(t, tp, wp, use_avg=False)
|
385 |
+
v = v_numer / np.maximum(eps, v_denom)
|
386 |
+
return v
|
387 |
+
|
388 |
+
acc = np.cumsum(vp, axis=-1)
|
389 |
+
acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1)
|
390 |
+
acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0)
|
391 |
+
v = np.diff(acc0_resampled, axis=-1)
|
392 |
+
return v
|
393 |
+
|
394 |
+
|
395 |
+
def blur_stepfun(x, y, r):
|
396 |
+
xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1))
|
397 |
+
y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) -
|
398 |
+
torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r)
|
399 |
+
y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1)
|
400 |
+
yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) *
|
401 |
+
torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0)
|
402 |
+
yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1)
|
403 |
+
return xr, yr
|
internal/train_utils.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import functools
|
3 |
+
|
4 |
+
import torch.optim
|
5 |
+
from internal import camera_utils
|
6 |
+
from internal import configs
|
7 |
+
from internal import datasets
|
8 |
+
from internal import image
|
9 |
+
from internal import math
|
10 |
+
from internal import models
|
11 |
+
from internal import ref_utils
|
12 |
+
from internal import stepfun
|
13 |
+
from internal import utils
|
14 |
+
import numpy as np
|
15 |
+
from torch.utils._pytree import tree_map, tree_flatten
|
16 |
+
from torch_scatter import segment_coo
|
17 |
+
|
18 |
+
|
19 |
+
class GradientScaler(torch.autograd.Function):
|
20 |
+
@staticmethod
|
21 |
+
def forward(ctx, colors, sigmas, ray_dist):
|
22 |
+
ctx.save_for_backward(ray_dist)
|
23 |
+
return colors, sigmas
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def backward(ctx, grad_output_colors, grad_output_sigmas):
|
27 |
+
(ray_dist,) = ctx.saved_tensors
|
28 |
+
scaling = torch.square(ray_dist).clamp(0, 1)
|
29 |
+
return grad_output_colors * scaling[..., None], grad_output_sigmas * scaling, None
|
30 |
+
|
31 |
+
|
32 |
+
def tree_reduce(fn, tree, initializer=0):
|
33 |
+
return functools.reduce(fn, tree_flatten(tree)[0], initializer)
|
34 |
+
|
35 |
+
|
36 |
+
def tree_sum(tree):
|
37 |
+
return tree_reduce(lambda x, y: x + y, tree, initializer=0)
|
38 |
+
|
39 |
+
|
40 |
+
def tree_norm_sq(tree):
|
41 |
+
return tree_sum(tree_map(lambda x: torch.sum(x ** 2), tree))
|
42 |
+
|
43 |
+
|
44 |
+
def tree_norm(tree):
|
45 |
+
return torch.sqrt(tree_norm_sq(tree))
|
46 |
+
|
47 |
+
|
48 |
+
def tree_abs_max(tree):
|
49 |
+
return tree_reduce(
|
50 |
+
lambda x, y: max(x, torch.abs(y).max().item()), tree, initializer=0)
|
51 |
+
|
52 |
+
|
53 |
+
def tree_len(tree):
|
54 |
+
return tree_sum(tree_map(lambda z: np.prod(z.shape), tree))
|
55 |
+
|
56 |
+
|
57 |
+
def summarize_tree(tree, fn, ancestry=(), max_depth=3):
|
58 |
+
"""Flatten 'tree' while 'fn'-ing values and formatting keys like/this."""
|
59 |
+
stats = {}
|
60 |
+
for k, v in tree.items():
|
61 |
+
name = ancestry + (k,)
|
62 |
+
stats['/'.join(name)] = fn(v)
|
63 |
+
if hasattr(v, 'items') and len(ancestry) < (max_depth - 1):
|
64 |
+
stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth))
|
65 |
+
return stats
|
66 |
+
|
67 |
+
|
68 |
+
def compute_data_loss(batch, renderings, config):
|
69 |
+
"""Computes data loss terms for RGB, normal, and depth outputs."""
|
70 |
+
data_losses = []
|
71 |
+
stats = collections.defaultdict(lambda: [])
|
72 |
+
|
73 |
+
# lossmult can be used to apply a weight to each ray in the batch.
|
74 |
+
# For example: masking out rays, applying the Bayer mosaic mask, upweighting
|
75 |
+
# rays from lower resolution images and so on.
|
76 |
+
lossmult = batch['lossmult']
|
77 |
+
lossmult = torch.broadcast_to(lossmult, batch['rgb'][..., :3].shape)
|
78 |
+
if config.disable_multiscale_loss:
|
79 |
+
lossmult = torch.ones_like(lossmult)
|
80 |
+
|
81 |
+
for rendering in renderings:
|
82 |
+
resid_sq = (rendering['rgb'] - batch['rgb'][..., :3]) ** 2
|
83 |
+
denom = lossmult.sum()
|
84 |
+
stats['mses'].append(((lossmult * resid_sq).sum() / denom).item())
|
85 |
+
|
86 |
+
if config.data_loss_type == 'mse':
|
87 |
+
# Mean-squared error (L2) loss.
|
88 |
+
data_loss = resid_sq
|
89 |
+
elif config.data_loss_type == 'charb':
|
90 |
+
# Charbonnier loss.
|
91 |
+
data_loss = torch.sqrt(resid_sq + config.charb_padding ** 2)
|
92 |
+
elif config.data_loss_type == 'rawnerf':
|
93 |
+
# Clip raw values against 1 to match sensor overexposure behavior.
|
94 |
+
rgb_render_clip = rendering['rgb'].clamp_max(1)
|
95 |
+
resid_sq_clip = (rgb_render_clip - batch['rgb'][..., :3]) ** 2
|
96 |
+
# Scale by gradient of log tonemapping curve.
|
97 |
+
scaling_grad = 1. / (1e-3 + rgb_render_clip.detach())
|
98 |
+
# Reweighted L2 loss.
|
99 |
+
data_loss = resid_sq_clip * scaling_grad ** 2
|
100 |
+
else:
|
101 |
+
assert False
|
102 |
+
data_losses.append((lossmult * data_loss).sum() / denom)
|
103 |
+
|
104 |
+
if config.compute_disp_metrics:
|
105 |
+
# Using mean to compute disparity, but other distance statistics can
|
106 |
+
# be used instead.
|
107 |
+
disp = 1 / (1 + rendering['distance_mean'])
|
108 |
+
stats['disparity_mses'].append(((disp - batch['disps']) ** 2).mean().item())
|
109 |
+
|
110 |
+
if config.compute_normal_metrics:
|
111 |
+
if 'normals' in rendering:
|
112 |
+
weights = rendering['acc'] * batch['alphas']
|
113 |
+
normalized_normals_gt = ref_utils.l2_normalize(batch['normals'])
|
114 |
+
normalized_normals = ref_utils.l2_normalize(rendering['normals'])
|
115 |
+
normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals,
|
116 |
+
normalized_normals_gt)
|
117 |
+
else:
|
118 |
+
# If normals are not computed, set MAE to NaN.
|
119 |
+
normal_mae = torch.nan
|
120 |
+
stats['normal_maes'].append(normal_mae.item())
|
121 |
+
|
122 |
+
loss = (
|
123 |
+
config.data_coarse_loss_mult * sum(data_losses[:-1]) +
|
124 |
+
config.data_loss_mult * data_losses[-1])
|
125 |
+
|
126 |
+
stats = {k: np.array(stats[k]) for k in stats}
|
127 |
+
return loss, stats
|
128 |
+
|
129 |
+
|
130 |
+
def interlevel_loss(ray_history, config):
|
131 |
+
"""Computes the interlevel loss defined in mip-NeRF 360."""
|
132 |
+
# Stop the gradient from the interlevel loss onto the NeRF MLP.
|
133 |
+
last_ray_results = ray_history[-1]
|
134 |
+
c = last_ray_results['sdist'].detach()
|
135 |
+
w = last_ray_results['weights'].detach()
|
136 |
+
loss_interlevel = 0.
|
137 |
+
for ray_results in ray_history[:-1]:
|
138 |
+
cp = ray_results['sdist']
|
139 |
+
wp = ray_results['weights']
|
140 |
+
loss_interlevel += stepfun.lossfun_outer(c, w, cp, wp).mean()
|
141 |
+
return config.interlevel_loss_mult * loss_interlevel
|
142 |
+
|
143 |
+
|
144 |
+
def anti_interlevel_loss(ray_history, config):
|
145 |
+
"""Computes the interlevel loss defined in mip-NeRF 360."""
|
146 |
+
last_ray_results = ray_history[-1]
|
147 |
+
c = last_ray_results['sdist'].detach()
|
148 |
+
w = last_ray_results['weights'].detach()
|
149 |
+
w_normalize = w / (c[..., 1:] - c[..., :-1])
|
150 |
+
loss_anti_interlevel = 0.
|
151 |
+
for i, ray_results in enumerate(ray_history[:-1]):
|
152 |
+
cp = ray_results['sdist']
|
153 |
+
wp = ray_results['weights']
|
154 |
+
c_, w_ = stepfun.blur_stepfun(c, w_normalize, config.pulse_width[i])
|
155 |
+
|
156 |
+
# piecewise linear pdf to piecewise quadratic cdf
|
157 |
+
area = 0.5 * (w_[..., 1:] + w_[..., :-1]) * (c_[..., 1:] - c_[..., :-1])
|
158 |
+
|
159 |
+
cdf = torch.cat([torch.zeros_like(area[..., :1]), torch.cumsum(area, dim=-1)], dim=-1)
|
160 |
+
|
161 |
+
# query piecewise quadratic interpolation
|
162 |
+
cdf_interp = math.sorted_interp_quad(cp, c_, w_, cdf)
|
163 |
+
# difference between adjacent interpolated values
|
164 |
+
w_s = torch.diff(cdf_interp, dim=-1)
|
165 |
+
|
166 |
+
loss_anti_interlevel += ((w_s - wp).clamp_min(0) ** 2 / (wp + 1e-5)).mean()
|
167 |
+
return config.anti_interlevel_loss_mult * loss_anti_interlevel
|
168 |
+
|
169 |
+
|
170 |
+
def distortion_loss(ray_history, config):
|
171 |
+
"""Computes the distortion loss regularizer defined in mip-NeRF 360."""
|
172 |
+
last_ray_results = ray_history[-1]
|
173 |
+
c = last_ray_results['sdist']
|
174 |
+
w = last_ray_results['weights']
|
175 |
+
loss = stepfun.lossfun_distortion(c, w).mean()
|
176 |
+
return config.distortion_loss_mult * loss
|
177 |
+
|
178 |
+
|
179 |
+
def orientation_loss(batch, model, ray_history, config):
|
180 |
+
"""Computes the orientation loss regularizer defined in ref-NeRF."""
|
181 |
+
total_loss = 0.
|
182 |
+
for i, ray_results in enumerate(ray_history):
|
183 |
+
w = ray_results['weights']
|
184 |
+
n = ray_results[config.orientation_loss_target]
|
185 |
+
if n is None:
|
186 |
+
raise ValueError('Normals cannot be None if orientation loss is on.')
|
187 |
+
# Negate viewdirs to represent normalized vectors from point to camera.
|
188 |
+
v = -1. * batch['viewdirs']
|
189 |
+
n_dot_v = (n * v[..., None, :]).sum(dim=-1)
|
190 |
+
loss = (w * n_dot_v.clamp_min(0) ** 2).sum(dim=-1).mean()
|
191 |
+
if i < model.num_levels - 1:
|
192 |
+
total_loss += config.orientation_coarse_loss_mult * loss
|
193 |
+
else:
|
194 |
+
total_loss += config.orientation_loss_mult * loss
|
195 |
+
return total_loss
|
196 |
+
|
197 |
+
|
198 |
+
def hash_decay_loss(ray_history, config):
|
199 |
+
total_loss = 0.
|
200 |
+
for i, ray_results in enumerate(ray_history):
|
201 |
+
total_loss += config.hash_decay_mults * ray_results['loss_hash_decay']
|
202 |
+
return total_loss
|
203 |
+
|
204 |
+
|
205 |
+
def opacity_loss(renderings, config):
|
206 |
+
total_loss = 0.
|
207 |
+
for i, rendering in enumerate(renderings):
|
208 |
+
o = rendering['acc']
|
209 |
+
total_loss += config.opacity_loss_mult * (-o * torch.log(o + 1e-5)).mean()
|
210 |
+
return total_loss
|
211 |
+
|
212 |
+
|
213 |
+
def predicted_normal_loss(model, ray_history, config):
|
214 |
+
"""Computes the predicted normal supervision loss defined in ref-NeRF."""
|
215 |
+
total_loss = 0.
|
216 |
+
for i, ray_results in enumerate(ray_history):
|
217 |
+
w = ray_results['weights']
|
218 |
+
n = ray_results['normals']
|
219 |
+
n_pred = ray_results['normals_pred']
|
220 |
+
if n is None or n_pred is None:
|
221 |
+
raise ValueError(
|
222 |
+
'Predicted normals and gradient normals cannot be None if '
|
223 |
+
'predicted normal loss is on.')
|
224 |
+
loss = torch.mean((w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1))
|
225 |
+
if i < model.num_levels - 1:
|
226 |
+
total_loss += config.predicted_normal_coarse_loss_mult * loss
|
227 |
+
else:
|
228 |
+
total_loss += config.predicted_normal_loss_mult * loss
|
229 |
+
return total_loss
|
230 |
+
|
231 |
+
|
232 |
+
def clip_gradients(model, accelerator, config):
|
233 |
+
"""Clips gradients of MLP based on norm and max value."""
|
234 |
+
if config.grad_max_norm > 0 and accelerator.sync_gradients:
|
235 |
+
accelerator.clip_grad_norm_(model.parameters(), config.grad_max_norm)
|
236 |
+
|
237 |
+
if config.grad_max_val > 0 and accelerator.sync_gradients:
|
238 |
+
accelerator.clip_grad_value_(model.parameters(), config.grad_max_val)
|
239 |
+
|
240 |
+
for param in model.parameters():
|
241 |
+
param.grad.nan_to_num_()
|
242 |
+
|
243 |
+
|
244 |
+
def create_optimizer(config: configs.Config, model):
|
245 |
+
"""Creates optax optimizer for model training."""
|
246 |
+
adam_kwargs = {
|
247 |
+
'betas': [config.adam_beta1, config.adam_beta2],
|
248 |
+
'eps': config.adam_eps,
|
249 |
+
}
|
250 |
+
lr_kwargs = {
|
251 |
+
'max_steps': config.max_steps,
|
252 |
+
'lr_delay_steps': config.lr_delay_steps,
|
253 |
+
'lr_delay_mult': config.lr_delay_mult,
|
254 |
+
}
|
255 |
+
|
256 |
+
lr_fn_main = lambda step: math.learning_rate_decay(
|
257 |
+
step,
|
258 |
+
lr_init=config.lr_init,
|
259 |
+
lr_final=config.lr_final,
|
260 |
+
**lr_kwargs)
|
261 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init, **adam_kwargs)
|
262 |
+
|
263 |
+
return optimizer, lr_fn_main
|