sseung0703
commited on
Commit
•
e8c4ed3
1
Parent(s):
8efbff1
update
Browse files- LICENSE +201 -0
- README.md +21 -40
- __init__.py +15 -0
- assets/chair.png +0 -0
- assets/drawing.png +0 -0
- assets/drum.png +0 -0
- assets/hotdog.png +0 -0
- configs/blender.yaml +9 -0
- configs/demo.yaml +10 -0
- configs/diet_nerf_tpu_vm_4shot.yaml +18 -0
- configs/diet_nerf_tpu_vm_few_shot.yaml +18 -0
- configs/diet_nerf_tpu_vm_test.yaml +18 -0
- configs/eval_diet_nerf_tpu_vm_few_shot.yaml +22 -0
- configs/nerf_tpu_vm_4shot.yaml +18 -0
- configs/nerf_tpu_vm_few_shot.yaml +18 -0
- configs/orig_nerf_tpu_vm_full.yaml +13 -0
- configs/orig_nerf_tpu_vm_test.yaml +13 -0
- eval.py +241 -0
- eval.sh +44 -0
- example_data/imgs/r_0.png +0 -0
- example_data/transforms_test.json +1 -0
- example_data/transforms_train.json +1 -0
- fork-of-first-touch-of-nerf-in-jax.ipynb +0 -0
- nerf/__init__.py +15 -0
- nerf/__pycache__/__init__.cpython-37.pyc +0 -0
- nerf/__pycache__/clip_utils.cpython-37.pyc +0 -0
- nerf/__pycache__/datasets.cpython-37.pyc +0 -0
- nerf/__pycache__/model_utils.cpython-37.pyc +0 -0
- nerf/__pycache__/models.cpython-37.pyc +0 -0
- nerf/__pycache__/utils.cpython-37.pyc +0 -0
- nerf/clip_utils.py +125 -0
- nerf/datasets.py +558 -0
- nerf/model_utils.py +334 -0
- nerf/models.py +264 -0
- nerf/utils.py +454 -0
- requirements.txt +14 -0
- run.sh +33 -0
- train.py +347 -0
- train.sh +34 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation
|
2 |
-
|
3 |
<p align="center"><img width="450" alt="스크린샷 2021-07-04 오후 4 11 51" src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png"></p>
|
4 |
|
5 |
the Pytorch, JAX/Flax based code implementation of this paper [Putting NeRF on a Diet : Ajay Jain, Matthew Tancik, Pieter Abbeel, Arxiv : https://arxiv.org/abs/2104.00677]
|
@@ -94,57 +94,38 @@ python -m train \
|
|
94 |
```
|
95 |
You can toggle the semantic loss by “use_semantic_loss” in configuration files.
|
96 |
|
97 |
-
## 💎
|
98 |
-
|
99 |
-
### ❗️ Performance Tables
|
100 |
-
#### 4 Shot Blender Dataset PSNR Result
|
101 |
-
|
102 |
-
| Scene | Chair | Drums | Ficus | Hotdog | Lego | Materials | Mic | Ship | Mean |
|
103 |
-
|---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
|
104 |
-
| NeRF | 33.00 | 25.01 | 30.13 | 36.18 | 32.54 | 29.62 | 32.91 | 28.65 | 31.01 |
|
105 |
-
| DietNeRF | **34.08** | **25.03** | **30.43** | **36.92** | **33.28** | **29.91** | **34.53** | **29.36** | **31.69** |
|
106 |
|
107 |
-
#### Loss Graph Comparison btw NeRF vs DietNeRF in Drum Scene
|
108 |
|
109 |
-
|
110 |
-
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
### ❗ Rendering GIF images by
|
114 |
|
115 |
DietNeRF has a strong capacity to generalise on novel and challenging views with EXTREMELY SMALL TRAINING SAMPLES!
|
116 |
The animations below shows the performance difference between DietNeRF (left) v.s. NeRF (right) with only 4 training images:
|
117 |
|
118 |
-
#### SHIP
|
119 |
-
![Text](./assets/ship-dietnerf.gif) ![Alt Text](./assets/ship-nerf.gif)
|
120 |
-
|
121 |
-
#### LEGO
|
122 |
-
![Text](./assets/ship-dietnerf.gif) ![Alt Text](./assets/ship-nerf.gif)
|
123 |
-
|
124 |
-
#### HOTDOG
|
125 |
-
![Text](./assets/ship-dietnerf.gif) ![Alt Text](./assets/ship-nerf.gif)
|
126 |
|
127 |
|
128 |
-
### ❗ Rendered
|
|
|
|
|
|
|
129 |
|
130 |
#### SHIP
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
@ will be filled
|
138 |
-
|
139 |
-
### ❗ Rendered examples by occluded 14-shot learned NeRF and Diet-NeRF
|
140 |
-
This result is on the quite initial state and expected to be improved.
|
141 |
-
|
142 |
-
#### Training poses
|
143 |
-
<img width="1400" src="https://user-images.githubusercontent.com/26036843/126111980-4f332c87-a7f0-42e0-a355-8e77621bbca4.png">
|
144 |
-
|
145 |
-
#### Rendered novel poses
|
146 |
-
<img width="800" src="https://user-images.githubusercontent.com/26036843/126113080-a6a48f3d-2629-4efc-a740-fe908ca6b5c3.png">
|
147 |
-
|
148 |
|
149 |
## 🤩 Demo
|
150 |
|
|
|
1 |
# Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation
|
2 |
+
|
3 |
<p align="center"><img width="450" alt="스크린샷 2021-07-04 오후 4 11 51" src="https://user-images.githubusercontent.com/77657524/126361638-4aad58e8-4efb-4fc5-bf78-f53d03799e1e.png"></p>
|
4 |
|
5 |
the Pytorch, JAX/Flax based code implementation of this paper [Putting NeRF on a Diet : Ajay Jain, Matthew Tancik, Pieter Abbeel, Arxiv : https://arxiv.org/abs/2104.00677]
|
|
|
94 |
```
|
95 |
You can toggle the semantic loss by “use_semantic_loss” in configuration files.
|
96 |
|
97 |
+
## 💎 Expriment Result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
|
|
99 |
|
100 |
+
### ❗ Rendered Rendering images by 8-shot learned Diet-NeRF (200000 iter)
|
101 |
+
### CHAIR / HOTDOG / DRUM
|
102 |
|
103 |
+
<p align="center">
|
104 |
+
<table>
|
105 |
+
<tr>
|
106 |
+
<td><img alt="" src="./assets/chair.png" width="300"/></td><td><img alt="" src="./assets/hotdog.png" width="300"/></td><td><img alt="" src="./assets/drum.png" width="300"/></td>
|
107 |
+
<tr>
|
108 |
+
</table></p>
|
109 |
|
110 |
+
### ❗ Rendering GIF images by 4-shot learned Diet-NeRF and Diet-NeRF (50000 iter)
|
111 |
|
112 |
DietNeRF has a strong capacity to generalise on novel and challenging views with EXTREMELY SMALL TRAINING SAMPLES!
|
113 |
The animations below shows the performance difference between DietNeRF (left) v.s. NeRF (right) with only 4 training images:
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
|
117 |
+
### ❗ Rendered GIF by occluded 14-shot learned NeRF and Diet-NeRF (100000 iter)
|
118 |
+
We made aritificial occulusion on the right side of image.
|
119 |
+
The reconstruction quality can be compared with this experiment.
|
120 |
+
Diet NeRF shows better quailty than Original NeRF when It is occulused.
|
121 |
|
122 |
#### SHIP
|
123 |
+
<p align="center">
|
124 |
+
<table>
|
125 |
+
<tr>
|
126 |
+
<td><img alt="" src="./assets/ship-dietnerf.gif" width="300"/></td><td><img alt="" src="./assets/ship-nerf.gif" width="300"/></td>
|
127 |
+
<tr>
|
128 |
+
</table></p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
## 🤩 Demo
|
131 |
|
__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
assets/chair.png
ADDED
assets/drawing.png
ADDED
assets/drum.png
ADDED
assets/hotdog.png
ADDED
configs/blender.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 4096
|
9 |
+
randomized: true
|
configs/demo.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 50000
|
configs/diet_nerf_tpu_vm_4shot.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 500
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float16
|
17 |
+
sc_loss_every: 16
|
18 |
+
few_shot: 4
|
configs/diet_nerf_tpu_vm_few_shot.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 500
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float16
|
17 |
+
sc_loss_every: 16
|
18 |
+
few_shot: 8
|
configs/diet_nerf_tpu_vm_test.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 64
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1026
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 1000
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float16
|
17 |
+
sc_loss_every: 16
|
18 |
+
few_shot: -1
|
configs/eval_diet_nerf_tpu_vm_few_shot.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_factor: 4
|
18 |
+
sc_loss_every: 16
|
19 |
+
sc_loss_mult: 10
|
20 |
+
few_shot: 8
|
21 |
+
spherify: True
|
22 |
+
lindisp: True
|
configs/nerf_tpu_vm_4shot.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 500
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: false
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_every: 16
|
18 |
+
few_shot: 4
|
configs/nerf_tpu_vm_few_shot.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 500
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: false
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_every: 16
|
18 |
+
few_shot: 8
|
configs/orig_nerf_tpu_vm_full.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 1000
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
configs/orig_nerf_tpu_vm_test.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 200000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 500
|
13 |
+
save_every: 500
|
eval.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Evaluation script for Nerf."""
|
18 |
+
import math
|
19 |
+
import glob
|
20 |
+
import os
|
21 |
+
from os import path
|
22 |
+
import functools
|
23 |
+
|
24 |
+
from absl import app
|
25 |
+
from absl import flags
|
26 |
+
import flax
|
27 |
+
from flax.metrics import tensorboard
|
28 |
+
from flax.training import checkpoints
|
29 |
+
import jax
|
30 |
+
from jax import random
|
31 |
+
import tensorflow as tf
|
32 |
+
|
33 |
+
from tqdm import tqdm
|
34 |
+
import cv2
|
35 |
+
import numpy as np
|
36 |
+
from PIL import Image
|
37 |
+
|
38 |
+
from nerf import datasets
|
39 |
+
from nerf import models
|
40 |
+
from nerf import utils
|
41 |
+
from nerf import clip_utils
|
42 |
+
|
43 |
+
FLAGS = flags.FLAGS
|
44 |
+
utils.define_flags()
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
def compute_lpips(image1, image2, model):
|
49 |
+
"""Compute the LPIPS metric."""
|
50 |
+
# The LPIPS model expects a batch dimension.
|
51 |
+
return model(
|
52 |
+
tf.convert_to_tensor(image1[None, Ellipsis]),
|
53 |
+
tf.convert_to_tensor(image2[None, Ellipsis]))[0]
|
54 |
+
|
55 |
+
|
56 |
+
def predict_to_image(pred_out):
|
57 |
+
image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8)
|
58 |
+
return Image.fromarray(image_arr)
|
59 |
+
|
60 |
+
|
61 |
+
def main(unused_argv):
|
62 |
+
# Hide the GPUs and TPUs from TF so it does not reserve memory on them for
|
63 |
+
# LPIPS computation or dataset loading.
|
64 |
+
tf.config.experimental.set_visible_devices([], "GPU")
|
65 |
+
tf.config.experimental.set_visible_devices([], "TPU")
|
66 |
+
|
67 |
+
#wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
|
68 |
+
|
69 |
+
rng = random.PRNGKey(20200823)
|
70 |
+
|
71 |
+
if FLAGS.config is not None:
|
72 |
+
utils.update_flags(FLAGS)
|
73 |
+
if FLAGS.train_dir is None:
|
74 |
+
raise ValueError("train_dir must be set. None set now.")
|
75 |
+
if FLAGS.data_dir is None:
|
76 |
+
raise ValueError("data_dir must be set. None set now.")
|
77 |
+
|
78 |
+
dataset = datasets.get_dataset("test", FLAGS)
|
79 |
+
rng, key = random.split(rng)
|
80 |
+
model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
|
81 |
+
optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
|
82 |
+
state = utils.TrainState(optimizer=optimizer)
|
83 |
+
del optimizer, init_variables
|
84 |
+
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
|
85 |
+
|
86 |
+
# Rendering is forced to be deterministic even if training was randomized, as
|
87 |
+
# this eliminates "speckle" artifacts.
|
88 |
+
def render_fn(variables, key_0, key_1, rays):
|
89 |
+
return model.apply(variables, key_0, key_1, rays, False)
|
90 |
+
|
91 |
+
# pmap over only the data input.
|
92 |
+
render_pfn = jax.pmap(
|
93 |
+
render_fn,
|
94 |
+
in_axes=(None, None, None, 0),
|
95 |
+
donate_argnums=3,
|
96 |
+
axis_name="batch",
|
97 |
+
)
|
98 |
+
|
99 |
+
# Compiling to the CPU because it's faster and more accurate.
|
100 |
+
ssim_fn = jax.jit(
|
101 |
+
functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
|
102 |
+
|
103 |
+
last_step = 0
|
104 |
+
out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds")
|
105 |
+
os.makedirs(out_dir, exist_ok=True)
|
106 |
+
if FLAGS.save_output:
|
107 |
+
print(f'eval output will be saved: {out_dir}')
|
108 |
+
else:
|
109 |
+
print(f'eval output will not be saved')
|
110 |
+
|
111 |
+
if not FLAGS.eval_once:
|
112 |
+
summary_writer = tensorboard.SummaryWriter(
|
113 |
+
path.join(FLAGS.train_dir, "eval"))
|
114 |
+
|
115 |
+
def generate_spinning_gif(radius, phi, gif_fn, frame_n):
|
116 |
+
_rng = random.PRNGKey(0)
|
117 |
+
partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
|
118 |
+
gif_images = []
|
119 |
+
for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
|
120 |
+
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
|
121 |
+
rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
|
122 |
+
_rng, key0, key1 = random.split(_rng, 3)
|
123 |
+
color, _, _ = utils.render_image(partial_render_fn, rays,
|
124 |
+
_rng, False, chunk=4096)
|
125 |
+
image = predict_to_image(color)
|
126 |
+
gif_images.append(image)
|
127 |
+
gif_images[0].save(gif_fn, save_all=True,
|
128 |
+
append_images=gif_images,
|
129 |
+
duration=100, loop=0)
|
130 |
+
return gif_images
|
131 |
+
|
132 |
+
if FLAGS.generate_gif_only:
|
133 |
+
print('generate GIF file only')
|
134 |
+
_radius = 4.
|
135 |
+
_phi = (30 * math.pi) / 180
|
136 |
+
_gif_fn = os.path.join(out_dir, 'spinning.gif')
|
137 |
+
generate_spinning_gif(_radius, _phi, _gif_fn, frame_n=30)
|
138 |
+
print(f'GIF file for spinning views written: {_gif_fn}')
|
139 |
+
return
|
140 |
+
else:
|
141 |
+
print('generate GIF file AND evaluate model performance')
|
142 |
+
|
143 |
+
is_gif_written = False
|
144 |
+
while True:
|
145 |
+
step = int(state.optimizer.state.step)
|
146 |
+
if step <= last_step:
|
147 |
+
continue
|
148 |
+
if FLAGS.save_output and (not utils.isdir(out_dir)):
|
149 |
+
utils.makedirs(out_dir)
|
150 |
+
psnr_values = []
|
151 |
+
ssim_values = []
|
152 |
+
#lpips_values = []
|
153 |
+
if not FLAGS.eval_once:
|
154 |
+
showcase_index = np.random.randint(0, dataset.size)
|
155 |
+
for idx in range(dataset.size):
|
156 |
+
print(f"Evaluating {idx + 1}/{dataset.size}")
|
157 |
+
batch = next(dataset)
|
158 |
+
pred_color, pred_disp, pred_acc = utils.render_image(
|
159 |
+
functools.partial(render_pfn, state.optimizer.target),
|
160 |
+
batch["rays"],
|
161 |
+
rng,
|
162 |
+
FLAGS.dataset == "llff",
|
163 |
+
chunk=FLAGS.chunk)
|
164 |
+
if jax.host_id() != 0: # Only record via host 0.
|
165 |
+
continue
|
166 |
+
if not FLAGS.eval_once and idx == showcase_index:
|
167 |
+
showcase_color = pred_color
|
168 |
+
showcase_disp = pred_disp
|
169 |
+
showcase_acc = pred_acc
|
170 |
+
if not FLAGS.render_path:
|
171 |
+
showcase_gt = batch["pixels"]
|
172 |
+
if not FLAGS.render_path:
|
173 |
+
psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean())
|
174 |
+
ssim = ssim_fn(pred_color, batch["pixels"])
|
175 |
+
#lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
|
176 |
+
print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
|
177 |
+
psnr_values.append(float(psnr))
|
178 |
+
ssim_values.append(float(ssim))
|
179 |
+
#lpips_values.append(float(lpips))
|
180 |
+
if FLAGS.save_output:
|
181 |
+
utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
|
182 |
+
utils.save_img(pred_disp[Ellipsis, 0],
|
183 |
+
path.join(out_dir, "disp_{:03d}.png".format(idx)))
|
184 |
+
if (not FLAGS.eval_once) and (jax.host_id() == 0):
|
185 |
+
summary_writer.image("pred_color", showcase_color, step)
|
186 |
+
summary_writer.image("pred_disp", showcase_disp, step)
|
187 |
+
summary_writer.image("pred_acc", showcase_acc, step)
|
188 |
+
if not FLAGS.render_path:
|
189 |
+
summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
|
190 |
+
summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
|
191 |
+
#summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
|
192 |
+
summary_writer.image("target", showcase_gt, step)
|
193 |
+
|
194 |
+
if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
|
195 |
+
with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
|
196 |
+
f.write(" ".join([str(v) for v in psnr_values]))
|
197 |
+
with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
|
198 |
+
f.write(" ".join([str(v) for v in ssim_values]))
|
199 |
+
#with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
|
200 |
+
#f.write(" ".join([str(v) for v in lpips_values]))
|
201 |
+
with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
|
202 |
+
f.write("{}".format(np.mean(np.array(psnr_values))))
|
203 |
+
with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
|
204 |
+
f.write("{}".format(np.mean(np.array(ssim_values))))
|
205 |
+
#with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
|
206 |
+
#f.write("{}".format(np.mean(np.array(lpips_values))))
|
207 |
+
print(f'performance metrics written as txt files: {out_dir}')
|
208 |
+
|
209 |
+
imglist = glob.glob(os.path.join(out_dir, "[0-9][0-9][0-9].png"))
|
210 |
+
sorted_files = sorted(imglist, key=lambda x: int(x.split('/')[-1].split('.')[0]))
|
211 |
+
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
212 |
+
fps = 10.0
|
213 |
+
img = cv2.imread(sorted_files[0], cv2.IMREAD_COLOR)
|
214 |
+
video_fn = os.path.join(out_dir, "rendering_video.mp4")
|
215 |
+
out = cv2.VideoWriter(video_fn, fourcc, fps,
|
216 |
+
(img.shape[1], img.shape[0]))
|
217 |
+
|
218 |
+
for i in range(len(sorted_files)):
|
219 |
+
img = cv2.imread(sorted_files[i], cv2.IMREAD_COLOR)
|
220 |
+
out.write(img)
|
221 |
+
out.release()
|
222 |
+
print(f'video file written: {video_fn}')
|
223 |
+
|
224 |
+
# write gif file for spinning views of a scene
|
225 |
+
if not is_gif_written:
|
226 |
+
_radius = 4.
|
227 |
+
_phi = (30 * math.pi) / 180
|
228 |
+
_gif_fn = os.path.join(out_dir, 'spinning.gif')
|
229 |
+
generate_spinning_gif(_radius, _phi, _gif_fn, frame_n=30)
|
230 |
+
print(f'GIF file for spinning views written: {_gif_fn}')
|
231 |
+
is_gif_written = True
|
232 |
+
|
233 |
+
if FLAGS.eval_once:
|
234 |
+
break
|
235 |
+
if int(step) >= FLAGS.max_steps:
|
236 |
+
break
|
237 |
+
last_step = step
|
238 |
+
|
239 |
+
|
240 |
+
if __name__ == "__main__":
|
241 |
+
app.run(main)
|
eval.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Google Research Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
#!/bin/bash
|
16 |
+
CONFIG=$1
|
17 |
+
DATA_ROOT=$2
|
18 |
+
ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
|
19 |
+
if [ $CONFIG == "llff" ]
|
20 |
+
then
|
21 |
+
SCENES="room fern leaves fortress orchids flower trex horns"
|
22 |
+
DATA_FOLDER="nerf_llff_data"
|
23 |
+
else
|
24 |
+
SCENES="lego chair drums ficus hotdog materials mic ship"
|
25 |
+
DATA_FOLDER="nerf_synthetic"
|
26 |
+
fi
|
27 |
+
|
28 |
+
# launch evaluation jobs for all scenes.
|
29 |
+
for scene in $SCENES; do
|
30 |
+
python -m jaxnerf.eval \
|
31 |
+
--data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
|
32 |
+
--train_dir="$ROOT_DIR"/"$scene" \
|
33 |
+
--chunk=4096 \
|
34 |
+
--config=configs/"$CONFIG"
|
35 |
+
done
|
36 |
+
|
37 |
+
# collect PSNR of all scenes.
|
38 |
+
touch "$ROOT_DIR"/psnr.txt
|
39 |
+
for scene in $SCENES; do
|
40 |
+
printf "${scene}: " >> "$ROOT_DIR"/psnr.txt
|
41 |
+
cat "$ROOT_DIR"/"$scene"/test_preds/psnr.txt >> \
|
42 |
+
"$ROOT_DIR"/psnr.txt
|
43 |
+
printf $'\n' >> "$ROOT_DIR"/psnr.txt
|
44 |
+
done
|
example_data/imgs/r_0.png
ADDED
example_data/transforms_test.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
|
example_data/transforms_train.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
|
fork-of-first-touch-of-nerf-in-jax.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
nerf/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
nerf/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (137 Bytes). View file
|
|
nerf/__pycache__/clip_utils.cpython-37.pyc
ADDED
Binary file (5.16 kB). View file
|
|
nerf/__pycache__/datasets.cpython-37.pyc
ADDED
Binary file (18.3 kB). View file
|
|
nerf/__pycache__/model_utils.cpython-37.pyc
ADDED
Binary file (10 kB). View file
|
|
nerf/__pycache__/models.cpython-37.pyc
ADDED
Binary file (5.08 kB). View file
|
|
nerf/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (15.8 kB). View file
|
|
nerf/clip_utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
from absl import flags
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import jax
|
7 |
+
from jax import random
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import numpy as np
|
10 |
+
from transformers import FlaxCLIPModel
|
11 |
+
|
12 |
+
from nerf import utils
|
13 |
+
|
14 |
+
FLAGS = flags.FLAGS
|
15 |
+
|
16 |
+
@partial(jax.jit, static_argnums=[0])
|
17 |
+
def semantic_loss(clip_model, src_image, target_embedding):
|
18 |
+
#c_image = utils.unshard(src_image[0])
|
19 |
+
f_image = utils.unshard(src_image[-1])
|
20 |
+
|
21 |
+
w = int(math.sqrt(src_image[-1].size//3))
|
22 |
+
#c_image = c_image.reshape([w, w, 3])
|
23 |
+
f_image = f_image.reshape([w, w, 3])
|
24 |
+
|
25 |
+
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
|
26 |
+
#src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
|
27 |
+
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
28 |
+
sc_loss = 1 - jnp.sum(src_embedding * target_embedding)
|
29 |
+
return sc_loss, f_image
|
30 |
+
|
31 |
+
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
|
32 |
+
random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"])
|
33 |
+
target_embedding = batch["embedding"].astype(jnp.float16)
|
34 |
+
rng, key_0, key_1 = random.split(rng,3)
|
35 |
+
|
36 |
+
def loss_fn(variables):
|
37 |
+
src_image = render_pfn(variables, key_0, key_1, random_rays)
|
38 |
+
sc_loss, f_image = semantic_loss(clip_model, src_image, target_embedding)
|
39 |
+
return sc_loss * FLAGS.sc_loss_mult, f_image
|
40 |
+
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
41 |
+
return sc_loss, grad, src_image
|
42 |
+
|
43 |
+
@partial(jax.jit, static_argnums=[0, 1])
|
44 |
+
def semantic_step_single(model, clip_model, rng, state, batch, lr):
|
45 |
+
batch = jax.tree_map(lambda x: x.astype(jnp.float16), batch)
|
46 |
+
# the batch is without shard
|
47 |
+
random_rays = batch["random_rays"]
|
48 |
+
rng, key_0, key_1 = random.split(rng,3)
|
49 |
+
|
50 |
+
def semantic_loss(variables):
|
51 |
+
c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
|
52 |
+
# reshape flat pixel to an image (assume 3 channels & square shape)
|
53 |
+
w = int(math.sqrt(f_image.shape[0]))
|
54 |
+
# c_image = c_image.reshape([w, w, 3])
|
55 |
+
f_image = f_image.reshape([w, w, 3])
|
56 |
+
|
57 |
+
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
|
58 |
+
# src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
|
59 |
+
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
60 |
+
target_embedding = batch["embedding"]
|
61 |
+
sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding)**2)
|
62 |
+
return sc_loss * FLAGS.sc_loss_mult, f_image
|
63 |
+
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
64 |
+
return sc_loss, grad, src_image
|
65 |
+
|
66 |
+
def trans_t(t):
|
67 |
+
return jnp.array([
|
68 |
+
[1, 0, 0, 0],
|
69 |
+
[0, 1, 0, 0],
|
70 |
+
[0, 0, 1, t],
|
71 |
+
[0, 0, 0, 1]], dtype=jnp.float32)
|
72 |
+
|
73 |
+
def rot_phi(phi):
|
74 |
+
return jnp.array([
|
75 |
+
[1, 0, 0, 0],
|
76 |
+
[0, jnp.cos(phi), jnp.sin(phi), 0],
|
77 |
+
[0,-jnp.sin(phi), jnp.cos(phi), 0],
|
78 |
+
[0, 0, 0, 1]], dtype=jnp.float32)
|
79 |
+
|
80 |
+
def rot_theta(th):
|
81 |
+
return jnp.array([
|
82 |
+
[jnp.cos(th), 0,-jnp.sin(th), 0],
|
83 |
+
[0, 1, 0, 0],
|
84 |
+
[jnp.sin(th), 0, jnp.cos(th), 0],
|
85 |
+
[0, 0, 0, 1]], dtype=jnp.float32)
|
86 |
+
|
87 |
+
def pose_spherical(radius, theta, phi):
|
88 |
+
c2w = trans_t(radius)
|
89 |
+
c2w = rot_phi(phi) @ c2w
|
90 |
+
c2w = rot_theta(theta) @ c2w
|
91 |
+
c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
|
92 |
+
return c2w
|
93 |
+
|
94 |
+
def random_pose(rng, bds):
|
95 |
+
rng, *rng_inputs = jax.random.split(rng, 3)
|
96 |
+
radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
|
97 |
+
theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi)
|
98 |
+
phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2)
|
99 |
+
return pose_spherical(radius, theta, phi)
|
100 |
+
|
101 |
+
def preprocess_for_CLIP(image):
|
102 |
+
"""
|
103 |
+
jax-based preprocessing for CLIP
|
104 |
+
image [B, 3, H, W]: batch image
|
105 |
+
return [B, 3, 224, 224]: pre-processed image for CLIP
|
106 |
+
"""
|
107 |
+
B, D, H, W = image.shape
|
108 |
+
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
|
109 |
+
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
|
110 |
+
image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
|
111 |
+
image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
|
112 |
+
return image
|
113 |
+
|
114 |
+
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
|
115 |
+
if dtype == 'float16':
|
116 |
+
dtype = jnp.float16
|
117 |
+
elif dtype == 'float32':
|
118 |
+
dtype = jnp.float32
|
119 |
+
else:
|
120 |
+
raise ValueError
|
121 |
+
|
122 |
+
if model_name is None:
|
123 |
+
model_name = 'openai/clip-vit-base-patch32'
|
124 |
+
|
125 |
+
return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)
|
nerf/datasets.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Different datasets implementation plus a general port for all the datasets."""
|
18 |
+
INTERNAL = False # pylint: disable=g-statement-before-imports
|
19 |
+
import json
|
20 |
+
import os, time
|
21 |
+
from os import path
|
22 |
+
import queue
|
23 |
+
import threading
|
24 |
+
|
25 |
+
if not INTERNAL:
|
26 |
+
import cv2 # pylint: disable=g-import-not-at-top
|
27 |
+
import jax
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image
|
30 |
+
|
31 |
+
from nerf import utils
|
32 |
+
from nerf import clip_utils
|
33 |
+
|
34 |
+
def get_dataset(split, args, clip_model = None):
|
35 |
+
return dataset_dict[args.dataset](split, args, clip_model)
|
36 |
+
|
37 |
+
|
38 |
+
def convert_to_ndc(origins, directions, focal, w, h, near=1.):
|
39 |
+
"""Convert a set of rays to NDC coordinates."""
|
40 |
+
# Shift ray origins to near plane
|
41 |
+
t = -(near + origins[..., 2]) / directions[..., 2]
|
42 |
+
origins = origins + t[..., None] * directions
|
43 |
+
|
44 |
+
dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
|
45 |
+
ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
|
46 |
+
|
47 |
+
# Projection
|
48 |
+
o0 = -((2 * focal) / w) * (ox / oz)
|
49 |
+
o1 = -((2 * focal) / h) * (oy / oz)
|
50 |
+
o2 = 1 + 2 * near / oz
|
51 |
+
|
52 |
+
d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
|
53 |
+
d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
|
54 |
+
d2 = -2 * near / oz
|
55 |
+
|
56 |
+
origins = np.stack([o0, o1, o2], -1)
|
57 |
+
directions = np.stack([d0, d1, d2], -1)
|
58 |
+
return origins, directions
|
59 |
+
|
60 |
+
|
61 |
+
class Dataset(threading.Thread):
|
62 |
+
"""Dataset Base Class."""
|
63 |
+
|
64 |
+
def __init__(self, split, flags, clip_model):
|
65 |
+
super(Dataset, self).__init__()
|
66 |
+
self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches.
|
67 |
+
self.daemon = True
|
68 |
+
self.use_pixel_centers = flags.use_pixel_centers
|
69 |
+
self.split = split
|
70 |
+
|
71 |
+
if split == "train":
|
72 |
+
self._train_init(flags, clip_model)
|
73 |
+
elif split == "test":
|
74 |
+
self._test_init(flags)
|
75 |
+
else:
|
76 |
+
raise ValueError(
|
77 |
+
"the split argument should be either \"train\" or \"test\", set"
|
78 |
+
"to {} here.".format(split))
|
79 |
+
self.batch_size = flags.batch_size // jax.process_count()
|
80 |
+
self.batching = flags.batching
|
81 |
+
self.render_path = flags.render_path
|
82 |
+
self.far = flags.far
|
83 |
+
self.near = flags.near
|
84 |
+
self.max_steps = flags.max_steps
|
85 |
+
self.start()
|
86 |
+
|
87 |
+
def __iter__(self):
|
88 |
+
return self
|
89 |
+
|
90 |
+
def __next__(self):
|
91 |
+
"""Get the next training batch or test example.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
batch: dict, has "pixels" and "rays".
|
95 |
+
"""
|
96 |
+
x = self.queue.get()
|
97 |
+
if self.split == "train":
|
98 |
+
return utils.shard(x)
|
99 |
+
else:
|
100 |
+
return utils.to_device(x)
|
101 |
+
|
102 |
+
def peek(self):
|
103 |
+
"""Peek at the next training batch or test example without dequeuing it.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
batch: dict, has "pixels" and "rays".
|
107 |
+
"""
|
108 |
+
x = self.queue.queue[0].copy() # Make a copy of the front of the queue.
|
109 |
+
if self.split == "train":
|
110 |
+
return utils.shard(x)
|
111 |
+
else:
|
112 |
+
return utils.to_device(x)
|
113 |
+
|
114 |
+
def run(self):
|
115 |
+
if self.split == "train":
|
116 |
+
next_func = self._next_train
|
117 |
+
else:
|
118 |
+
next_func = self._next_test
|
119 |
+
while True:
|
120 |
+
self.queue.put(next_func())
|
121 |
+
|
122 |
+
@property
|
123 |
+
def size(self):
|
124 |
+
return self.n_examples
|
125 |
+
|
126 |
+
def _train_init(self, flags, clip_model):
|
127 |
+
"""Initialize training."""
|
128 |
+
self._load_renderings(flags, clip_model)
|
129 |
+
self._generate_rays()
|
130 |
+
|
131 |
+
if flags.batching == "all_images":
|
132 |
+
# flatten the ray and image dimension together.
|
133 |
+
self.images = self.images.reshape([-1, 3])
|
134 |
+
self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
|
135 |
+
self.rays)
|
136 |
+
elif flags.batching == "single_image":
|
137 |
+
self.images = self.images.reshape([-1, self.resolution, 3])
|
138 |
+
self.rays = utils.namedtuple_map(
|
139 |
+
lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
|
140 |
+
else:
|
141 |
+
raise NotImplementedError(
|
142 |
+
f"{flags.batching} batching strategy is not implemented.")
|
143 |
+
|
144 |
+
def _test_init(self, flags):
|
145 |
+
self._load_renderings(flags, clip_model = None)
|
146 |
+
self._generate_rays()
|
147 |
+
self.it = 0
|
148 |
+
|
149 |
+
def _next_train(self):
|
150 |
+
"""Sample next training batch."""
|
151 |
+
|
152 |
+
if self.batching == "all_images":
|
153 |
+
ray_indices = np.random.randint(0, self.rays[0].shape[0],
|
154 |
+
(self.batch_size,))
|
155 |
+
batch_pixels = self.images[ray_indices]
|
156 |
+
batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays)
|
157 |
+
raise NotImplementedError("image_index not implemented for batching=all_images")
|
158 |
+
|
159 |
+
elif self.batching == "single_image":
|
160 |
+
image_index = np.random.randint(0, self.n_examples, ())
|
161 |
+
ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
|
162 |
+
(self.batch_size,))
|
163 |
+
batch_pixels = self.images[image_index][ray_indices]
|
164 |
+
batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices],
|
165 |
+
self.rays)
|
166 |
+
else:
|
167 |
+
raise NotImplementedError(
|
168 |
+
f"{self.batching} batching strategy is not implemented.")
|
169 |
+
return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index}
|
170 |
+
|
171 |
+
def _next_test(self):
|
172 |
+
"""Sample next test example."""
|
173 |
+
idx = self.it
|
174 |
+
self.it = (self.it + 1) % self.n_examples
|
175 |
+
|
176 |
+
if self.render_path:
|
177 |
+
return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)}
|
178 |
+
else:
|
179 |
+
return {"pixels": self.images[idx],
|
180 |
+
"rays": utils.namedtuple_map(lambda r: r[idx], self.rays),
|
181 |
+
"image_index": idx}
|
182 |
+
|
183 |
+
# TODO(bydeng): Swap this function with a more flexible camera model.
|
184 |
+
def _generate_rays(self):
|
185 |
+
"""Generating rays for all images."""
|
186 |
+
pixel_center = 0.5 if self.use_pixel_centers else 0.0
|
187 |
+
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
|
188 |
+
np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns)
|
189 |
+
np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows)
|
190 |
+
indexing="xy")
|
191 |
+
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
|
192 |
+
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
|
193 |
+
axis=-1)
|
194 |
+
directions = ((camera_dirs[None, ..., None, :] *
|
195 |
+
self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
|
196 |
+
origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
|
197 |
+
directions.shape)
|
198 |
+
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
199 |
+
self.rays = utils.Rays(
|
200 |
+
origins=origins, directions=directions, viewdirs=viewdirs)
|
201 |
+
|
202 |
+
def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1):
|
203 |
+
""" render one instance of rays given a camera to world matrix (4, 4) """
|
204 |
+
pixel_center = 0.5 if self.use_pixel_centers else 0.0
|
205 |
+
# TODO @Alex: apply mesh downsampling here
|
206 |
+
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
|
207 |
+
np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
|
208 |
+
np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
|
209 |
+
indexing="xy")
|
210 |
+
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
|
211 |
+
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
|
212 |
+
axis=-1)
|
213 |
+
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
|
214 |
+
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
|
215 |
+
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
216 |
+
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
|
217 |
+
|
218 |
+
class Blender(Dataset):
|
219 |
+
"""Blender Dataset."""
|
220 |
+
|
221 |
+
def _load_renderings(self, flags, clip_model = None):
|
222 |
+
"""Load images from disk."""
|
223 |
+
if flags.render_path:
|
224 |
+
raise ValueError("render_path cannot be used for the blender dataset.")
|
225 |
+
cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot)
|
226 |
+
|
227 |
+
self.images = np.stack(images, axis=0)
|
228 |
+
if flags.white_bkgd:
|
229 |
+
self.images = (self.images[..., :3] * self.images[..., -1:] +
|
230 |
+
(1. - self.images[..., -1:]))
|
231 |
+
else:
|
232 |
+
self.images = self.images[..., :3]
|
233 |
+
self.h, self.w = self.images.shape[1:3]
|
234 |
+
self.resolution = self.h * self.w
|
235 |
+
self.camtoworlds = np.stack(cams, axis=0)
|
236 |
+
camera_angle_x = float(meta["camera_angle_x"])
|
237 |
+
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
|
238 |
+
self.n_examples = self.images.shape[0]
|
239 |
+
|
240 |
+
if flags.use_semantic_loss and clip_model is not None:
|
241 |
+
embs = []
|
242 |
+
for img in self.images:
|
243 |
+
img = np.expand_dims(np.transpose(img,[2,0,1]), 0)
|
244 |
+
emb = clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img))
|
245 |
+
embs.append( emb/np.linalg.norm(emb) )
|
246 |
+
self.embeddings = np.concatenate(embs, 0)
|
247 |
+
|
248 |
+
self.image_idx = np.arange(self.images.shape[0])
|
249 |
+
np.random.shuffle(self.image_idx)
|
250 |
+
self.image_idx = self.image_idx.tolist()
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def load_files(data_dir, split, factor, few_shot):
|
254 |
+
with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
|
255 |
+
meta = json.load(fp)
|
256 |
+
images = []
|
257 |
+
cams = []
|
258 |
+
|
259 |
+
frames = np.arange(len(meta["frames"]))
|
260 |
+
if few_shot > 0 and split == 'train':
|
261 |
+
np.random.seed(0)
|
262 |
+
np.random.shuffle(frames)
|
263 |
+
frames = frames[:few_shot]
|
264 |
+
|
265 |
+
# if split == 'train':
|
266 |
+
# frames = [2,5,10,40,52,53,69,78,83,85,90,94,96,97]
|
267 |
+
|
268 |
+
for i in frames:
|
269 |
+
frame = meta["frames"][i]
|
270 |
+
fname = os.path.join(data_dir, frame["file_path"] + ".png")
|
271 |
+
with utils.open_file(fname, "rb") as imgin:
|
272 |
+
image = np.array(Image.open(imgin)).astype(np.float32) / 255.
|
273 |
+
if factor == 2:
|
274 |
+
[halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]]
|
275 |
+
image = cv2.resize(image, (halfres_w, halfres_h),
|
276 |
+
interpolation=cv2.INTER_AREA)
|
277 |
+
elif factor == 4:
|
278 |
+
[halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]]
|
279 |
+
image = cv2.resize(image, (halfres_w, halfres_h),
|
280 |
+
interpolation=cv2.INTER_AREA)
|
281 |
+
elif factor > 0:
|
282 |
+
raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} "
|
283 |
+
"set.".format(factor))
|
284 |
+
cams.append(np.array(frame["transform_matrix"], dtype=np.float32))
|
285 |
+
images.append(image)
|
286 |
+
|
287 |
+
print(f'No. of samples: {len(frames)}')
|
288 |
+
return cams, images, meta
|
289 |
+
|
290 |
+
def _next_train(self):
|
291 |
+
batch_dict = super(Blender, self)._next_train()
|
292 |
+
if self.batching == "single_image":
|
293 |
+
image_index = batch_dict.pop("image_index")
|
294 |
+
else:
|
295 |
+
raise NotImplementedError
|
296 |
+
return batch_dict
|
297 |
+
|
298 |
+
def get_clip_data(self):
|
299 |
+
if len(self.image_idx) == 0:
|
300 |
+
self.image_idx = np.arange(self.images.shape[0])
|
301 |
+
np.random.shuffle(self.image_idx)
|
302 |
+
self.image_idx = self.image_idx.tolist()
|
303 |
+
image_index = self.image_idx.pop()
|
304 |
+
|
305 |
+
batch_dict = {}
|
306 |
+
batch_dict["embedding"] = self.embeddings[image_index]
|
307 |
+
|
308 |
+
src_seed = int(time.time())
|
309 |
+
src_rng = jax.random.PRNGKey(src_seed)
|
310 |
+
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
311 |
+
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 4)
|
312 |
+
cx = np.random.randint(80, 120)
|
313 |
+
cy = np.random.randint(80, 120)
|
314 |
+
d = 70
|
315 |
+
random_rays = jax.tree_map(lambda x: x[cy-d:cy+d,cx-d:cx+d], random_rays)
|
316 |
+
w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
|
317 |
+
random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
|
318 |
+
batch_dict["random_rays"] = random_rays
|
319 |
+
return batch_dict
|
320 |
+
|
321 |
+
class LLFF(Dataset):
|
322 |
+
"""LLFF Dataset."""
|
323 |
+
|
324 |
+
def _load_renderings(self, flags):
|
325 |
+
"""Load images from disk."""
|
326 |
+
# Load images.
|
327 |
+
imgdir_suffix = ""
|
328 |
+
if flags.factor > 0:
|
329 |
+
imgdir_suffix = "_{}".format(flags.factor)
|
330 |
+
factor = flags.factor
|
331 |
+
else:
|
332 |
+
factor = 1
|
333 |
+
imgdir = path.join(flags.data_dir, "images" + imgdir_suffix)
|
334 |
+
if not utils.file_exists(imgdir):
|
335 |
+
raise ValueError("Image folder {} doesn't exist.".format(imgdir))
|
336 |
+
imgfiles = [
|
337 |
+
path.join(imgdir, f)
|
338 |
+
for f in sorted(utils.listdir(imgdir))
|
339 |
+
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
|
340 |
+
]
|
341 |
+
images = []
|
342 |
+
for imgfile in imgfiles:
|
343 |
+
with utils.open_file(imgfile, "rb") as imgin:
|
344 |
+
image = np.array(Image.open(imgin), dtype=np.float32) / 255.
|
345 |
+
images.append(image)
|
346 |
+
images = np.stack(images, axis=-1)
|
347 |
+
|
348 |
+
# Load poses and bds.
|
349 |
+
with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"),
|
350 |
+
"rb") as fp:
|
351 |
+
poses_arr = np.load(fp)
|
352 |
+
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
|
353 |
+
bds = poses_arr[:, -2:].transpose([1, 0])
|
354 |
+
if poses.shape[-1] != images.shape[-1]:
|
355 |
+
raise RuntimeError("Mismatch between imgs {} and poses {}".format(
|
356 |
+
images.shape[-1], poses.shape[-1]))
|
357 |
+
|
358 |
+
# Update poses according to downsampling.
|
359 |
+
poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
|
360 |
+
poses[2, 4, :] = poses[2, 4, :] * 1. / factor
|
361 |
+
|
362 |
+
# Correct rotation matrix ordering and move variable dim to axis 0.
|
363 |
+
poses = np.concatenate(
|
364 |
+
[poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
|
365 |
+
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
|
366 |
+
images = np.moveaxis(images, -1, 0)
|
367 |
+
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
|
368 |
+
|
369 |
+
# Rescale according to a default bd factor.
|
370 |
+
scale = 1. / (bds.min() * .75)
|
371 |
+
poses[:, :3, 3] *= scale
|
372 |
+
bds *= scale
|
373 |
+
|
374 |
+
# Recenter poses.
|
375 |
+
poses = self._recenter_poses(poses)
|
376 |
+
|
377 |
+
# Generate a spiral/spherical ray path for rendering videos.
|
378 |
+
if flags.spherify:
|
379 |
+
poses = self._generate_spherical_poses(poses, bds)
|
380 |
+
self.spherify = True
|
381 |
+
else:
|
382 |
+
self.spherify = False
|
383 |
+
if not flags.spherify and self.split == "test":
|
384 |
+
self._generate_spiral_poses(poses, bds)
|
385 |
+
|
386 |
+
# Select the split.
|
387 |
+
i_test = np.arange(images.shape[0])[::flags.llffhold]
|
388 |
+
i_train = np.array(
|
389 |
+
[i for i in np.arange(int(images.shape[0])) if i not in i_test])
|
390 |
+
if self.split == "train":
|
391 |
+
indices = i_train
|
392 |
+
else:
|
393 |
+
indices = i_test
|
394 |
+
images = images[indices]
|
395 |
+
poses = poses[indices]
|
396 |
+
|
397 |
+
self.images = images
|
398 |
+
self.camtoworlds = poses[:, :3, :4]
|
399 |
+
self.focal = poses[0, -1, -1]
|
400 |
+
self.h, self.w = images.shape[1:3]
|
401 |
+
self.resolution = self.h * self.w
|
402 |
+
if flags.render_path:
|
403 |
+
self.n_examples = self.render_poses.shape[0]
|
404 |
+
else:
|
405 |
+
self.n_examples = images.shape[0]
|
406 |
+
|
407 |
+
def _generate_rays(self):
|
408 |
+
"""Generate normalized device coordinate rays for llff."""
|
409 |
+
if self.split == "test":
|
410 |
+
n_render_poses = self.render_poses.shape[0]
|
411 |
+
self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds],
|
412 |
+
axis=0)
|
413 |
+
|
414 |
+
super()._generate_rays()
|
415 |
+
|
416 |
+
if not self.spherify:
|
417 |
+
ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins,
|
418 |
+
self.rays.directions,
|
419 |
+
self.focal, self.w, self.h)
|
420 |
+
self.rays = utils.Rays(
|
421 |
+
origins=ndc_origins,
|
422 |
+
directions=ndc_directions,
|
423 |
+
viewdirs=self.rays.viewdirs)
|
424 |
+
|
425 |
+
# Split poses from the dataset and generated poses
|
426 |
+
if self.split == "test":
|
427 |
+
self.camtoworlds = self.camtoworlds[n_render_poses:]
|
428 |
+
split = [np.split(r, [n_render_poses], 0) for r in self.rays]
|
429 |
+
split0, split1 = zip(*split)
|
430 |
+
self.render_rays = utils.Rays(*split0)
|
431 |
+
self.rays = utils.Rays(*split1)
|
432 |
+
|
433 |
+
def _recenter_poses(self, poses):
|
434 |
+
"""Recenter poses according to the original NeRF code."""
|
435 |
+
poses_ = poses.copy()
|
436 |
+
bottom = np.reshape([0, 0, 0, 1.], [1, 4])
|
437 |
+
c2w = self._poses_avg(poses)
|
438 |
+
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
|
439 |
+
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
|
440 |
+
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
|
441 |
+
poses = np.linalg.inv(c2w) @ poses
|
442 |
+
poses_[:, :3, :4] = poses[:, :3, :4]
|
443 |
+
poses = poses_
|
444 |
+
return poses
|
445 |
+
|
446 |
+
def _poses_avg(self, poses):
|
447 |
+
"""Average poses according to the original NeRF code."""
|
448 |
+
hwf = poses[0, :3, -1:]
|
449 |
+
center = poses[:, :3, 3].mean(0)
|
450 |
+
vec2 = self._normalize(poses[:, :3, 2].sum(0))
|
451 |
+
up = poses[:, :3, 1].sum(0)
|
452 |
+
c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1)
|
453 |
+
return c2w
|
454 |
+
|
455 |
+
def _viewmatrix(self, z, up, pos):
|
456 |
+
"""Construct lookat view matrix."""
|
457 |
+
vec2 = self._normalize(z)
|
458 |
+
vec1_avg = up
|
459 |
+
vec0 = self._normalize(np.cross(vec1_avg, vec2))
|
460 |
+
vec1 = self._normalize(np.cross(vec2, vec0))
|
461 |
+
m = np.stack([vec0, vec1, vec2, pos], 1)
|
462 |
+
return m
|
463 |
+
|
464 |
+
def _normalize(self, x):
|
465 |
+
"""Normalization helper function."""
|
466 |
+
return x / np.linalg.norm(x)
|
467 |
+
|
468 |
+
def _generate_spiral_poses(self, poses, bds):
|
469 |
+
"""Generate a spiral path for rendering."""
|
470 |
+
c2w = self._poses_avg(poses)
|
471 |
+
# Get average pose.
|
472 |
+
up = self._normalize(poses[:, :3, 1].sum(0))
|
473 |
+
# Find a reasonable "focus depth" for this dataset.
|
474 |
+
close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
|
475 |
+
dt = .75
|
476 |
+
mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
|
477 |
+
focal = mean_dz
|
478 |
+
# Get radii for spiral path.
|
479 |
+
tt = poses[:, :3, 3]
|
480 |
+
rads = np.percentile(np.abs(tt), 90, 0)
|
481 |
+
c2w_path = c2w
|
482 |
+
n_views = 120
|
483 |
+
n_rots = 2
|
484 |
+
# Generate poses for spiral path.
|
485 |
+
render_poses = []
|
486 |
+
rads = np.array(list(rads) + [1.])
|
487 |
+
hwf = c2w_path[:, 4:5]
|
488 |
+
zrate = .5
|
489 |
+
for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]:
|
490 |
+
c = np.dot(c2w[:3, :4], (np.array(
|
491 |
+
[np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads))
|
492 |
+
z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
|
493 |
+
render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1))
|
494 |
+
self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4]
|
495 |
+
|
496 |
+
def _generate_spherical_poses(self, poses, bds):
|
497 |
+
"""Generate a 360 degree spherical path for rendering."""
|
498 |
+
# pylint: disable=g-long-lambda
|
499 |
+
p34_to_44 = lambda p: np.concatenate([
|
500 |
+
p,
|
501 |
+
np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])
|
502 |
+
], 1)
|
503 |
+
rays_d = poses[:, :3, 2:3]
|
504 |
+
rays_o = poses[:, :3, 3:4]
|
505 |
+
|
506 |
+
def min_line_dist(rays_o, rays_d):
|
507 |
+
a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
|
508 |
+
b_i = -a_i @ rays_o
|
509 |
+
pt_mindist = np.squeeze(-np.linalg.inv(
|
510 |
+
(np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0))
|
511 |
+
return pt_mindist
|
512 |
+
|
513 |
+
pt_mindist = min_line_dist(rays_o, rays_d)
|
514 |
+
center = pt_mindist
|
515 |
+
up = (poses[:, :3, 3] - center).mean(0)
|
516 |
+
vec0 = self._normalize(up)
|
517 |
+
vec1 = self._normalize(np.cross([.1, .2, .3], vec0))
|
518 |
+
vec2 = self._normalize(np.cross(vec0, vec1))
|
519 |
+
pos = center
|
520 |
+
c2w = np.stack([vec1, vec2, vec0, pos], 1)
|
521 |
+
poses_reset = (
|
522 |
+
np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]))
|
523 |
+
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
|
524 |
+
sc = 1. / rad
|
525 |
+
poses_reset[:, :3, 3] *= sc
|
526 |
+
bds *= sc
|
527 |
+
rad *= sc
|
528 |
+
centroid = np.mean(poses_reset[:, :3, 3], 0)
|
529 |
+
zh = centroid[2]
|
530 |
+
radcircle = np.sqrt(rad ** 2 - zh ** 2)
|
531 |
+
new_poses = []
|
532 |
+
|
533 |
+
for th in np.linspace(0., 2. * np.pi, 120):
|
534 |
+
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
535 |
+
up = np.array([0, 0, -1.])
|
536 |
+
vec2 = self._normalize(camorigin)
|
537 |
+
vec0 = self._normalize(np.cross(vec2, up))
|
538 |
+
vec1 = self._normalize(np.cross(vec2, vec0))
|
539 |
+
pos = camorigin
|
540 |
+
p = np.stack([vec0, vec1, vec2, pos], 1)
|
541 |
+
new_poses.append(p)
|
542 |
+
|
543 |
+
new_poses = np.stack(new_poses, 0)
|
544 |
+
new_poses = np.concatenate([
|
545 |
+
new_poses,
|
546 |
+
np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)
|
547 |
+
], -1)
|
548 |
+
poses_reset = np.concatenate([
|
549 |
+
poses_reset[:, :3, :4],
|
550 |
+
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)
|
551 |
+
], -1)
|
552 |
+
if self.split == "test":
|
553 |
+
self.render_poses = new_poses[:, :3, :4]
|
554 |
+
return poses_reset
|
555 |
+
|
556 |
+
|
557 |
+
dataset_dict = {"blender": Blender,
|
558 |
+
"llff": LLFF}
|
nerf/model_utils.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Helper functions/classes for model definition."""
|
18 |
+
|
19 |
+
import functools
|
20 |
+
from typing import Any, Callable
|
21 |
+
|
22 |
+
from flax import linen as nn
|
23 |
+
import jax
|
24 |
+
from jax import lax
|
25 |
+
from jax import random
|
26 |
+
import jax.numpy as jnp
|
27 |
+
|
28 |
+
|
29 |
+
class MLP(nn.Module):
|
30 |
+
"""A simple MLP."""
|
31 |
+
net_depth: int = 8 # The depth of the first part of MLP.
|
32 |
+
net_width: int = 256 # The width of the first part of MLP.
|
33 |
+
net_depth_condition: int = 1 # The depth of the second part of MLP.
|
34 |
+
net_width_condition: int = 128 # The width of the second part of MLP.
|
35 |
+
net_activation: Callable[..., Any] = nn.relu # The activation function.
|
36 |
+
skip_layer: int = 4 # The layer to add skip layers to.
|
37 |
+
num_rgb_channels: int = 3 # The number of RGB channels.
|
38 |
+
num_sigma_channels: int = 1 # The number of sigma channels.
|
39 |
+
|
40 |
+
@nn.compact
|
41 |
+
def __call__(self, x, condition=None):
|
42 |
+
"""
|
43 |
+
Evaluate the MLP.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x: jnp.ndarray(float32), [batch, num_samples, feature], points.
|
47 |
+
condition: jnp.ndarray(float32), [batch, feature], if not None, this
|
48 |
+
variable will be part of the input to the second part of the MLP
|
49 |
+
concatenated with the output vector of the first part of the MLP. If
|
50 |
+
None, only the first part of the MLP will be used with input x. In the
|
51 |
+
original paper, this variable is the view direction.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
raw_rgb: jnp.ndarray(float32), with a shape of
|
55 |
+
[batch, num_samples, num_rgb_channels].
|
56 |
+
raw_sigma: jnp.ndarray(float32), with a shape of
|
57 |
+
[batch, num_samples, num_sigma_channels].
|
58 |
+
"""
|
59 |
+
feature_dim = x.shape[-1]
|
60 |
+
num_samples = x.shape[1]
|
61 |
+
x = x.reshape([-1, feature_dim])
|
62 |
+
dense_layer = functools.partial(
|
63 |
+
nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
|
64 |
+
inputs = x
|
65 |
+
|
66 |
+
dtype = x.dtype
|
67 |
+
for i in range(self.net_depth):
|
68 |
+
x = dense_layer(self.net_width, dtype = dtype)(x)
|
69 |
+
x = self.net_activation(x)
|
70 |
+
if i % self.skip_layer == 0 and i > 0:
|
71 |
+
x = jnp.concatenate([x, inputs], axis=-1)
|
72 |
+
raw_sigma = dense_layer(self.num_sigma_channels, dtype = dtype)(x).reshape(
|
73 |
+
[-1, num_samples, self.num_sigma_channels])
|
74 |
+
|
75 |
+
if condition is not None:
|
76 |
+
# Output of the first part of MLP.
|
77 |
+
bottleneck = dense_layer(self.net_width, dtype = dtype)(x)
|
78 |
+
# Broadcast condition from [batch, feature] to
|
79 |
+
# [batch, num_samples, feature] since all the samples along the same ray
|
80 |
+
# have the same viewdir.
|
81 |
+
condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))
|
82 |
+
# Collapse the [batch, num_samples, feature] tensor to
|
83 |
+
# [batch * num_samples, feature] so that it can be fed into nn.Dense.
|
84 |
+
condition = condition.reshape([-1, condition.shape[-1]])
|
85 |
+
x = jnp.concatenate([bottleneck, condition], axis=-1)
|
86 |
+
# Here use 1 extra layer to align with the original nerf model.
|
87 |
+
for i in range(self.net_depth_condition):
|
88 |
+
x = dense_layer(self.net_width_condition, dtype = dtype)(x)
|
89 |
+
x = self.net_activation(x)
|
90 |
+
raw_rgb = dense_layer(self.num_rgb_channels, dtype = dtype)(x).reshape(
|
91 |
+
[-1, num_samples, self.num_rgb_channels])
|
92 |
+
return raw_rgb, raw_sigma
|
93 |
+
|
94 |
+
|
95 |
+
def cast_rays(z_vals, origins, directions):
|
96 |
+
return origins[..., None, :] + z_vals[..., None] * directions[..., None, :]
|
97 |
+
|
98 |
+
|
99 |
+
def sample_along_rays(key, origins, directions, num_samples, near, far,
|
100 |
+
randomized, lindisp):
|
101 |
+
"""
|
102 |
+
Stratified sampling along the rays.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
key: jnp.ndarray, random generator key.
|
106 |
+
origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
|
107 |
+
directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
|
108 |
+
num_samples: int.
|
109 |
+
near: float, near clip.
|
110 |
+
far: float, far clip.
|
111 |
+
randomized: bool, use randomized stratified sampling.
|
112 |
+
lindisp: bool, sampling linearly in disparity rather than depth.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
|
116 |
+
points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
|
117 |
+
"""
|
118 |
+
batch_size = origins.shape[0]
|
119 |
+
|
120 |
+
dtype = origins.dtype
|
121 |
+
|
122 |
+
t_vals = jnp.linspace(0., 1., num_samples, dtype = dtype)
|
123 |
+
if lindisp:
|
124 |
+
z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
|
125 |
+
else:
|
126 |
+
z_vals = near * (1. - t_vals) + far * t_vals
|
127 |
+
|
128 |
+
if randomized:
|
129 |
+
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
130 |
+
upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
|
131 |
+
lower = jnp.concatenate([z_vals[..., :1], mids], -1)
|
132 |
+
t_rand = random.uniform(key, [batch_size, num_samples])
|
133 |
+
z_vals = lower + (upper - lower) * t_rand
|
134 |
+
else:
|
135 |
+
# Broadcast z_vals to make the returned shape consistent.
|
136 |
+
z_vals = jnp.broadcast_to(z_vals[None, ...], [batch_size, num_samples]).astype(dtype)
|
137 |
+
|
138 |
+
coords = cast_rays(z_vals, origins, directions)
|
139 |
+
return z_vals, coords
|
140 |
+
|
141 |
+
|
142 |
+
def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
|
143 |
+
"""
|
144 |
+
Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
145 |
+
|
146 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
147 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
|
151 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
152 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
153 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
encoded: jnp.ndarray, encoded variables.
|
157 |
+
"""
|
158 |
+
if min_deg == max_deg:
|
159 |
+
return x
|
160 |
+
|
161 |
+
dtype = x.dtype
|
162 |
+
|
163 |
+
scales = jnp.array([2 ** i for i in range(min_deg, max_deg)], dtype = dtype)
|
164 |
+
if legacy_posenc_order:
|
165 |
+
xb = x[..., None, :] * scales[:, None]
|
166 |
+
four_feat = jnp.reshape(
|
167 |
+
jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)),
|
168 |
+
list(x.shape[:-1]) + [-1])
|
169 |
+
else:
|
170 |
+
xb = jnp.reshape((x[..., None, :] * scales[:, None]),
|
171 |
+
list(x.shape[:-1]) + [-1])
|
172 |
+
four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
|
173 |
+
return jnp.concatenate([x] + [four_feat], axis=-1)
|
174 |
+
|
175 |
+
|
176 |
+
def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
|
177 |
+
"""
|
178 |
+
Volumetric Rendering Function.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
|
182 |
+
sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
|
183 |
+
z_vals: jnp.ndarray(float32), [batch_size, num_samples].
|
184 |
+
dirs: jnp.ndarray(float32), [batch_size, 3].
|
185 |
+
white_bkgd: bool.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
comp_rgb: jnp.ndarray(float32), [batch_size, 3].
|
189 |
+
disp: jnp.ndarray(float32), [batch_size].
|
190 |
+
acc: jnp.ndarray(float32), [batch_size].
|
191 |
+
weights: jnp.ndarray(float32), [batch_size, num_samples]
|
192 |
+
"""
|
193 |
+
dtype = rgb.dtype
|
194 |
+
|
195 |
+
eps = jnp.array(1e-10, dtype = dtype)
|
196 |
+
dists = jnp.concatenate([
|
197 |
+
z_vals[..., 1:] - z_vals[..., :-1],
|
198 |
+
jnp.broadcast_to(jnp.array([1e10]),#, dtype = dtype),
|
199 |
+
z_vals[..., :1].shape)
|
200 |
+
], -1)
|
201 |
+
dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
|
202 |
+
# Note that we're quietly turning sigma from [..., 0] to [...].
|
203 |
+
alpha = 1.0 - jnp.exp(-sigma[..., 0] * dists)
|
204 |
+
accum_prod = jnp.concatenate([
|
205 |
+
jnp.ones_like(alpha[..., :1], alpha.dtype),
|
206 |
+
jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1)
|
207 |
+
],
|
208 |
+
axis=-1)
|
209 |
+
weights = alpha * accum_prod
|
210 |
+
weights = weights.astype(dtype)
|
211 |
+
|
212 |
+
comp_rgb = (weights[..., None] * rgb).sum(axis=-2)
|
213 |
+
depth = (weights * z_vals).sum(axis=-1)
|
214 |
+
acc = weights.sum(axis=-1)
|
215 |
+
# Equivalent to (but slightly more efficient and stable than):
|
216 |
+
# disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
|
217 |
+
inv_eps = 1 / eps
|
218 |
+
disp = acc / depth
|
219 |
+
disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)
|
220 |
+
if white_bkgd:
|
221 |
+
comp_rgb = comp_rgb + (1. - acc[..., None])
|
222 |
+
return comp_rgb, disp, acc, weights
|
223 |
+
|
224 |
+
|
225 |
+
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
|
226 |
+
"""
|
227 |
+
Piecewise-Constant PDF sampling.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
key: jnp.ndarray(float32), [2,], random number generator.
|
231 |
+
bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
|
232 |
+
weights: jnp.ndarray(float32), [batch_size, num_bins].
|
233 |
+
num_samples: int, the number of samples.
|
234 |
+
randomized: bool, use randomized samples.
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
z_samples: jnp.ndarray(float32), [batch_size, num_samples].
|
238 |
+
"""
|
239 |
+
# Pad each weight vector (only if necessary) to bring its sum to `eps`. This
|
240 |
+
# avoids NaNs when the input is zeros or small, but has no effect otherwise.
|
241 |
+
dtype = bins.dtype
|
242 |
+
|
243 |
+
eps = 1e-5
|
244 |
+
weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
|
245 |
+
padding = jnp.maximum(0, eps - weight_sum)
|
246 |
+
weights += padding / weights.shape[-1]
|
247 |
+
weight_sum += padding
|
248 |
+
|
249 |
+
# Compute the PDF and CDF for each weight vector, while ensuring that the CDF
|
250 |
+
# starts with exactly 0 and ends with exactly 1.
|
251 |
+
pdf = weights / weight_sum
|
252 |
+
cdf = jnp.minimum(1, jnp.cumsum(pdf[..., :-1], axis=-1))
|
253 |
+
cdf = jnp.concatenate([
|
254 |
+
jnp.zeros(list(cdf.shape[:-1]) + [1], dtype = dtype), cdf,
|
255 |
+
jnp.ones(list(cdf.shape[:-1]) + [1], dtype = dtype)
|
256 |
+
],
|
257 |
+
axis=-1)
|
258 |
+
|
259 |
+
# Draw uniform samples.
|
260 |
+
if randomized:
|
261 |
+
# Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
262 |
+
u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
|
263 |
+
else:
|
264 |
+
# Match the behavior of random.uniform() by spanning [0, 1-eps].
|
265 |
+
u = jnp.linspace(0., 1. - jnp.finfo(dtype).eps, num_samples, dtype = dtype)
|
266 |
+
u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
|
267 |
+
|
268 |
+
# Identify the location in `cdf` that corresponds to a random sample.
|
269 |
+
# The final `True` index in `mask` will be the start of the sampled interval.
|
270 |
+
mask = u[..., None, :] >= cdf[..., :, None]
|
271 |
+
|
272 |
+
def find_interval(x):
|
273 |
+
# Grab the value where `mask` switches from True to False, and vice versa.
|
274 |
+
# This approach takes advantage of the fact that `x` is sorted.
|
275 |
+
x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
|
276 |
+
x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
|
277 |
+
return x0, x1
|
278 |
+
|
279 |
+
bins_g0, bins_g1 = find_interval(bins)
|
280 |
+
cdf_g0, cdf_g1 = find_interval(cdf)
|
281 |
+
|
282 |
+
t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
|
283 |
+
samples = bins_g0 + t * (bins_g1 - bins_g0)
|
284 |
+
|
285 |
+
# Prevent gradient from backprop-ing through `samples`.
|
286 |
+
return lax.stop_gradient(samples)
|
287 |
+
|
288 |
+
|
289 |
+
def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,
|
290 |
+
randomized):
|
291 |
+
"""
|
292 |
+
Hierarchical sampling.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
key: jnp.ndarray(float32), [2,], random number generator.
|
296 |
+
bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
|
297 |
+
weights: jnp.ndarray(float32), [batch_size, num_bins].
|
298 |
+
origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
|
299 |
+
directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
|
300 |
+
z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
|
301 |
+
num_samples: int, the number of samples.
|
302 |
+
randomized: bool, use randomized samples.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
z_vals: jnp.ndarray(float32),
|
306 |
+
[batch_size, num_coarse_samples + num_fine_samples].
|
307 |
+
points: jnp.ndarray(float32),
|
308 |
+
[batch_size, num_coarse_samples + num_fine_samples, 3].
|
309 |
+
"""
|
310 |
+
z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
|
311 |
+
randomized)
|
312 |
+
# Compute united z_vals and sample points
|
313 |
+
z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
|
314 |
+
coords = cast_rays(z_vals, origins, directions)
|
315 |
+
return z_vals, coords
|
316 |
+
|
317 |
+
|
318 |
+
def add_gaussian_noise(key, raw, noise_std, randomized):
|
319 |
+
"""
|
320 |
+
Adds gaussian noise to `raw`, which can used to regularize it.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
key: jnp.ndarray(float32), [2,], random number generator.
|
324 |
+
raw: jnp.ndarray(float32), arbitrary shape.
|
325 |
+
noise_std: float, The standard deviation of the noise to be added.
|
326 |
+
randomized: bool, add noise if randomized is True.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
|
330 |
+
"""
|
331 |
+
if (noise_std is not None) and randomized:
|
332 |
+
return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std
|
333 |
+
else:
|
334 |
+
return raw
|
nerf/models.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Different model implementation plus a general port for all the models."""
|
18 |
+
from typing import Any, Callable
|
19 |
+
from flax import linen as nn
|
20 |
+
from jax import random
|
21 |
+
import jax.numpy as jnp
|
22 |
+
|
23 |
+
from nerf import model_utils
|
24 |
+
from nerf import utils
|
25 |
+
|
26 |
+
|
27 |
+
def get_model(key, example_batch, args):
|
28 |
+
"""A helper function that wraps around a 'model zoo'."""
|
29 |
+
model_dict = {"nerf": construct_nerf}
|
30 |
+
return model_dict[args.model](key, example_batch, args)
|
31 |
+
|
32 |
+
|
33 |
+
class NerfModel(nn.Module):
|
34 |
+
"""Nerf NN Model with both coarse and fine MLPs."""
|
35 |
+
num_coarse_samples: int # The number of samples for the coarse nerf.
|
36 |
+
num_fine_samples: int # The number of samples for the fine nerf.
|
37 |
+
use_viewdirs: bool # If True, use viewdirs as an input.
|
38 |
+
near: float # The distance to the near plane
|
39 |
+
far: float # The distance to the far plane
|
40 |
+
noise_std: float # The std dev of noise added to raw sigma.
|
41 |
+
net_depth: int # The depth of the first part of MLP.
|
42 |
+
net_width: int # The width of the first part of MLP.
|
43 |
+
net_depth_condition: int # The depth of the second part of MLP.
|
44 |
+
net_width_condition: int # The width of the second part of MLP.
|
45 |
+
net_activation: Callable[..., Any] # MLP activation
|
46 |
+
skip_layer: int # How often to add skip connections.
|
47 |
+
num_rgb_channels: int # The number of RGB channels.
|
48 |
+
num_sigma_channels: int # The number of density channels.
|
49 |
+
white_bkgd: bool # If True, use a white background.
|
50 |
+
min_deg_point: int # The minimum degree of positional encoding for positions.
|
51 |
+
max_deg_point: int # The maximum degree of positional encoding for positions.
|
52 |
+
deg_view: int # The degree of positional encoding for viewdirs.
|
53 |
+
lindisp: bool # If True, sample linearly in disparity rather than in depth.
|
54 |
+
rgb_activation: Callable[..., Any] # Output RGB activation.
|
55 |
+
sigma_activation: Callable[..., Any] # Output sigma activation.
|
56 |
+
legacy_posenc_order: bool # Keep the same ordering as the original tf code.
|
57 |
+
|
58 |
+
@nn.compact
|
59 |
+
def __call__(self, rng_0, rng_1, rays, randomized, rgb_only = False):
|
60 |
+
"""Nerf Model.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
rng_0: jnp.ndarray, random number generator for coarse model sampling.
|
64 |
+
rng_1: jnp.ndarray, random number generator for fine model sampling.
|
65 |
+
rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
|
66 |
+
randomized: bool, use randomized stratified sampling.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
|
70 |
+
"""
|
71 |
+
# Stratified sampling along rays
|
72 |
+
key, rng_0 = random.split(rng_0)
|
73 |
+
dtype = rays[0].dtype
|
74 |
+
|
75 |
+
z_vals, samples = model_utils.sample_along_rays(
|
76 |
+
key,
|
77 |
+
rays.origins,
|
78 |
+
rays.directions,
|
79 |
+
self.num_coarse_samples,
|
80 |
+
self.near,
|
81 |
+
self.far,
|
82 |
+
randomized,
|
83 |
+
self.lindisp,
|
84 |
+
)
|
85 |
+
|
86 |
+
samples_enc = model_utils.posenc(
|
87 |
+
samples,
|
88 |
+
self.min_deg_point,
|
89 |
+
self.max_deg_point,
|
90 |
+
self.legacy_posenc_order,
|
91 |
+
)
|
92 |
+
|
93 |
+
# Construct the "coarse" MLP.
|
94 |
+
coarse_mlp = model_utils.MLP(
|
95 |
+
net_depth=self.net_depth,
|
96 |
+
net_width=self.net_width,
|
97 |
+
net_depth_condition=self.net_depth_condition,
|
98 |
+
net_width_condition=self.net_width_condition,
|
99 |
+
net_activation=self.net_activation,
|
100 |
+
skip_layer=self.skip_layer,
|
101 |
+
num_rgb_channels=self.num_rgb_channels,
|
102 |
+
num_sigma_channels=self.num_sigma_channels)
|
103 |
+
|
104 |
+
# Point attribute predictions
|
105 |
+
if self.use_viewdirs:
|
106 |
+
viewdirs_enc = model_utils.posenc(
|
107 |
+
rays.viewdirs,
|
108 |
+
0,
|
109 |
+
self.deg_view,
|
110 |
+
self.legacy_posenc_order,
|
111 |
+
)
|
112 |
+
raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc)
|
113 |
+
else:
|
114 |
+
viewdirs_enc = None
|
115 |
+
raw_rgb, raw_sigma = coarse_mlp(samples_enc)
|
116 |
+
# Add noises to regularize the density predictions if needed
|
117 |
+
key, rng_0 = random.split(rng_0)
|
118 |
+
raw_sigma = model_utils.add_gaussian_noise(
|
119 |
+
key,
|
120 |
+
raw_sigma,
|
121 |
+
self.noise_std,
|
122 |
+
randomized,
|
123 |
+
)
|
124 |
+
rgb = self.rgb_activation(raw_rgb)
|
125 |
+
sigma = self.sigma_activation(raw_sigma)
|
126 |
+
# Volumetric rendering.
|
127 |
+
comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
|
128 |
+
rgb,
|
129 |
+
sigma,
|
130 |
+
z_vals,
|
131 |
+
rays.directions,
|
132 |
+
white_bkgd=self.white_bkgd,
|
133 |
+
)
|
134 |
+
|
135 |
+
ret = [
|
136 |
+
(comp_rgb, disp, acc),
|
137 |
+
]
|
138 |
+
|
139 |
+
if self.num_fine_samples > 0 and not(rgb_only):
|
140 |
+
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
141 |
+
key, rng_1 = random.split(rng_1)
|
142 |
+
|
143 |
+
z_vals, samples = model_utils.sample_pdf(
|
144 |
+
key,
|
145 |
+
z_vals_mid,
|
146 |
+
weights[..., 1:-1],
|
147 |
+
rays.origins,
|
148 |
+
rays.directions,
|
149 |
+
z_vals,
|
150 |
+
self.num_fine_samples,
|
151 |
+
randomized,
|
152 |
+
)
|
153 |
+
samples_enc = model_utils.posenc(
|
154 |
+
samples,
|
155 |
+
self.min_deg_point,
|
156 |
+
self.max_deg_point,
|
157 |
+
self.legacy_posenc_order,
|
158 |
+
)
|
159 |
+
|
160 |
+
# Construct the "fine" MLP.
|
161 |
+
fine_mlp = model_utils.MLP(
|
162 |
+
net_depth=self.net_depth,
|
163 |
+
net_width=self.net_width,
|
164 |
+
net_depth_condition=self.net_depth_condition,
|
165 |
+
net_width_condition=self.net_width_condition,
|
166 |
+
net_activation=self.net_activation,
|
167 |
+
skip_layer=self.skip_layer,
|
168 |
+
num_rgb_channels=self.num_rgb_channels,
|
169 |
+
num_sigma_channels=self.num_sigma_channels)
|
170 |
+
|
171 |
+
if self.use_viewdirs:
|
172 |
+
raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc)
|
173 |
+
else:
|
174 |
+
raw_rgb, raw_sigma = fine_mlp(samples_enc)
|
175 |
+
key, rng_1 = random.split(rng_1)
|
176 |
+
raw_sigma = model_utils.add_gaussian_noise(
|
177 |
+
key,
|
178 |
+
raw_sigma,
|
179 |
+
self.noise_std,
|
180 |
+
randomized,
|
181 |
+
)
|
182 |
+
rgb = self.rgb_activation(raw_rgb)
|
183 |
+
sigma = self.sigma_activation(raw_sigma)
|
184 |
+
|
185 |
+
comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
|
186 |
+
rgb,
|
187 |
+
sigma,
|
188 |
+
z_vals,
|
189 |
+
rays.directions,
|
190 |
+
white_bkgd=self.white_bkgd,
|
191 |
+
)
|
192 |
+
ret.append((comp_rgb, disp, acc))
|
193 |
+
if rgb_only:
|
194 |
+
#return [ret[0][0], ret[1][0]]
|
195 |
+
return [None, ret[0][0]]
|
196 |
+
return ret
|
197 |
+
|
198 |
+
def construct_nerf(key, example_batch, args):
|
199 |
+
"""Construct a Neural Radiance Field.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
key: jnp.ndarray. Random number generator.
|
203 |
+
example_batch: dict, an example of a batch of data.
|
204 |
+
args: FLAGS class. Hyperparameters of nerf.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
model: nn.Model. Nerf model with parameters.
|
208 |
+
state: flax.Module.state. Nerf model state for stateful parameters.
|
209 |
+
"""
|
210 |
+
net_activation = getattr(nn, str(args.net_activation))
|
211 |
+
rgb_activation = getattr(nn, str(args.rgb_activation))
|
212 |
+
sigma_activation = getattr(nn, str(args.sigma_activation))
|
213 |
+
|
214 |
+
# Assert that rgb_activation always produces outputs in [0, 1], and
|
215 |
+
# sigma_activation always produce non-negative outputs.
|
216 |
+
x = jnp.exp(jnp.linspace(-90, 90, 1024))
|
217 |
+
x = jnp.concatenate([-x[::-1], x], 0)
|
218 |
+
|
219 |
+
rgb = rgb_activation(x)
|
220 |
+
if jnp.any(rgb < 0) or jnp.any(rgb > 1):
|
221 |
+
raise NotImplementedError(
|
222 |
+
"Choice of rgb_activation `{}` produces colors outside of [0, 1]"
|
223 |
+
.format(args.rgb_activation))
|
224 |
+
|
225 |
+
sigma = sigma_activation(x)
|
226 |
+
if jnp.any(sigma < 0):
|
227 |
+
raise NotImplementedError(
|
228 |
+
"Choice of sigma_activation `{}` produces negative densities".format(
|
229 |
+
args.sigma_activation))
|
230 |
+
|
231 |
+
model = NerfModel(
|
232 |
+
min_deg_point=args.min_deg_point,
|
233 |
+
max_deg_point=args.max_deg_point,
|
234 |
+
deg_view=args.deg_view,
|
235 |
+
num_coarse_samples=args.num_coarse_samples,
|
236 |
+
num_fine_samples=args.num_fine_samples,
|
237 |
+
use_viewdirs=args.use_viewdirs,
|
238 |
+
near=args.near,
|
239 |
+
far=args.far,
|
240 |
+
noise_std=args.noise_std,
|
241 |
+
white_bkgd=args.white_bkgd,
|
242 |
+
net_depth=args.net_depth,
|
243 |
+
net_width=args.net_width,
|
244 |
+
net_depth_condition=args.net_depth_condition,
|
245 |
+
net_width_condition=args.net_width_condition,
|
246 |
+
skip_layer=args.skip_layer,
|
247 |
+
num_rgb_channels=args.num_rgb_channels,
|
248 |
+
num_sigma_channels=args.num_sigma_channels,
|
249 |
+
lindisp=args.lindisp,
|
250 |
+
net_activation=net_activation,
|
251 |
+
rgb_activation=rgb_activation,
|
252 |
+
sigma_activation=sigma_activation,
|
253 |
+
legacy_posenc_order=args.legacy_posenc_order)
|
254 |
+
rays = example_batch["rays"]
|
255 |
+
key1, key2, key3 = random.split(key, num=3)
|
256 |
+
|
257 |
+
init_variables = model.init(
|
258 |
+
key1,
|
259 |
+
rng_0=key2,
|
260 |
+
rng_1=key3,
|
261 |
+
rays=utils.namedtuple_map(lambda x: x[0], rays),
|
262 |
+
randomized=args.randomized)
|
263 |
+
|
264 |
+
return model, init_variables
|
nerf/utils.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Utility functions."""
|
18 |
+
import collections
|
19 |
+
import os
|
20 |
+
from os import path
|
21 |
+
import pickle
|
22 |
+
from absl import flags
|
23 |
+
import flax
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
import jax.scipy as jsp
|
27 |
+
import numpy as np
|
28 |
+
from PIL import Image
|
29 |
+
import yaml
|
30 |
+
from nerf import datasets
|
31 |
+
|
32 |
+
BASE_DIR = ""
|
33 |
+
INTERNAL = False
|
34 |
+
|
35 |
+
|
36 |
+
@flax.struct.dataclass
|
37 |
+
class TrainState:
|
38 |
+
optimizer: flax.optim.Optimizer
|
39 |
+
|
40 |
+
|
41 |
+
@flax.struct.dataclass
|
42 |
+
class Stats:
|
43 |
+
loss: float
|
44 |
+
psnr: float
|
45 |
+
loss_c: float
|
46 |
+
psnr_c: float
|
47 |
+
weight_l2: float
|
48 |
+
|
49 |
+
|
50 |
+
Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
|
51 |
+
|
52 |
+
|
53 |
+
def namedtuple_map(fn, tup):
|
54 |
+
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
|
55 |
+
return type(tup)(*map(fn, tup))
|
56 |
+
|
57 |
+
|
58 |
+
def define_flags():
|
59 |
+
"""Define flags for both training and evaluation modes."""
|
60 |
+
flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
|
61 |
+
flags.DEFINE_string("data_dir", None, "input data directory.")
|
62 |
+
flags.DEFINE_string("config", None,
|
63 |
+
"using config files to set hyperparameters.")
|
64 |
+
|
65 |
+
# CLIP part Flags
|
66 |
+
flags.DEFINE_bool("use_semantic_loss", True,
|
67 |
+
"whether use semantic loss or not")
|
68 |
+
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
|
69 |
+
flags.DEFINE_string("clip_output_dtype", "float32",
|
70 |
+
"float32/ float16 (float16 for memory saving)")
|
71 |
+
flags.DEFINE_integer("sc_loss_every", 16,
|
72 |
+
"no. of steps to take before performing semantic loss evaluation")
|
73 |
+
flags.DEFINE_float("sc_loss_mult", 1e-3,
|
74 |
+
"weighting for semantic loss from CLIP")
|
75 |
+
|
76 |
+
# Dataset Flags
|
77 |
+
# TODO(pratuls): rename to dataset_loader and consider cleaning up
|
78 |
+
flags.DEFINE_enum("dataset", "blender",
|
79 |
+
list(k for k in datasets.dataset_dict.keys()),
|
80 |
+
"The type of dataset feed to nerf.")
|
81 |
+
flags.DEFINE_enum(
|
82 |
+
"batching", "single_image", ["single_image", "all_images"],
|
83 |
+
"source of ray sampling when collecting training batch,"
|
84 |
+
"single_image for sampling from only one image in a batch,"
|
85 |
+
"all_images for sampling from all the training images.")
|
86 |
+
flags.DEFINE_bool(
|
87 |
+
"white_bkgd", True, "using white color as default background."
|
88 |
+
"(used in the blender dataset only)")
|
89 |
+
flags.DEFINE_integer("batch_size", 1024,
|
90 |
+
"the number of rays in a mini-batch (for training).")
|
91 |
+
flags.DEFINE_integer("factor", 4,
|
92 |
+
"the downsample factor of images, 0 for no downsample.")
|
93 |
+
flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
|
94 |
+
flags.DEFINE_bool(
|
95 |
+
"render_path", False, "render generated path if set true."
|
96 |
+
"(used in the llff dataset only)")
|
97 |
+
flags.DEFINE_integer(
|
98 |
+
"llffhold", 8, "will take every 1/N images as LLFF test set."
|
99 |
+
"(used in the llff dataset only)")
|
100 |
+
flags.DEFINE_bool(
|
101 |
+
"use_pixel_centers", False,
|
102 |
+
"If True, generate rays through the center of each pixel. Note: While "
|
103 |
+
"this is the correct way to handle rays, it is not the way rays are "
|
104 |
+
"handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR "
|
105 |
+
"compared to Vanilla NeRF.")
|
106 |
+
|
107 |
+
# Model Flags
|
108 |
+
flags.DEFINE_string("model", "nerf", "name of model to use.")
|
109 |
+
flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
|
110 |
+
flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
|
111 |
+
flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
|
112 |
+
flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
|
113 |
+
flags.DEFINE_integer("net_depth_condition", 1,
|
114 |
+
"depth of the second part of MLP.")
|
115 |
+
flags.DEFINE_integer("net_width_condition", 128,
|
116 |
+
"width of the second part of MLP.")
|
117 |
+
flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
|
118 |
+
flags.DEFINE_integer(
|
119 |
+
"skip_layer", 4, "add a skip connection to the output vector of every"
|
120 |
+
"skip_layer layers.")
|
121 |
+
flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
|
122 |
+
flags.DEFINE_integer("num_sigma_channels", 1,
|
123 |
+
"the number of density channels.")
|
124 |
+
flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
|
125 |
+
flags.DEFINE_integer("min_deg_point", 0,
|
126 |
+
"Minimum degree of positional encoding for points.")
|
127 |
+
flags.DEFINE_integer("max_deg_point", 10,
|
128 |
+
"Maximum degree of positional encoding for points.")
|
129 |
+
flags.DEFINE_integer("deg_view", 4,
|
130 |
+
"Degree of positional encoding for viewdirs.")
|
131 |
+
flags.DEFINE_integer(
|
132 |
+
"num_coarse_samples", 64,
|
133 |
+
"the number of samples on each ray for the coarse model.")
|
134 |
+
flags.DEFINE_integer("num_fine_samples", 128,
|
135 |
+
"the number of samples on each ray for the fine model.")
|
136 |
+
flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
|
137 |
+
flags.DEFINE_float(
|
138 |
+
"noise_std", None, "std dev of noise added to regularize sigma output."
|
139 |
+
"(used in the llff dataset only)")
|
140 |
+
flags.DEFINE_bool("lindisp", False,
|
141 |
+
"sampling linearly in disparity rather than depth.")
|
142 |
+
flags.DEFINE_string("net_activation", "relu",
|
143 |
+
"activation function used within the MLP.")
|
144 |
+
flags.DEFINE_string("rgb_activation", "sigmoid",
|
145 |
+
"activation function used to produce RGB.")
|
146 |
+
flags.DEFINE_string("sigma_activation", "relu",
|
147 |
+
"activation function used to produce density.")
|
148 |
+
flags.DEFINE_bool(
|
149 |
+
"legacy_posenc_order", False,
|
150 |
+
"If True, revert the positional encoding feature order to an older version of this codebase."
|
151 |
+
)
|
152 |
+
|
153 |
+
# Train Flags
|
154 |
+
flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
|
155 |
+
flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
|
156 |
+
flags.DEFINE_integer(
|
157 |
+
"lr_delay_steps", 0, "The number of steps at the beginning of "
|
158 |
+
"training to reduce the learning rate by lr_delay_mult")
|
159 |
+
flags.DEFINE_float(
|
160 |
+
"lr_delay_mult", 1., "A multiplier on the learning rate when the step "
|
161 |
+
"is < lr_delay_steps")
|
162 |
+
flags.DEFINE_float("grad_max_norm", 0.,
|
163 |
+
"The gradient clipping magnitude (disabled if == 0).")
|
164 |
+
flags.DEFINE_float("grad_max_val", 0.,
|
165 |
+
"The gradient clipping value (disabled if == 0).")
|
166 |
+
|
167 |
+
flags.DEFINE_integer("max_steps", 1000000,
|
168 |
+
"the number of optimization steps.")
|
169 |
+
flags.DEFINE_integer("save_every", 10000,
|
170 |
+
"the number of steps to save a checkpoint.")
|
171 |
+
flags.DEFINE_integer("print_every", 100,
|
172 |
+
"the number of steps between reports to tensorboard.")
|
173 |
+
flags.DEFINE_integer(
|
174 |
+
"render_every", 5000, "the number of steps to render a test image,"
|
175 |
+
"better to be x00 for accurate step time record.")
|
176 |
+
flags.DEFINE_integer("gc_every", 10000,
|
177 |
+
"the number of steps to run python garbage collection.")
|
178 |
+
flags.DEFINE_integer("few_shot", -1,
|
179 |
+
"the number of images.")
|
180 |
+
|
181 |
+
# Eval Flags
|
182 |
+
flags.DEFINE_bool(
|
183 |
+
"eval_once", True,
|
184 |
+
"evaluate the model only once if true, otherwise keeping evaluating new"
|
185 |
+
"checkpoints if there's any.")
|
186 |
+
flags.DEFINE_bool("save_output", True,
|
187 |
+
"save predicted images to disk if True.")
|
188 |
+
flags.DEFINE_integer(
|
189 |
+
"chunk", 1024,
|
190 |
+
"the size of chunks for evaluation inferences, set to the value that"
|
191 |
+
"fits your GPU/TPU memory.")
|
192 |
+
flags.DEFINE_bool("generate_gif_only", False,
|
193 |
+
"in eval.py, we only generate GIF file for the trained model")
|
194 |
+
|
195 |
+
|
196 |
+
def update_flags(args):
|
197 |
+
"""Update the flags in `args` with the contents of the config YAML file."""
|
198 |
+
pth = path.join(BASE_DIR, args.config + ".yaml")
|
199 |
+
with open_file(pth, "r") as fin:
|
200 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
201 |
+
# Only allow args to be updated if they already exist.
|
202 |
+
invalid_args = list(set(configs.keys()) - set(dir(args)))
|
203 |
+
if invalid_args:
|
204 |
+
raise ValueError(f"Invalid args {invalid_args} in {pth}.")
|
205 |
+
args.__dict__.update(configs)
|
206 |
+
|
207 |
+
def open_file(pth, mode="r"):
|
208 |
+
if not INTERNAL:
|
209 |
+
return open(pth, mode=mode)
|
210 |
+
|
211 |
+
|
212 |
+
def file_exists(pth):
|
213 |
+
if not INTERNAL:
|
214 |
+
return path.exists(pth)
|
215 |
+
|
216 |
+
|
217 |
+
def listdir(pth):
|
218 |
+
if not INTERNAL:
|
219 |
+
return os.listdir(pth)
|
220 |
+
|
221 |
+
|
222 |
+
def isdir(pth):
|
223 |
+
if not INTERNAL:
|
224 |
+
return path.isdir(pth)
|
225 |
+
|
226 |
+
|
227 |
+
def makedirs(pth):
|
228 |
+
if not INTERNAL:
|
229 |
+
os.makedirs(pth)
|
230 |
+
|
231 |
+
|
232 |
+
def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
|
233 |
+
"""Render all the pixels of an image (in test mode).
|
234 |
+
|
235 |
+
Args:
|
236 |
+
render_fn: function, jit-ed render function.
|
237 |
+
rays: a `Rays` namedtuple, the rays to be rendered.
|
238 |
+
rng: jnp.ndarray, random number generator (used in training mode only).
|
239 |
+
normalize_disp: bool, if true then normalize `disp` to [0, 1].
|
240 |
+
chunk: int, the size of chunks to render sequentially.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
rgb: jnp.ndarray, rendered color image.
|
244 |
+
disp: jnp.ndarray, rendered disparity image.
|
245 |
+
acc: jnp.ndarray, rendered accumulated weights per pixel.
|
246 |
+
"""
|
247 |
+
height, width = rays[0].shape[:2]
|
248 |
+
num_rays = height * width
|
249 |
+
rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)
|
250 |
+
unused_rng, key_0, key_1 = jax.random.split(rng, 3)
|
251 |
+
host_id = jax.host_id()
|
252 |
+
results = []
|
253 |
+
for i in range(0, num_rays, chunk):
|
254 |
+
# pylint: disable=cell-var-from-loop
|
255 |
+
chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
|
256 |
+
chunk_size = chunk_rays[0].shape[0]
|
257 |
+
rays_remaining = chunk_size % jax.device_count()
|
258 |
+
if rays_remaining != 0:
|
259 |
+
padding = jax.device_count() - rays_remaining
|
260 |
+
chunk_rays = namedtuple_map(
|
261 |
+
lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
|
262 |
+
else:
|
263 |
+
padding = 0
|
264 |
+
# After padding the number of chunk_rays is always divisible by
|
265 |
+
# host_count.
|
266 |
+
rays_per_host = chunk_rays[0].shape[0] // jax.process_count()
|
267 |
+
start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
|
268 |
+
chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
|
269 |
+
chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
|
270 |
+
results.append([unshard(x, padding) for x in chunk_results])
|
271 |
+
# pylint: enable=cell-var-from-loop
|
272 |
+
rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)]
|
273 |
+
# Normalize disp for visualization for ndc_rays in llff front-facing scenes.
|
274 |
+
if normalize_disp:
|
275 |
+
disp = (disp - disp.min()) / (disp.max() - disp.min())
|
276 |
+
return (rgb.reshape((height, width, -1)), disp.reshape(
|
277 |
+
(height, width, -1)), acc.reshape((height, width, -1)))
|
278 |
+
|
279 |
+
|
280 |
+
def compute_psnr(mse):
|
281 |
+
"""Compute psnr value given mse (we assume the maximum pixel value is 1).
|
282 |
+
|
283 |
+
Args:
|
284 |
+
mse: float, mean square error of pixels.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
psnr: float, the psnr value.
|
288 |
+
"""
|
289 |
+
return -10. * jnp.log(mse) / jnp.log(10.)
|
290 |
+
|
291 |
+
|
292 |
+
def compute_ssim(img0,
|
293 |
+
img1,
|
294 |
+
max_val,
|
295 |
+
filter_size=11,
|
296 |
+
filter_sigma=1.5,
|
297 |
+
k1=0.01,
|
298 |
+
k2=0.03,
|
299 |
+
return_map=False):
|
300 |
+
"""Computes SSIM from two images.
|
301 |
+
|
302 |
+
This function was modeled after tf.image.ssim, and should produce comparable
|
303 |
+
output.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
img0: array. An image of size [..., width, height, num_channels].
|
307 |
+
img1: array. An image of size [..., width, height, num_channels].
|
308 |
+
max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
|
309 |
+
filter_size: int >= 1. Window size.
|
310 |
+
filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
|
311 |
+
k1: float > 0. One of the SSIM dampening parameters.
|
312 |
+
k2: float > 0. One of the SSIM dampening parameters.
|
313 |
+
return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
Each image's mean SSIM, or a tensor of individual values if `return_map`.
|
317 |
+
"""
|
318 |
+
# Construct a 1D Gaussian blur filter.
|
319 |
+
hw = filter_size // 2
|
320 |
+
shift = (2 * hw - filter_size + 1) / 2
|
321 |
+
f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2
|
322 |
+
filt = jnp.exp(-0.5 * f_i)
|
323 |
+
filt /= jnp.sum(filt)
|
324 |
+
|
325 |
+
# Blur in x and y (faster than the 2D convolution).
|
326 |
+
filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
|
327 |
+
filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")
|
328 |
+
|
329 |
+
# Vmap the blurs to the tensor size, and then compose them.
|
330 |
+
num_dims = len(img0.shape)
|
331 |
+
map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
|
332 |
+
for d in map_axes:
|
333 |
+
filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
|
334 |
+
filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
|
335 |
+
filt_fn = lambda z: filt_fn1(filt_fn2(z))
|
336 |
+
|
337 |
+
mu0 = filt_fn(img0)
|
338 |
+
mu1 = filt_fn(img1)
|
339 |
+
mu00 = mu0 * mu0
|
340 |
+
mu11 = mu1 * mu1
|
341 |
+
mu01 = mu0 * mu1
|
342 |
+
sigma00 = filt_fn(img0 ** 2) - mu00
|
343 |
+
sigma11 = filt_fn(img1 ** 2) - mu11
|
344 |
+
sigma01 = filt_fn(img0 * img1) - mu01
|
345 |
+
|
346 |
+
# Clip the variances and covariances to valid values.
|
347 |
+
# Variance must be non-negative:
|
348 |
+
sigma00 = jnp.maximum(0., sigma00)
|
349 |
+
sigma11 = jnp.maximum(0., sigma11)
|
350 |
+
sigma01 = jnp.sign(sigma01) * jnp.minimum(
|
351 |
+
jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
|
352 |
+
|
353 |
+
c1 = (k1 * max_val) ** 2
|
354 |
+
c2 = (k2 * max_val) ** 2
|
355 |
+
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
|
356 |
+
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
|
357 |
+
ssim_map = numer / denom
|
358 |
+
ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
|
359 |
+
return ssim_map if return_map else ssim
|
360 |
+
|
361 |
+
|
362 |
+
def save_img(img, pth):
|
363 |
+
"""Save an image to disk.
|
364 |
+
|
365 |
+
Args:
|
366 |
+
img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
|
367 |
+
before saved to pth.
|
368 |
+
pth: string, path to save the image to.
|
369 |
+
"""
|
370 |
+
with open_file(pth, "wb") as imgout:
|
371 |
+
Image.fromarray(np.array(
|
372 |
+
(np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
|
373 |
+
|
374 |
+
|
375 |
+
def learning_rate_decay(step,
|
376 |
+
lr_init,
|
377 |
+
lr_final,
|
378 |
+
max_steps,
|
379 |
+
lr_delay_steps=0,
|
380 |
+
lr_delay_mult=1):
|
381 |
+
"""Continuous learning rate decay function.
|
382 |
+
|
383 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
384 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
385 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
386 |
+
function of lr_delay_mult, such that the initial learning rate is
|
387 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
388 |
+
to the normal learning rate when steps>lr_delay_steps.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
step: int, the current optimization step.
|
392 |
+
lr_init: float, the initial learning rate.
|
393 |
+
lr_final: float, the final learning rate.
|
394 |
+
max_steps: int, the number of steps during optimization.
|
395 |
+
lr_delay_steps: int, the number of steps to delay the full learning rate.
|
396 |
+
lr_delay_mult: float, the multiplier on the rate when delaying it.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
lr: the learning for current step 'step'.
|
400 |
+
"""
|
401 |
+
if lr_delay_steps > 0:
|
402 |
+
# A kind of reverse cosine decay.
|
403 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
404 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
|
405 |
+
else:
|
406 |
+
delay_rate = 1.
|
407 |
+
t = np.clip(step / max_steps, 0, 1)
|
408 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
409 |
+
return delay_rate * log_lerp
|
410 |
+
|
411 |
+
|
412 |
+
def shard(xs):
|
413 |
+
"""Split data into shards for multiple devices along the first dimension."""
|
414 |
+
'''
|
415 |
+
if 'embedding' in xs:
|
416 |
+
xs['pixels'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['pixels'])
|
417 |
+
xs['rays'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['rays'])
|
418 |
+
xs['embedding'] = np.stack([xs['embedding']]*jax.local_device_count(),0)
|
419 |
+
xs['random_rays'] = jax.tree_map(lambda x: np.stack([x]*jax.local_device_count(),0), xs['random_rays'])
|
420 |
+
else:
|
421 |
+
xs = jax.tree_map(
|
422 |
+
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
|
423 |
+
, xs)
|
424 |
+
|
425 |
+
return xs
|
426 |
+
'''
|
427 |
+
return jax.tree_map(
|
428 |
+
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
|
429 |
+
, xs)
|
430 |
+
|
431 |
+
|
432 |
+
def to_device(xs):
|
433 |
+
"""Transfer data to devices (GPU/TPU)."""
|
434 |
+
return jax.tree_map(jnp.array, xs)
|
435 |
+
|
436 |
+
|
437 |
+
def unshard(x, padding=0):
|
438 |
+
"""Collect the sharded tensor to the shape before sharding."""
|
439 |
+
y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
|
440 |
+
if padding > 0:
|
441 |
+
y = y[:-padding]
|
442 |
+
return y
|
443 |
+
|
444 |
+
|
445 |
+
def write_pickle(data, fn):
|
446 |
+
with open(fn, 'wb') as f:
|
447 |
+
pickle.dump(data, f)
|
448 |
+
return None
|
449 |
+
|
450 |
+
|
451 |
+
def read_pickle(fn):
|
452 |
+
with open(fn, 'rb') as f:
|
453 |
+
data = pickle.load(f)
|
454 |
+
return data
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.16.4
|
2 |
+
jax>=0.2.6
|
3 |
+
jaxlib>=0.1.57
|
4 |
+
flax>=0.2.2
|
5 |
+
opencv-python>=4.4.0
|
6 |
+
Pillow>=7.2.0
|
7 |
+
pyyaml>=5.3.1
|
8 |
+
tensorboard>=2.4.0
|
9 |
+
tensorflow>=2.3.1
|
10 |
+
tensorflow-hub>=0.11.0
|
11 |
+
transformers==4.8.2
|
12 |
+
wandb==0.10.33
|
13 |
+
tqdm==4.61.2
|
14 |
+
# pip install git+https://github.com/deepmind/jmp # mixed precision for JAX
|
run.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Google Research Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
#!/bin/bash
|
16 |
+
set -e
|
17 |
+
set -x
|
18 |
+
|
19 |
+
virtualenv -p python3 .
|
20 |
+
source ./bin/activate
|
21 |
+
|
22 |
+
pip install -r jaxnerf/requirements.txt
|
23 |
+
pip uninstall jax
|
24 |
+
pip install --upgrade pip
|
25 |
+
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
26 |
+
python -m jaxnerf.train \
|
27 |
+
--data_dir=/mnt/data/NeRF_Data/nerf_synthetic/lego \
|
28 |
+
--train_dir=test_output \
|
29 |
+
--max_steps=5 \
|
30 |
+
--factor=2 \
|
31 |
+
--batch_size=512 \
|
32 |
+
--config=configs/orig_nerf_tpu_vm_test \
|
33 |
+
--precompute_pkl_path /mnt/data/NeRF_Data/nerf_synthetic/lego/clip_cache_train_factor4_float32.pkl
|
train.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Training script for Nerf."""
|
18 |
+
import functools
|
19 |
+
import gc
|
20 |
+
import time
|
21 |
+
from absl import app
|
22 |
+
from absl import flags
|
23 |
+
import flax
|
24 |
+
from flax.metrics import tensorboard
|
25 |
+
from flax.training import checkpoints
|
26 |
+
import jax
|
27 |
+
from jax import config
|
28 |
+
from jax import random
|
29 |
+
import jax.numpy as jnp
|
30 |
+
import numpy as np
|
31 |
+
# import wandb
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
from nerf import datasets
|
35 |
+
from nerf import models
|
36 |
+
from nerf import utils
|
37 |
+
from nerf import clip_utils
|
38 |
+
|
39 |
+
FLAGS = flags.FLAGS
|
40 |
+
|
41 |
+
utils.define_flags()
|
42 |
+
config.parse_flags_with_absl()
|
43 |
+
|
44 |
+
# set up TPU for colab
|
45 |
+
import os
|
46 |
+
if "COLAB_TPU_ADDR" in os.environ:
|
47 |
+
import jax.tools.colab_tpu
|
48 |
+
jax.tools.colab_tpu.setup_tpu()
|
49 |
+
print(f"detected device: {jax.local_devices()}")
|
50 |
+
|
51 |
+
|
52 |
+
def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
53 |
+
# TODO make clip_grad input enable
|
54 |
+
"""One optimization step.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
model: The linen model.
|
58 |
+
rng: jnp.ndarray, random number generator.
|
59 |
+
state: utils.TrainState, state of the model/optimizer.
|
60 |
+
batch: dict, a mini-batch of data for training.
|
61 |
+
lr: float, real-time learning rate.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
new_state: utils.TrainState, new training state.
|
65 |
+
stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
|
66 |
+
rng: jnp.ndarray, updated random number generator.
|
67 |
+
"""
|
68 |
+
rng, key_0, key_1 = random.split(rng, 3)
|
69 |
+
|
70 |
+
def loss_fn(variables):
|
71 |
+
rays = batch["rays"]
|
72 |
+
ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
|
73 |
+
if len(ret) not in (1, 2):
|
74 |
+
raise ValueError(
|
75 |
+
"ret should contain either 1 set of output (coarse only), or 2 sets"
|
76 |
+
"of output (coarse as ret[0] and fine as ret[1]).")
|
77 |
+
# The main prediction is always at the end of the ret list.
|
78 |
+
rgb, unused_disp, unused_acc = ret[-1]
|
79 |
+
loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean()
|
80 |
+
psnr = utils.compute_psnr(loss)
|
81 |
+
if len(ret) > 1:
|
82 |
+
# If there are both coarse and fine predictions, we compute the loss for
|
83 |
+
# the coarse prediction (ret[0]) as well.
|
84 |
+
rgb_c, unused_disp_c, unused_acc_c = ret[0]
|
85 |
+
loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean()
|
86 |
+
psnr_c = utils.compute_psnr(loss_c)
|
87 |
+
else:
|
88 |
+
loss_c = 0.
|
89 |
+
psnr_c = 0.
|
90 |
+
|
91 |
+
def tree_sum_fn(fn):
|
92 |
+
return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
|
93 |
+
variables, initializer=0)
|
94 |
+
|
95 |
+
weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z ** 2)) /
|
96 |
+
tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))
|
97 |
+
|
98 |
+
total_loss = loss + loss_c + FLAGS.weight_decay_mult * weight_l2
|
99 |
+
stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c,
|
100 |
+
psnr_c=psnr_c, weight_l2=weight_l2)
|
101 |
+
return total_loss, stats
|
102 |
+
|
103 |
+
(_, stats), grad = (
|
104 |
+
jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
|
105 |
+
#grad = jax.lax.pmean(grad, axis_name="batch")
|
106 |
+
stats = jax.lax.pmean(stats, axis_name="batch")
|
107 |
+
|
108 |
+
# Clip the gradient by value.
|
109 |
+
if FLAGS.grad_max_val > 0:
|
110 |
+
clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val)
|
111 |
+
grad = jax.tree_util.tree_map(clip_fn, grad)
|
112 |
+
|
113 |
+
# Clip the (possibly value-clipped) gradient by norm.
|
114 |
+
if FLAGS.grad_max_norm > 0:
|
115 |
+
grad_norm = jnp.sqrt(
|
116 |
+
jax.tree_util.tree_reduce(
|
117 |
+
lambda x, y: x + jnp.sum(y ** 2), grad, initializer=0))
|
118 |
+
mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
|
119 |
+
grad = jax.tree_util.tree_map(lambda z: mult * z, grad)
|
120 |
+
|
121 |
+
return grad, stats, rng
|
122 |
+
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate =lr)
|
123 |
+
new_state = state.replace(optimizer=new_optimizer)
|
124 |
+
return new_state, stats, rng
|
125 |
+
|
126 |
+
def update_step(state, grad, lr):
|
127 |
+
grad = jax.lax.pmean(grad, axis_name="batch")
|
128 |
+
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
|
129 |
+
new_state = state.replace(optimizer=new_optimizer)
|
130 |
+
return new_state
|
131 |
+
|
132 |
+
def main(unused_argv):
|
133 |
+
#wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
|
134 |
+
rng = random.PRNGKey(20200823)
|
135 |
+
# Shift the numpy random seed by host_id() to shuffle data loaded by different
|
136 |
+
# hosts.
|
137 |
+
np.random.seed(20201473 + jax.host_id())
|
138 |
+
|
139 |
+
if FLAGS.config is not None:
|
140 |
+
utils.update_flags(FLAGS)
|
141 |
+
if FLAGS.batch_size % jax.device_count() != 0:
|
142 |
+
raise ValueError("Batch size must be divisible by the number of devices.")
|
143 |
+
if FLAGS.train_dir is None:
|
144 |
+
raise ValueError("train_dir must be set. None set now.")
|
145 |
+
if FLAGS.data_dir is None:
|
146 |
+
raise ValueError("data_dir must be set. None set now.")
|
147 |
+
|
148 |
+
# setup CLIP model
|
149 |
+
if FLAGS.use_semantic_loss:
|
150 |
+
clip_model = clip_utils.init_CLIP(FLAGS.clip_output_dtype,
|
151 |
+
FLAGS.clip_model_name)
|
152 |
+
print(f'semantic loss ACTIVATED, CLIP is set up '
|
153 |
+
f'(sc_loss_mult: {FLAGS.sc_loss_mult})')
|
154 |
+
else:
|
155 |
+
clip_model = None
|
156 |
+
print('semantic loss DEACTIVATED, CLIP is set to None')
|
157 |
+
|
158 |
+
dataset = datasets.get_dataset("train", FLAGS, clip_model)
|
159 |
+
test_dataset = datasets.get_dataset("test", FLAGS, clip_model)
|
160 |
+
|
161 |
+
# setup NeRF model
|
162 |
+
rng, key = random.split(rng)
|
163 |
+
model, variables = models.get_model(key, dataset.peek(), FLAGS)
|
164 |
+
optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
|
165 |
+
state = utils.TrainState(optimizer=optimizer)
|
166 |
+
del optimizer, variables
|
167 |
+
learning_rate_fn = functools.partial(
|
168 |
+
utils.learning_rate_decay,
|
169 |
+
lr_init=FLAGS.lr_init,
|
170 |
+
lr_final=FLAGS.lr_final,
|
171 |
+
max_steps=FLAGS.max_steps,
|
172 |
+
lr_delay_steps=FLAGS.lr_delay_steps,
|
173 |
+
lr_delay_mult=FLAGS.lr_delay_mult)
|
174 |
+
|
175 |
+
train_pstep = jax.pmap(
|
176 |
+
functools.partial(train_step, model, clip_model),
|
177 |
+
axis_name="batch",
|
178 |
+
in_axes=(0, 0, 0, None, None, None),
|
179 |
+
donate_argnums=(2,))
|
180 |
+
|
181 |
+
update_pstep = jax.pmap(
|
182 |
+
functools.partial(update_step,),
|
183 |
+
axis_name="batch",
|
184 |
+
in_axes=(0, 0, None),
|
185 |
+
donate_argnums=(0,))
|
186 |
+
|
187 |
+
def render_fn(variables, key_0, key_1, rays):
|
188 |
+
return model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
|
189 |
+
|
190 |
+
render_pfn = jax.pmap(
|
191 |
+
render_fn,
|
192 |
+
in_axes=(None, None, None, 0), # Only distribute the data input.
|
193 |
+
donate_argnums=(3,),
|
194 |
+
axis_name="batch")
|
195 |
+
|
196 |
+
def render_fn_(variables, key_0, key_1, rays):
|
197 |
+
return model.apply(variables, key_0, key_1, rays, False, True)
|
198 |
+
|
199 |
+
render_pfn_ = jax.pmap(
|
200 |
+
render_fn_,
|
201 |
+
in_axes=(None, None, None, 0), # Only distribute the data input.
|
202 |
+
donate_argnums=(3,),
|
203 |
+
axis_name="batch")
|
204 |
+
|
205 |
+
# Compiling to the CPU because it's faster and more accurate.
|
206 |
+
ssim_fn = jax.jit(
|
207 |
+
functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
|
208 |
+
|
209 |
+
if not utils.isdir(FLAGS.train_dir):
|
210 |
+
utils.makedirs(FLAGS.train_dir)
|
211 |
+
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
|
212 |
+
# Resume training a the step of the last checkpoint.
|
213 |
+
init_step = state.optimizer.state.step + 1
|
214 |
+
|
215 |
+
# for distributive training
|
216 |
+
state = flax.jax_utils.replicate(state)
|
217 |
+
if jax.host_id() == 0:
|
218 |
+
summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
|
219 |
+
|
220 |
+
# Prefetch_buffer_size = 3 x batch_size
|
221 |
+
pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
|
222 |
+
n_local_devices = jax.local_device_count()
|
223 |
+
rng = rng + jax.host_id() # Make random seed separate across hosts.
|
224 |
+
keys = random.split(rng, n_local_devices) # For pmapping RNG keys.
|
225 |
+
gc.disable() # Disable automatic garbage collection for efficiency.
|
226 |
+
stats_trace = []
|
227 |
+
reset_timer = True
|
228 |
+
|
229 |
+
# for semantic loss update
|
230 |
+
sc_image = None
|
231 |
+
sc_loss = 0.
|
232 |
+
|
233 |
+
for step, batch in tqdm(zip(range(init_step, FLAGS.max_steps + 1), pdataset)):
|
234 |
+
if reset_timer:
|
235 |
+
t_loop_start = time.time()
|
236 |
+
reset_timer = False
|
237 |
+
lr = learning_rate_fn(step)
|
238 |
+
|
239 |
+
grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
|
240 |
+
|
241 |
+
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
|
242 |
+
sc_batch = dataset.get_clip_data()
|
243 |
+
if jax.local_device_count() > 1:
|
244 |
+
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
|
245 |
+
else:
|
246 |
+
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
|
247 |
+
|
248 |
+
if jax.host_id() == 0 and step%FLAGS.print_every:
|
249 |
+
for mlp_k, mlp in grad['params'].items():
|
250 |
+
for layer_k, layer_g in mlp.items():
|
251 |
+
summary_writer.scalar("%s/%s/kernel_grad"%(mlp_k, layer_k), jnp.linalg.norm(jnp.mean(layer_g['kernel'],0)), step)
|
252 |
+
for mlp_k, mlp in sc_grad['params'].items():
|
253 |
+
for layer_k, layer_g in mlp.items():
|
254 |
+
summary_writer.scalar("%s/%s/kernel_sc_grad"%(mlp_k, layer_k), jnp.linalg.norm(layer_g['kernel']), step)
|
255 |
+
|
256 |
+
leaves, treedef = jax.tree_flatten(grad)
|
257 |
+
sc_leaves, _ = jax.tree_flatten(sc_grad)
|
258 |
+
grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
state = update_pstep(state, grad, lr)
|
263 |
+
|
264 |
+
if jax.host_id() == 0:
|
265 |
+
stats_trace.append(stats)
|
266 |
+
if step % FLAGS.gc_every == 0:
|
267 |
+
gc.collect()
|
268 |
+
|
269 |
+
# Log training summaries. This is put behind a host_id check because in
|
270 |
+
# multi-host evaluation, all hosts need to run inference even though we
|
271 |
+
# only use host 0 to record results.
|
272 |
+
if jax.host_id() == 0:
|
273 |
+
if step % FLAGS.print_every == 0:
|
274 |
+
summary_writer.scalar("loss/train", stats.loss[0], step)
|
275 |
+
summary_writer.scalar("sc_loss", sc_loss, step)
|
276 |
+
summary_writer.scalar("psnr/train", stats.psnr[0], step)
|
277 |
+
summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
|
278 |
+
summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
|
279 |
+
summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
|
280 |
+
avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
|
281 |
+
avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
|
282 |
+
stats_trace = []
|
283 |
+
summary_writer.scalar("train_avg/loss", avg_loss, step)
|
284 |
+
summary_writer.scalar("train_avg/psnr", avg_psnr, step)
|
285 |
+
summary_writer.scalar("learning_rate", lr, step)
|
286 |
+
steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
|
287 |
+
reset_timer = True
|
288 |
+
rays_per_sec = FLAGS.batch_size * steps_per_sec
|
289 |
+
summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
|
290 |
+
summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
|
291 |
+
precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
|
292 |
+
print(("{:" + "{:d}".format(precision) + "d}").format(step) +
|
293 |
+
f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
|
294 |
+
f"avg_loss={avg_loss:0.4f}, " +
|
295 |
+
f"weight_l2={stats.weight_l2[0]:0.2e}, " +
|
296 |
+
# f"sc_loss={sc_loss:0.4f}, " +
|
297 |
+
f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
|
298 |
+
if step % FLAGS.save_every == 0:
|
299 |
+
state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
300 |
+
checkpoints.save_checkpoint(
|
301 |
+
FLAGS.train_dir, state_to_save, int(step), keep=100)
|
302 |
+
|
303 |
+
# Test-set evaluation.
|
304 |
+
if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
|
305 |
+
# We reuse the same random number generator from the optimization step
|
306 |
+
# here on purpose so that the visualization matches what happened in
|
307 |
+
# training.
|
308 |
+
t_eval_start = time.time()
|
309 |
+
eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
|
310 |
+
state)).optimizer.target
|
311 |
+
test_case = next(test_dataset)
|
312 |
+
pred_color, pred_disp, pred_acc = utils.render_image(
|
313 |
+
functools.partial(render_pfn, eval_variables),
|
314 |
+
test_case["rays"],
|
315 |
+
keys[0],
|
316 |
+
FLAGS.dataset == "llff",
|
317 |
+
chunk=FLAGS.chunk)
|
318 |
+
|
319 |
+
# Log eval summaries on host 0.
|
320 |
+
if jax.host_id() == 0:
|
321 |
+
psnr = utils.compute_psnr(
|
322 |
+
((pred_color - test_case["pixels"]) ** 2).mean())
|
323 |
+
ssim = ssim_fn(pred_color, test_case["pixels"])
|
324 |
+
eval_time = time.time() - t_eval_start
|
325 |
+
num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
|
326 |
+
rays_per_sec = num_rays / eval_time
|
327 |
+
summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
|
328 |
+
print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
|
329 |
+
summary_writer.scalar("psnr/test", psnr, step)
|
330 |
+
summary_writer.scalar("test_psnr", psnr, step)
|
331 |
+
summary_writer.scalar("ssim/ssim", ssim, step)
|
332 |
+
summary_writer.scalar("test_ssim", ssim, step)
|
333 |
+
if sc_image is not None:
|
334 |
+
summary_writer .image("random_ray_image", sc_image, step)
|
335 |
+
summary_writer.image("test_pred_color", pred_color, step)
|
336 |
+
summary_writer.image("test_pred_disp", pred_disp, step)
|
337 |
+
summary_writer.image("test_pred_acc", pred_acc, step)
|
338 |
+
summary_writer.image("test_target", test_case["pixels"], step)
|
339 |
+
|
340 |
+
if FLAGS.max_steps % FLAGS.save_every != 0:
|
341 |
+
state = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
342 |
+
checkpoints.save_checkpoint(
|
343 |
+
FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
|
344 |
+
|
345 |
+
|
346 |
+
if __name__ == "__main__":
|
347 |
+
app.run(main)
|
train.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Google Research Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
#!/bin/bash
|
16 |
+
CONFIG=$1
|
17 |
+
DATA_ROOT=$2
|
18 |
+
ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
|
19 |
+
if [ $CONFIG == "llff" ]
|
20 |
+
then
|
21 |
+
SCENES="room fern leaves fortress orchids flower trex horns"
|
22 |
+
DATA_FOLDER="nerf_llff_data"
|
23 |
+
else
|
24 |
+
SCENES="lego chair drums ficus hotdog materials mic ship"
|
25 |
+
DATA_FOLDER="nerf_synthetic"
|
26 |
+
fi
|
27 |
+
|
28 |
+
# launch training jobs for all scenes.
|
29 |
+
for scene in $SCENES; do
|
30 |
+
python -m jaxnerf.train \
|
31 |
+
--data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
|
32 |
+
--train_dir="$ROOT_DIR"/"$scene" \
|
33 |
+
--config=configs/"$CONFIG"
|
34 |
+
done
|